From 4a7a52c987ec81e4445bfc65d5dde9715bfa3db4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Feb 2025 08:07:56 -0500 Subject: [PATCH 01/34] Add cat proposition --- src/cat.jl | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 src/cat.jl diff --git a/src/cat.jl b/src/cat.jl new file mode 100644 index 0000000..7c1375a --- /dev/null +++ b/src/cat.jl @@ -0,0 +1,75 @@ +# separate out in module to have the abliity of defining cat +""" + module Cat + +This module provides a slight modification to the Base implementation of concatenation. +In particular, the biggest hindrance there is that the output is selected based solely on the first input argument. +Here, we remedy this, and along the way leave some more flexible entry points. +Where possible, the final default implementation hooks back into Base, to minimize the required code. + +For users, this implements `cat(!)`, which up to a slight modification of the signature follow their Base counterparts. + +Developers can specialize the behavior and implementations of these functions, +either changing the destination through [`cat_size_shape`](@ref) and [`cat_similar`](@ref), +or the filling procedure via [`copy_or_fill!`](@ref), [`cat_offset1!`](@ref) or [`cat_offset!`](@ref) +""" +module Cat + +public cat, cat! +public cat_size_shape, cat_similar +public cat_offset!, cat_offset1!, copy_or_fill! + +# This is mostly a copy of the Base implementation, with the main difference being +# that the destination is chosen based on all inputs instead of just the first. + +# The entry points for deciding the destination are cat_size_shape and cat_similar(T, shape, args...) + +# Hooking into the actual concatenation machinery can be done in two ways: +# - specializing cat_offset!(dest, shape, catdims, offsets, x) on dest and/or x +# - specializing copy_or_fill!(dest, inds, x) on dest and/or x + +function cat(dims, args...) + T = promote_eltypeof(args...) + catdims = Base.dims2cat(dims) + shape = cat_size_shape(catdims, args...) + dest = cat_similar(T, shape, args...) + if count(!iszero, catdims)::Int > 1 + zero!(dest) + end + return cat!(dest, shape, catdims, args...) +end + +function cat!(dest, shape, catdims, args...) + offsets = ntuple(zero, ndims(dest)) + return cat_offset!(dest, shape, catdims, offsets, args...) +end + +# Write in terms of a generic cat_offset!, which in term aims to specialize on 1 argument +# at a time via cat_offset1! to avoid having to write too many specializations +function cat_offset!(dest, shape, catdims, offsets, x, X...) + dest, newoffsets = cat_offset1!(dest, shape, catdims, offsets, x) + return cat_offset!(dest, shape, newoffsets, X...) +end +cat_offset!(dest, shape, catdims, offsets) = dest + +# this is the typical specialization point, which is no longer vararg. +# it simply computes indices and calls out to copy_or_fill!, so if that +# pattern works you can also overload that function +function cat_offset1!(dest, shape, catdims, offsets, x) + inds = ntuple(length(offests)) do i + (i ≤ length(catdims) && catdims[i]) ? offsets[i] + cat_indices(x, i) : 1:shape[i] + end + copy_or_fill!(dest, inds, x) + newoffsets = ntuple(length(offsets)) do i + (i ≤ length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] + end + return dest, newoffsets +end + +# utility functions, default to their base counterparts but defined here to +# have the option to hook into (promote to public) +copy_or_fill!(dest, inds, x) = Base._copy_or_fill!(dest, inds, x) +cat_size_shape(catdims, args...) = Base.cat_size_shape(catdims, args...) +cat_similar(::Type{T}, shape, args...) = Base.cat_similar(args[1], T, shape) + +end From efd288c28e59c9acab43715a2f541fa52b6937f7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Feb 2025 08:14:14 -0500 Subject: [PATCH 02/34] Formatter --- src/cat.jl | 51 ++++++++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/src/cat.jl b/src/cat.jl index 7c1375a..258985d 100644 --- a/src/cat.jl +++ b/src/cat.jl @@ -15,9 +15,14 @@ or the filling procedure via [`copy_or_fill!`](@ref), [`cat_offset1!`](@ref) or """ module Cat -public cat, cat! -public cat_size_shape, cat_similar -public cat_offset!, cat_offset1!, copy_or_fill! +# this seems to break the formatter? +# public cat +# public cat! +# public cat_size_shape +# public cat_similar +# public cat_offset! +# public cat_offset1! +# public copy_or_fill! # This is mostly a copy of the Base implementation, with the main difference being # that the destination is chosen based on all inputs instead of just the first. @@ -29,26 +34,26 @@ public cat_offset!, cat_offset1!, copy_or_fill! # - specializing copy_or_fill!(dest, inds, x) on dest and/or x function cat(dims, args...) - T = promote_eltypeof(args...) - catdims = Base.dims2cat(dims) - shape = cat_size_shape(catdims, args...) - dest = cat_similar(T, shape, args...) - if count(!iszero, catdims)::Int > 1 - zero!(dest) - end - return cat!(dest, shape, catdims, args...) + T = promote_eltypeof(args...) + catdims = Base.dims2cat(dims) + shape = cat_size_shape(catdims, args...) + dest = cat_similar(T, shape, args...) + if count(!iszero, catdims)::Int > 1 + zero!(dest) + end + return cat!(dest, shape, catdims, args...) end function cat!(dest, shape, catdims, args...) - offsets = ntuple(zero, ndims(dest)) - return cat_offset!(dest, shape, catdims, offsets, args...) + offsets = ntuple(zero, ndims(dest)) + return cat_offset!(dest, shape, catdims, offsets, args...) end # Write in terms of a generic cat_offset!, which in term aims to specialize on 1 argument # at a time via cat_offset1! to avoid having to write too many specializations function cat_offset!(dest, shape, catdims, offsets, x, X...) - dest, newoffsets = cat_offset1!(dest, shape, catdims, offsets, x) - return cat_offset!(dest, shape, newoffsets, X...) + dest, newoffsets = cat_offset1!(dest, shape, catdims, offsets, x) + return cat_offset!(dest, shape, newoffsets, X...) end cat_offset!(dest, shape, catdims, offsets) = dest @@ -56,14 +61,14 @@ cat_offset!(dest, shape, catdims, offsets) = dest # it simply computes indices and calls out to copy_or_fill!, so if that # pattern works you can also overload that function function cat_offset1!(dest, shape, catdims, offsets, x) - inds = ntuple(length(offests)) do i - (i ≤ length(catdims) && catdims[i]) ? offsets[i] + cat_indices(x, i) : 1:shape[i] - end - copy_or_fill!(dest, inds, x) - newoffsets = ntuple(length(offsets)) do i - (i ≤ length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] - end - return dest, newoffsets + inds = ntuple(length(offests)) do i + (i ≤ length(catdims) && catdims[i]) ? offsets[i] + cat_indices(x, i) : 1:shape[i] + end + copy_or_fill!(dest, inds, x) + newoffsets = ntuple(length(offsets)) do i + (i ≤ length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] + end + return dest, newoffsets end # utility functions, default to their base counterparts but defined here to From 399b82f20b2c2a4d0f7072822c27a42861ce2a42 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 11:01:27 -0500 Subject: [PATCH 03/34] Rework cat to include `Concatenated` object --- src/cat.jl | 141 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 106 insertions(+), 35 deletions(-) diff --git a/src/cat.jl b/src/cat.jl index 258985d..e4eca6f 100644 --- a/src/cat.jl +++ b/src/cat.jl @@ -1,6 +1,6 @@ # separate out in module to have the abliity of defining cat """ - module Cat + module Concatenate This module provides a slight modification to the Base implementation of concatenation. In particular, the biggest hindrance there is that the output is selected based solely on the first input argument. @@ -13,42 +13,117 @@ Developers can specialize the behavior and implementations of these functions, either changing the destination through [`cat_size_shape`](@ref) and [`cat_similar`](@ref), or the filling procedure via [`copy_or_fill!`](@ref), [`cat_offset1!`](@ref) or [`cat_offset!`](@ref) """ -module Cat - -# this seems to break the formatter? -# public cat -# public cat! -# public cat_size_shape -# public cat_similar -# public cat_offset! -# public cat_offset1! -# public copy_or_fill! - -# This is mostly a copy of the Base implementation, with the main difference being -# that the destination is chosen based on all inputs instead of just the first. - -# The entry points for deciding the destination are cat_size_shape and cat_similar(T, shape, args...) - -# Hooking into the actual concatenation machinery can be done in two ways: -# - specializing cat_offset!(dest, shape, catdims, offsets, x) on dest and/or x -# - specializing copy_or_fill!(dest, inds, x) on dest and/or x - -function cat(dims, args...) - T = promote_eltypeof(args...) - catdims = Base.dims2cat(dims) - shape = cat_size_shape(catdims, args...) - dest = cat_similar(T, shape, args...) - if count(!iszero, catdims)::Int > 1 - zero!(dest) +module Concatenate + +#= +This is mostly a copy of the Base implementation, with the main difference being +that the destination is chosen based on all inputs instead of just the first. + +Additionally, we have an intermediate representation in terms of a Concatenated object, +reminiscent of how Broadcast works. + +The various entry points for specializing behavior are: + +* Destination selection can be achieved through + Base.similar(cat::Concatenated{Interface}, ::Type{T}, axes) where {Interface} + +* Implementation for moving one or more arguments into the destionation through + copy_offset!(dest, shape, catdims, offsets, args...) + copy_offset1!(dest, shape, catdims, offsets, x) + +* Custom implementations: + Base.copy(cat::Concatenated{Interface}) # custom implementation of concatenate + Base.copyto!(dest, cat::Concatenated{Interface}) # custom implementation of concatenate! based on interface + Base.copyto!(dest, cat::Concatenated{Nothing}) # custom implementation of concatenate! based on typeof(dest) +=# + +export concatenate, concatenate! +public Concatenated, cat_offset!, cat_offset1!, copy_or_fill! + +using Base: promote_eltypeof +using .DerivableInterfaces: AbstractInterface, interface + +# TODO: named like this because Catted looks so ugly and Cat is module name? +# cfr Broadcast - Broadcasted +""" + Concatenated{Interface,Dims,Args<:Tuple} + +Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide +hooks to customize the implementation. +""" +struct Concatenated{Interface,Dims,Args<:Tuple} + interface::Interface + dims::Val{Dims} + args::Args + + function Concatenated( + interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple + ) where {Dims} + return new{typeof(interface),Dims,typeof(args)}(interface, dims, args) + end + function Concatenated(dims, args::Tuple) + return Concatenated(interface(args...), dims, args) + end + function Concatenated{Interface}(dims, args) where {Interface} + return Concatenated(Interface(), dims, args) + end + function Concatenated{Interface,Dims}(args) where {Interface,Dims} + return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args) end - return cat!(dest, shape, catdims, args...) end -function cat!(dest, shape, catdims, args...) +dims(::Concatenated{A,D}) where {A,D} = D +DerivableInterfaces.interface(cat::Concatenated) = cat.interface + +concatenated(args...; dims) = Concatenated(args, Val(dims)) + +# allocating the destination container +# ------------------------------------ +Base.similar(cat::Concatenated) = similar(cat, promote_eltypeof(cat.args...)) +Base.similar(cat::Concatenated, ::Type{T}) where {T} = similar(cat, T, axes(cat)) +function Base.similar(cat::Concatenated, ::Type{T}, ax) where {T} + return cat_similar(interface(cat), T, ax) +end + +# For now, simply couple back to base implementation +function Base.axes(cat::Concatenated) + catdims = Base.dims2cat(dims(cat)) + return cat_size_shape(catdims, cat.args...) +end + +cat_size_shape(catdims, args...) = Base.cat_size_shape(catdims, args...) +cat_similar(::Type{T}, shape, args...) = Base.cat_similar(args[1], T, shape) + +# Main logic +# ---------- +concatenate(args...; dims) = Base.materialize(concatenated(dims, args...)) +Base.materialize(cat::Concatenated) = copy(cat) + +function concatenate!(dest, args...; dims) + Base.materialize!(dest, concatenated(dims, args...)) + return dest +end +Base.materialize!(dest, cat::Concatenated) = copyto!(dest, cat) + +Base.copy(cat::Concatenated) = copyto!(similar(cat), cat) + +# default falls back to replacing interface with Nothing +# this permits specializing on typeof(dest) without ambiguities +@inline Base.copyto!(dest, cat::Concatenated) = + copyto!(dest, convert(Concatenated{Nothing}, cat)) + +function Base.copyto!(dest::AbstractArray, cat::Concatenated{Nothing}) + # if concatenation along multiple directions, holes need to be zero. + catdims = Base.dims2cat(dims(cat)) + count(!iszero, catdims)::Int > 1 && zero!(dest) + + shape = cat_size_shape(catdims, cat.args...) offsets = ntuple(zero, ndims(dest)) - return cat_offset!(dest, shape, catdims, offsets, args...) + return cat_offset!(dest, shape, catdims, offsets, cat.args...) end +# Array implementation +# -------------------- # Write in terms of a generic cat_offset!, which in term aims to specialize on 1 argument # at a time via cat_offset1! to avoid having to write too many specializations function cat_offset!(dest, shape, catdims, offsets, x, X...) @@ -71,10 +146,6 @@ function cat_offset1!(dest, shape, catdims, offsets, x) return dest, newoffsets end -# utility functions, default to their base counterparts but defined here to -# have the option to hook into (promote to public) copy_or_fill!(dest, inds, x) = Base._copy_or_fill!(dest, inds, x) -cat_size_shape(catdims, args...) = Base.cat_size_shape(catdims, args...) -cat_similar(::Type{T}, shape, args...) = Base.cat_similar(args[1], T, shape) end From e29b271cfb26291b72285f3e20f4521a480c0949 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 11:10:08 -0500 Subject: [PATCH 04/34] Fix some docstrings --- Project.toml | 2 ++ src/cat.jl | 36 +++++++++++++++++++++--------------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 55cf1b9..7f51a0b 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.3.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" @@ -15,6 +16,7 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [compat] Adapt = "4.1.1" ArrayLayouts = "1.11.0" +Compat = "3.47,4.10" ExproniconLite = "0.10.13" LinearAlgebra = "1.10" MLStyle = "0.4.17" diff --git a/src/cat.jl b/src/cat.jl index e4eca6f..e323585 100644 --- a/src/cat.jl +++ b/src/cat.jl @@ -2,20 +2,8 @@ """ module Concatenate -This module provides a slight modification to the Base implementation of concatenation. -In particular, the biggest hindrance there is that the output is selected based solely on the first input argument. -Here, we remedy this, and along the way leave some more flexible entry points. -Where possible, the final default implementation hooks back into Base, to minimize the required code. +Alternative implementation for `Base.cat` through [`concatenate(!)`](@ref). -For users, this implements `cat(!)`, which up to a slight modification of the signature follow their Base counterparts. - -Developers can specialize the behavior and implementations of these functions, -either changing the destination through [`cat_size_shape`](@ref) and [`cat_similar`](@ref), -or the filling procedure via [`copy_or_fill!`](@ref), [`cat_offset1!`](@ref) or [`cat_offset!`](@ref) -""" -module Concatenate - -#= This is mostly a copy of the Base implementation, with the main difference being that the destination is chosen based on all inputs instead of just the first. @@ -25,20 +13,26 @@ reminiscent of how Broadcast works. The various entry points for specializing behavior are: * Destination selection can be achieved through + Base.similar(cat::Concatenated{Interface}, ::Type{T}, axes) where {Interface} * Implementation for moving one or more arguments into the destionation through + copy_offset!(dest, shape, catdims, offsets, args...) copy_offset1!(dest, shape, catdims, offsets, x) * Custom implementations: + Base.copy(cat::Concatenated{Interface}) # custom implementation of concatenate Base.copyto!(dest, cat::Concatenated{Interface}) # custom implementation of concatenate! based on interface Base.copyto!(dest, cat::Concatenated{Nothing}) # custom implementation of concatenate! based on typeof(dest) -=# +""" +module Concatenate + +using Compat: @compat export concatenate, concatenate! -public Concatenated, cat_offset!, cat_offset1!, copy_or_fill! +@compat public Concatenated, cat_offset!, cat_offset1!, copy_or_fill! using Base: promote_eltypeof using .DerivableInterfaces: AbstractInterface, interface @@ -96,9 +90,21 @@ cat_similar(::Type{T}, shape, args...) = Base.cat_similar(args[1], T, shape) # Main logic # ---------- +""" + concatenate(args...; dims) + +Concatenate the supplied `args` along dimensions `dims`. + +See also [`concatenate!`](@ref). +""" concatenate(args...; dims) = Base.materialize(concatenated(dims, args...)) Base.materialize(cat::Concatenated) = copy(cat) +""" + concatenate!(dest, args...; dims) + +Concatenate the suppliled `args` along dimensions `dims`, placing the result into `dest`. +""" function concatenate!(dest, args...; dims) Base.materialize!(dest, concatenated(dims, args...)) return dest From 3ad0b84aaf6351b1a25ead45ec3467abcdd89997 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 12:50:57 -0500 Subject: [PATCH 05/34] simplify some functions --- src/cat.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/cat.jl b/src/cat.jl index e323585..8d3e31c 100644 --- a/src/cat.jl +++ b/src/cat.jl @@ -76,7 +76,7 @@ concatenated(args...; dims) = Concatenated(args, Val(dims)) Base.similar(cat::Concatenated) = similar(cat, promote_eltypeof(cat.args...)) Base.similar(cat::Concatenated, ::Type{T}) where {T} = similar(cat, T, axes(cat)) function Base.similar(cat::Concatenated, ::Type{T}, ax) where {T} - return cat_similar(interface(cat), T, ax) + return similar(interface(cat), T, ax) end # For now, simply couple back to base implementation @@ -143,15 +143,16 @@ cat_offset!(dest, shape, catdims, offsets) = dest # pattern works you can also overload that function function cat_offset1!(dest, shape, catdims, offsets, x) inds = ntuple(length(offests)) do i - (i ≤ length(catdims) && catdims[i]) ? offsets[i] + cat_indices(x, i) : 1:shape[i] + (i ≤ length(catdims) && catdims[i]) ? offsets[i] + axes(x, i) : 1:shape[i] end copy_or_fill!(dest, inds, x) newoffsets = ntuple(length(offsets)) do i - (i ≤ length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] + (i ≤ length(catdims) && catdims[i]) ? offsets[i] + size(x, i) : offsets[i] end return dest, newoffsets end copy_or_fill!(dest, inds, x) = Base._copy_or_fill!(dest, inds, x) +zero!(x::AbstractArray) = fill!(x, zero(eltype(x))) end From 13e97a937d9900380aede8391f3bf1866be72ea6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 12:53:21 -0500 Subject: [PATCH 06/34] remove out-of-date comments --- src/cat.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/cat.jl b/src/cat.jl index 8d3e31c..27cdc03 100644 --- a/src/cat.jl +++ b/src/cat.jl @@ -1,4 +1,3 @@ -# separate out in module to have the abliity of defining cat """ module Concatenate @@ -37,8 +36,6 @@ export concatenate, concatenate! using Base: promote_eltypeof using .DerivableInterfaces: AbstractInterface, interface -# TODO: named like this because Catted looks so ugly and Cat is module name? -# cfr Broadcast - Broadcasted """ Concatenated{Interface,Dims,Args<:Tuple} From 4363c3c3e96b0cd16b87bddd5062a97e5d0f94f5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 12:55:02 -0500 Subject: [PATCH 07/34] rename and include file --- src/DerivableInterfaces.jl | 3 +++ src/{cat.jl => concatenate.jl} | 0 2 files changed, 3 insertions(+) rename src/{cat.jl => concatenate.jl} (100%) diff --git a/src/DerivableInterfaces.jl b/src/DerivableInterfaces.jl index 42ac21f..06eac51 100644 --- a/src/DerivableInterfaces.jl +++ b/src/DerivableInterfaces.jl @@ -9,4 +9,7 @@ include("abstractarrayinterface.jl") include("defaultarrayinterface.jl") include("traits.jl") +# Specific AbstractArray alternatives +include("concatenate.jl") + end diff --git a/src/cat.jl b/src/concatenate.jl similarity index 100% rename from src/cat.jl rename to src/concatenate.jl From 18898079150744551b8a7acd5291374b882541a8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 13:07:24 -0500 Subject: [PATCH 08/34] small simplifications --- src/concatenate.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 27cdc03..68fc697 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -70,21 +70,20 @@ concatenated(args...; dims) = Concatenated(args, Val(dims)) # allocating the destination container # ------------------------------------ -Base.similar(cat::Concatenated) = similar(cat, promote_eltypeof(cat.args...)) +Base.similar(cat::Concatenated) = similar(cat, eltype(cat)) Base.similar(cat::Concatenated, ::Type{T}) where {T} = similar(cat, T, axes(cat)) function Base.similar(cat::Concatenated, ::Type{T}, ax) where {T} return similar(interface(cat), T, ax) end +Base.eltype(cat::Concatenated) = promote_eltypeof(cat.args...) + # For now, simply couple back to base implementation function Base.axes(cat::Concatenated) catdims = Base.dims2cat(dims(cat)) - return cat_size_shape(catdims, cat.args...) + return Base.cat_size_shape(catdims, cat.args...) end -cat_size_shape(catdims, args...) = Base.cat_size_shape(catdims, args...) -cat_similar(::Type{T}, shape, args...) = Base.cat_similar(args[1], T, shape) - # Main logic # ---------- """ From 00c30a2bc4e6755752bc44a8c52be0f47129aea6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 13:08:33 -0500 Subject: [PATCH 09/34] remove previous implementation --- src/abstractarrayinterface.jl | 99 ----------------------------------- src/traits.jl | 1 - test/SparseArrayDOKs.jl | 2 + 3 files changed, 2 insertions(+), 100 deletions(-) diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl index 7ea82a5..28a5483 100644 --- a/src/abstractarrayinterface.jl +++ b/src/abstractarrayinterface.jl @@ -250,102 +250,3 @@ end ## @interface ::AbstractMatrixInterface function Base.*(a1, a2) ## return ArrayLayouts.mul(a1, a2) ## end - -# Concatenation - -axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) -function axis_cat( - a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... -) - return axis_cat(axis_cat(a1, a2), a_rest...) -end - -unval(x) = x -unval(::Val{x}) where {x} = x - -function cat_axes(as::AbstractArray...; dims) - return ntuple(length(first(axes.(as)))) do dim - return if dim in unval(dims) - axis_cat(map(axes -> axes[dim], axes.(as))...) - else - axes(first(as))[dim] - end - end -end - -function cat! end - -# Represents concatenating `args` over `dims`. -struct Cat{Args<:Tuple{Vararg{AbstractArray}},dims} - args::Args -end -to_cat_dims(dim::Integer) = Int(dim) -to_cat_dims(dim::Int) = (dim,) -to_cat_dims(dims::Val) = to_cat_dims(unval(dims)) -to_cat_dims(dims::Tuple) = dims -Cat(args::AbstractArray...; dims) = Cat{typeof(args),to_cat_dims(dims)}(args) -cat_dims(::Cat{<:Any,dims}) where {dims} = dims - -function Base.axes(a::Cat) - return cat_axes(a.args...; dims=cat_dims(a)) -end -Base.eltype(a::Cat) = promote_type(eltype.(a.args)...) -function Base.similar(a::Cat) - ax = axes(a) - elt = eltype(a) - # TODO: This drops GPU information, maybe use MemoryLayout? - return similar(arraytype(interface(a.args...), elt), ax) -end - -# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857 -# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation -# This is very similar to the `Base.cat` implementation but handles zero values better. -function cat_offset!( - a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims -) - inds = ntuple(ndims(a_dest)) do dim - dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim) - end - a_dest[inds...] = a1 - new_offsets = ntuple(ndims(a_dest)) do dim - dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim] - end - cat_offset!(a_dest, new_offsets, a_rest...; dims) - return a_dest -end -function cat_offset!(a_dest::AbstractArray, offsets; dims) - return a_dest -end - -@interface ::AbstractArrayInterface function cat!( - a_dest::AbstractArray, as::AbstractArray...; dims -) - offsets = ntuple(zero, ndims(a_dest)) - # TODO: Fill `a_dest` with zeros if needed using `zero!`. - cat_offset!(a_dest, offsets, as...; dims) - return a_dest -end - -function cat_along(dims, as::AbstractArray...) - return @interface interface(as...) cat_along(dims, as...) -end - -@interface interface::AbstractArrayInterface function cat_along(dims, as::AbstractArray...) - a_dest = similar(Cat(as...; dims)) - @interface interface cat!(a_dest, as...; dims) - return a_dest -end - -@interface interface::AbstractArrayInterface function Base.cat(as::AbstractArray...; dims) - return @interface interface cat_along(dims, as...) -end - -# TODO: Use `@derive` instead: -# ```julia -# @derive (T=AbstractArray,) begin -# cat!(a_dest::AbstractArray, as::T...; dims) -# end -# ``` -function cat!(a_dest::AbstractArray, as::AbstractArray...; dims) - return @interface interface(as...) cat!(a_dest, as...; dims) -end diff --git a/src/traits.jl b/src/traits.jl index a05bd75..ccb743d 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -45,7 +45,6 @@ function derive(::Val{:AbstractArrayOps}, type) Base.permutedims!(::Any, ::$type, ::Any) Broadcast.BroadcastStyle(::Type{<:$type}) Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) - Base.cat(::$type...; kwargs...) ArrayLayouts.MemoryLayout(::Type{<:$type}) LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number) end diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index bf54893..d4bf3b0 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -268,4 +268,6 @@ DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() # DerivableInterfaces the interface for the type. @derive AnySparseArrayDOK AbstractArrayOps +Base._cat(dims, args::SparseArrayDOK...) = DerivableInterfaces.concatenate(args...; dims) + end From eff68ce659afde554a0dec396c8bb3cdbe065722 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 13:19:16 -0500 Subject: [PATCH 10/34] remove arraytype in favor of `similar(interface, T, ax)` --- src/abstractarrayinterface.jl | 8 ++------ test/SparseArrayDOKs.jl | 3 ++- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl index 28a5483..5e5b56b 100644 --- a/src/abstractarrayinterface.jl +++ b/src/abstractarrayinterface.jl @@ -13,9 +13,6 @@ function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style} return interface(Style) end -# TODO: Define as `Array{T}`. -arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.") - using ArrayLayouts: ArrayLayouts @interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I...) @@ -85,7 +82,7 @@ end @interface interface::AbstractArrayInterface function Base.similar( a::AbstractArray, T::Type, size::Tuple{Vararg{Int}} ) - return similar(arraytype(interface, T), size) + return similar(interface, T, size) end @interface ::AbstractArrayInterface function Base.copy(a::AbstractArray) @@ -105,8 +102,7 @@ end @interface interface::AbstractArrayInterface function Base.similar( bc::Broadcast.Broadcasted, T::Type, axes::Tuple ) - # `arraytype(::AbstractInterface)` determines the default array type associated with the interface. - return similar(arraytype(interface, T), axes) + return similar(interface, T, axes) end using MapBroadcast: Mapped diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index d4bf3b0..ec6c2a9 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -70,7 +70,7 @@ DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface @derive SparseArrayStyle AbstractArrayStyleOps -DerivableInterfaces.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T} +Base.similar(::SparseArrayInterface, ::Type{T}, ax) where {T} = similar(SparseArrayDOK{T}, ax) # Interface functions. @interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type) @@ -226,6 +226,7 @@ struct SparseArrayDOK{T,N} <: AbstractArray{T,N} end storage(a::SparseArrayDOK) = a.storage Base.size(a::SparseArrayDOK) = a.size +Base.similar(::Type{SparseArrayDOK{T}}, axes) = SparseArrayDOK{T}(undef, axes) function SparseArrayDOK{T}(size::Int...) where {T} N = length(size) return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) From 029391798d906ad442f52d962cffb88bf8af6aea Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 13:46:42 -0500 Subject: [PATCH 11/34] Various fixes --- src/concatenate.jl | 21 +++++++++++++-------- test/SparseArrayDOKs.jl | 3 +-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 68fc697..6bfd684 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -34,7 +34,7 @@ export concatenate, concatenate! @compat public Concatenated, cat_offset!, cat_offset1!, copy_or_fill! using Base: promote_eltypeof -using .DerivableInterfaces: AbstractInterface, interface +using ..DerivableInterfaces: DerivableInterfaces, AbstractInterface, interface """ Concatenated{Interface,Dims,Args<:Tuple} @@ -66,7 +66,11 @@ end dims(::Concatenated{A,D}) where {A,D} = D DerivableInterfaces.interface(cat::Concatenated) = cat.interface -concatenated(args...; dims) = Concatenated(args, Val(dims)) +concatenated(args...; dims) = Concatenated(Val(dims), args) + +function Base.convert(::Type{Concatenated{NewInterface}}, cat::Concatenated{<:Any,Dims,Args}) where {NewInterface,Dims,Args} + return Concatenated{NewInterface}(cat.dims, cat.args)::Concatenated{NewInterface,Dims,Args} +end # allocating the destination container # ------------------------------------ @@ -93,7 +97,7 @@ Concatenate the supplied `args` along dimensions `dims`. See also [`concatenate!`](@ref). """ -concatenate(args...; dims) = Base.materialize(concatenated(dims, args...)) +concatenate(args...; dims) = Base.materialize(concatenated(args...; dims)) Base.materialize(cat::Concatenated) = copy(cat) """ @@ -111,7 +115,8 @@ Base.copy(cat::Concatenated) = copyto!(similar(cat), cat) # default falls back to replacing interface with Nothing # this permits specializing on typeof(dest) without ambiguities -@inline Base.copyto!(dest, cat::Concatenated) = +# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. +@inline Base.copyto!(dest::AbstractArray, cat::Concatenated) = copyto!(dest, convert(Concatenated{Nothing}, cat)) function Base.copyto!(dest::AbstractArray, cat::Concatenated{Nothing}) @@ -119,7 +124,7 @@ function Base.copyto!(dest::AbstractArray, cat::Concatenated{Nothing}) catdims = Base.dims2cat(dims(cat)) count(!iszero, catdims)::Int > 1 && zero!(dest) - shape = cat_size_shape(catdims, cat.args...) + shape = Base.cat_size_shape(catdims, cat.args...) offsets = ntuple(zero, ndims(dest)) return cat_offset!(dest, shape, catdims, offsets, cat.args...) end @@ -130,7 +135,7 @@ end # at a time via cat_offset1! to avoid having to write too many specializations function cat_offset!(dest, shape, catdims, offsets, x, X...) dest, newoffsets = cat_offset1!(dest, shape, catdims, offsets, x) - return cat_offset!(dest, shape, newoffsets, X...) + return cat_offset!(dest, shape, catdims, newoffsets, X...) end cat_offset!(dest, shape, catdims, offsets) = dest @@ -138,8 +143,8 @@ cat_offset!(dest, shape, catdims, offsets) = dest # it simply computes indices and calls out to copy_or_fill!, so if that # pattern works you can also overload that function function cat_offset1!(dest, shape, catdims, offsets, x) - inds = ntuple(length(offests)) do i - (i ≤ length(catdims) && catdims[i]) ? offsets[i] + axes(x, i) : 1:shape[i] + inds = ntuple(length(offsets)) do i + (i ≤ length(catdims) && catdims[i]) ? offsets[i] .+ axes(x, i) : 1:shape[i] end copy_or_fill!(dest, inds, x) newoffsets = ntuple(length(offsets)) do i diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index ec6c2a9..bcf5923 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -226,7 +226,6 @@ struct SparseArrayDOK{T,N} <: AbstractArray{T,N} end storage(a::SparseArrayDOK) = a.storage Base.size(a::SparseArrayDOK) = a.size -Base.similar(::Type{SparseArrayDOK{T}}, axes) = SparseArrayDOK{T}(undef, axes) function SparseArrayDOK{T}(size::Int...) where {T} N = length(size) return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) @@ -269,6 +268,6 @@ DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() # DerivableInterfaces the interface for the type. @derive AnySparseArrayDOK AbstractArrayOps -Base._cat(dims, args::SparseArrayDOK...) = DerivableInterfaces.concatenate(args...; dims) +Base._cat(dims, args::SparseArrayDOK...) = DerivableInterfaces.Concatenate.concatenate(args...; dims) end From d07403389c8366df04217d2e3cde8e0e05728d46 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 11 Feb 2025 13:49:47 -0500 Subject: [PATCH 12/34] formatter --- src/concatenate.jl | 8 ++++++-- test/SparseArrayDOKs.jl | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 6bfd684..52fc5df 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -68,8 +68,12 @@ DerivableInterfaces.interface(cat::Concatenated) = cat.interface concatenated(args...; dims) = Concatenated(Val(dims), args) -function Base.convert(::Type{Concatenated{NewInterface}}, cat::Concatenated{<:Any,Dims,Args}) where {NewInterface,Dims,Args} - return Concatenated{NewInterface}(cat.dims, cat.args)::Concatenated{NewInterface,Dims,Args} +function Base.convert( + ::Type{Concatenated{NewInterface}}, cat::Concatenated{<:Any,Dims,Args} +) where {NewInterface,Dims,Args} + return Concatenated{NewInterface}( + cat.dims, cat.args + )::Concatenated{NewInterface,Dims,Args} end # allocating the destination container diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index bcf5923..fd8993e 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -70,7 +70,9 @@ DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface @derive SparseArrayStyle AbstractArrayStyleOps -Base.similar(::SparseArrayInterface, ::Type{T}, ax) where {T} = similar(SparseArrayDOK{T}, ax) +function Base.similar(::SparseArrayInterface, ::Type{T}, ax) where {T} + return similar(SparseArrayDOK{T}, ax) +end # Interface functions. @interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type) @@ -268,6 +270,8 @@ DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() # DerivableInterfaces the interface for the type. @derive AnySparseArrayDOK AbstractArrayOps -Base._cat(dims, args::SparseArrayDOK...) = DerivableInterfaces.Concatenate.concatenate(args...; dims) +function Base._cat(dims, args::SparseArrayDOK...) + return DerivableInterfaces.Concatenate.concatenate(args...; dims) +end end From 11f20135638ec87e90e3ac30f3c035c605e03d62 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 10:44:38 -0500 Subject: [PATCH 13/34] Avoid `Base._copy_or_fill!` --- src/concatenate.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 52fc5df..61571a1 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -157,7 +157,10 @@ function cat_offset1!(dest, shape, catdims, offsets, x) return dest, newoffsets end -copy_or_fill!(dest, inds, x) = Base._copy_or_fill!(dest, inds, x) +# copy of Base._copy_or_fill! +copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) +copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) + zero!(x::AbstractArray) = fill!(x, zero(eltype(x))) end From 4c70818cba04bd6220936de4d0d151010f4c4d95 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 11:20:17 -0500 Subject: [PATCH 14/34] rename `concatenate(!)` to `cat(!)` --- src/concatenate.jl | 64 ++++++++++++++++++++--------------------- test/SparseArrayDOKs.jl | 2 +- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 61571a1..184a32f 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -1,7 +1,7 @@ """ module Concatenate -Alternative implementation for `Base.cat` through [`concatenate(!)`](@ref). +Alternative implementation for `Base.cat` through [`cat(!)`](@ref cat). This is mostly a copy of the Base implementation, with the main difference being that the destination is chosen based on all inputs instead of just the first. @@ -13,7 +13,7 @@ The various entry points for specializing behavior are: * Destination selection can be achieved through - Base.similar(cat::Concatenated{Interface}, ::Type{T}, axes) where {Interface} + Base.similar(concat::Concatenated{Interface}, ::Type{T}, axes) where {Interface} * Implementation for moving one or more arguments into the destionation through @@ -22,15 +22,13 @@ The various entry points for specializing behavior are: * Custom implementations: - Base.copy(cat::Concatenated{Interface}) # custom implementation of concatenate - Base.copyto!(dest, cat::Concatenated{Interface}) # custom implementation of concatenate! based on interface - Base.copyto!(dest, cat::Concatenated{Nothing}) # custom implementation of concatenate! based on typeof(dest) + Base.copy(concat::Concatenated{Interface}) # custom implementation of cat + Base.copyto!(dest, concat::Concatenated{Interface}) # custom implementation of cat! based on interface + Base.copyto!(dest, concat::Concatenated{Nothing}) # custom implementation of cat! based on typeof(dest) """ module Concatenate using Compat: @compat - -export concatenate, concatenate! @compat public Concatenated, cat_offset!, cat_offset1!, copy_or_fill! using Base: promote_eltypeof @@ -64,73 +62,73 @@ struct Concatenated{Interface,Dims,Args<:Tuple} end dims(::Concatenated{A,D}) where {A,D} = D -DerivableInterfaces.interface(cat::Concatenated) = cat.interface +DerivableInterfaces.interface(concat::Concatenated) = concat.interface concatenated(args...; dims) = Concatenated(Val(dims), args) function Base.convert( - ::Type{Concatenated{NewInterface}}, cat::Concatenated{<:Any,Dims,Args} + ::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Args} ) where {NewInterface,Dims,Args} return Concatenated{NewInterface}( - cat.dims, cat.args + concat.dims, concat.args )::Concatenated{NewInterface,Dims,Args} end # allocating the destination container # ------------------------------------ -Base.similar(cat::Concatenated) = similar(cat, eltype(cat)) -Base.similar(cat::Concatenated, ::Type{T}) where {T} = similar(cat, T, axes(cat)) -function Base.similar(cat::Concatenated, ::Type{T}, ax) where {T} - return similar(interface(cat), T, ax) +Base.similar(concat::Concatenated) = similar(concat, eltype(cat)) +Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(cat)) +function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} + return similar(interface(concat), T, ax) end -Base.eltype(cat::Concatenated) = promote_eltypeof(cat.args...) +Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) # For now, simply couple back to base implementation -function Base.axes(cat::Concatenated) - catdims = Base.dims2cat(dims(cat)) - return Base.cat_size_shape(catdims, cat.args...) +function Base.axes(concat::Concatenated) + catdims = Base.dims2cat(dims(concat)) + return Base.cat_size_shape(catdims, concat.args...) end # Main logic # ---------- """ - concatenate(args...; dims) + Concatenate.cat(args...; dims) Concatenate the supplied `args` along dimensions `dims`. -See also [`concatenate!`](@ref). +See also [`cat!`](@ref). """ -concatenate(args...; dims) = Base.materialize(concatenated(args...; dims)) -Base.materialize(cat::Concatenated) = copy(cat) +cat(args...; dims) = Base.materialize(concatenated(args...; dims)) +Base.materialize(concat::Concatenated) = copy(concat) """ - concatenate!(dest, args...; dims) + Concatenate.cat!(dest, args...; dims) -Concatenate the suppliled `args` along dimensions `dims`, placing the result into `dest`. +Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`. """ -function concatenate!(dest, args...; dims) +function cat!(dest, args...; dims) Base.materialize!(dest, concatenated(dims, args...)) return dest end -Base.materialize!(dest, cat::Concatenated) = copyto!(dest, cat) +Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat) -Base.copy(cat::Concatenated) = copyto!(similar(cat), cat) +Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) # default falls back to replacing interface with Nothing # this permits specializing on typeof(dest) without ambiguities # Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. -@inline Base.copyto!(dest::AbstractArray, cat::Concatenated) = - copyto!(dest, convert(Concatenated{Nothing}, cat)) +@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) = + copyto!(dest, convert(Concatenated{Nothing}, concat)) -function Base.copyto!(dest::AbstractArray, cat::Concatenated{Nothing}) +function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) # if concatenation along multiple directions, holes need to be zero. - catdims = Base.dims2cat(dims(cat)) + catdims = Base.dims2cat(dims(concat)) count(!iszero, catdims)::Int > 1 && zero!(dest) - shape = Base.cat_size_shape(catdims, cat.args...) + shape = Base.cat_size_shape(catdims, concat.args...) offsets = ntuple(zero, ndims(dest)) - return cat_offset!(dest, shape, catdims, offsets, cat.args...) + return cat_offset!(dest, shape, catdims, offsets, concat.args...) end # Array implementation diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index fd8993e..20fadc3 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -271,7 +271,7 @@ DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() @derive AnySparseArrayDOK AbstractArrayOps function Base._cat(dims, args::SparseArrayDOK...) - return DerivableInterfaces.Concatenate.concatenate(args...; dims) + return DerivableInterfaces.Concatenate.cat(args...; dims) end end From af8c0edc18d96b1d11f366681df9d17e13b52b45 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 11:24:37 -0500 Subject: [PATCH 15/34] Add comment about overloading `Base._cat` --- test/SparseArrayDOKs.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index 20fadc3..46de38d 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -270,6 +270,7 @@ DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() # DerivableInterfaces the interface for the type. @derive AnySparseArrayDOK AbstractArrayOps +# avoid overloading `Base.cat` because of method invalidations function Base._cat(dims, args::SparseArrayDOK...) return DerivableInterfaces.Concatenate.cat(args...; dims) end From e4f0d89683538218a9506cc342afe3242d2b9301 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 11:30:35 -0500 Subject: [PATCH 16/34] remove specialization points and simplify --- src/concatenate.jl | 45 ++++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 184a32f..a0dc296 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -29,7 +29,7 @@ The various entry points for specializing behavior are: module Concatenate using Compat: @compat -@compat public Concatenated, cat_offset!, cat_offset1!, copy_or_fill! +@compat public Concatenated using Base: promote_eltypeof using ..DerivableInterfaces: DerivableInterfaces, AbstractInterface, interface @@ -121,43 +121,42 @@ Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) @inline Base.copyto!(dest::AbstractArray, concat::Concatenated) = copyto!(dest, convert(Concatenated{Nothing}, concat)) +# couple back to Base implementation if no specialization exists: +# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852 function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) - # if concatenation along multiple directions, holes need to be zero. catdims = Base.dims2cat(dims(concat)) - count(!iszero, catdims)::Int > 1 && zero!(dest) - shape = Base.cat_size_shape(catdims, concat.args...) - offsets = ntuple(zero, ndims(dest)) - return cat_offset!(dest, shape, catdims, offsets, concat.args...) + count(!iszero, catdims)::Int > 1 && zero!(dest) + return Base.__cat(dest, shape, catdims, concat.args...) end # Array implementation # -------------------- # Write in terms of a generic cat_offset!, which in term aims to specialize on 1 argument # at a time via cat_offset1! to avoid having to write too many specializations -function cat_offset!(dest, shape, catdims, offsets, x, X...) - dest, newoffsets = cat_offset1!(dest, shape, catdims, offsets, x) - return cat_offset!(dest, shape, catdims, newoffsets, X...) -end -cat_offset!(dest, shape, catdims, offsets) = dest +# function cat_offset!(dest, shape, catdims, offsets, x, X...) +# dest, newoffsets = cat_offset1!(dest, shape, catdims, offsets, x) +# return cat_offset!(dest, shape, catdims, newoffsets, X...) +# end +# cat_offset!(dest, shape, catdims, offsets) = dest # this is the typical specialization point, which is no longer vararg. # it simply computes indices and calls out to copy_or_fill!, so if that # pattern works you can also overload that function -function cat_offset1!(dest, shape, catdims, offsets, x) - inds = ntuple(length(offsets)) do i - (i ≤ length(catdims) && catdims[i]) ? offsets[i] .+ axes(x, i) : 1:shape[i] - end - copy_or_fill!(dest, inds, x) - newoffsets = ntuple(length(offsets)) do i - (i ≤ length(catdims) && catdims[i]) ? offsets[i] + size(x, i) : offsets[i] - end - return dest, newoffsets -end +# function cat_offset1!(dest, shape, catdims, offsets, x) +# inds = ntuple(length(offsets)) do i +# (i ≤ length(catdims) && catdims[i]) ? offsets[i] .+ axes(x, i) : 1:shape[i] +# end +# copy_or_fill!(dest, inds, x) +# newoffsets = ntuple(length(offsets)) do i +# (i ≤ length(catdims) && catdims[i]) ? offsets[i] + size(x, i) : offsets[i] +# end +# return dest, newoffsets +# end # copy of Base._copy_or_fill! -copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) -copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) +# copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) +# copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) zero!(x::AbstractArray) = fill!(x, zero(eltype(x))) From db1054e666d5e76416ef2cc980abdc228f066457 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 07:49:59 -0500 Subject: [PATCH 17/34] remove traits --- src/DerivableInterfaces.jl | 1 - src/derive_macro.jl | 43 ---------------------------- src/traits.jl | 58 -------------------------------------- 3 files changed, 102 deletions(-) delete mode 100644 src/traits.jl diff --git a/src/DerivableInterfaces.jl b/src/DerivableInterfaces.jl index 06eac51..1a254d5 100644 --- a/src/DerivableInterfaces.jl +++ b/src/DerivableInterfaces.jl @@ -7,7 +7,6 @@ include("interface_macro.jl") include("wrappedarrays.jl") include("abstractarrayinterface.jl") include("defaultarrayinterface.jl") -include("traits.jl") # Specific AbstractArray alternatives include("concatenate.jl") diff --git a/src/derive_macro.jl b/src/derive_macro.jl index 0d17527..8579b6f 100644 --- a/src/derive_macro.jl +++ b/src/derive_macro.jl @@ -68,26 +68,6 @@ function derive_expr(interface::Union{Symbol,Expr}, types::Expr, funcs::Expr) end end -#== -```julia -@derive SparseArrayDOK AbstractArrayOps -``` -==# -function derive_expr(type::Union{Symbol,Expr}, trait::Symbol) - return derive_trait(type, trait) -end - -#== -```julia -@derive SparseArrayInterface() SparseArrayDOK AbstractArrayOps -``` -==# -function derive_expr( - interface::Union{Symbol,Expr}, types::Union{Symbol,Expr}, trait::Symbol -) - return derive_trait(interface, types, trait) -end - function derive_funcs(args...) interface_and_or_types = Base.front(args) funcs = last(args) @@ -217,26 +197,3 @@ function derive_interface_func(interface::Union{Symbol,Expr}, func::Expr) # namespace when `@derive` is called. return globalref_derive(codegen_ast(jlfn)) end - -#= -```julia -@derive SparseArrayInterface() SparseArrayDOK AbstractArrayOps -``` -=# -function derive_trait( - interface::Union{Symbol,Expr}, type::Union{Symbol,Expr}, trait::Symbol -) - funcs = Expr(:block, derive(Val(trait), type).args...) - return derive_funcs(interface, funcs) -end - -#= -```julia -@derive SparseArrayDOK AbstractArrayOps -``` -=# -function derive_trait(type::Union{Symbol,Expr}, trait::Symbol) - types = :((T=$type,)) - funcs = Expr(:block, derive(Val(trait), :T).args...) - return derive_funcs(types, funcs) -end diff --git a/src/traits.jl b/src/traits.jl deleted file mode 100644 index ccb743d..0000000 --- a/src/traits.jl +++ /dev/null @@ -1,58 +0,0 @@ -using ArrayLayouts: ArrayLayouts -using LinearAlgebra: LinearAlgebra - -# TODO: Create a macro: -#= -``` -@derive_def AbstractArrayOps T begin - Base.getindex(::T, ::Any...) - Base.getindex(::T, ::Int...) - Base.setindex!(::T, ::Any, ::Int...) - Base.similar(::T, ::Type, ::Tuple{Vararg{Int}}) -end -``` -=# -# TODO: Define an `AbstractMatrixOps` trait, which is where -# matrix multiplication should be defined (both `mul!` and `*`). -#= -```julia -@derive SparseArrayDOK AbstractArrayOps -@derive SparseArrayInterface SparseArrayDOK AbstractArrayOps -``` -=# -function derive(::Val{:AbstractArrayOps}, type) - return quote - Base.getindex(::$type, ::Any...) - Base.getindex(::$type, ::Int...) - Base.setindex!(::$type, ::Any, ::Any...) - Base.setindex!(::$type, ::Any, ::Int...) - Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}}) - Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}}) - Base.copy(::$type) - Base.copy!(::AbstractArray, ::$type) - Base.copyto!(::AbstractArray, ::$type) - Base.map(::Any, ::$type...) - Base.map!(::Any, ::AbstractArray, ::$type...) - Base.mapreduce(::Any, ::Any, ::$type...; kwargs...) - Base.reduce(::Any, ::$type...; kwargs...) - Base.all(::Function, ::$type) - Base.all(::$type) - Base.iszero(::$type) - Base.real(::$type) - Base.fill!(::$type, ::Any) - ArrayLayouts.zero!(::$type) - Base.zero(::$type) - Base.permutedims!(::Any, ::$type, ::Any) - Broadcast.BroadcastStyle(::Type{<:$type}) - Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) - ArrayLayouts.MemoryLayout(::Type{<:$type}) - LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number) - end -end - -function derive(::Val{:AbstractArrayStyleOps}, type) - return quote - Base.similar(::Broadcast.Broadcasted{<:$type}, ::Type, ::Tuple) - Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:$type}) - end -end From 3650f1d3b200d384b058965d8c89382f5ff8a307 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 07:56:48 -0500 Subject: [PATCH 18/34] Refactor overdubbing --- src/interface_function.jl | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/interface_function.jl b/src/interface_function.jl index 0ebde58..fce3a4f 100644 --- a/src/interface_function.jl +++ b/src/interface_function.jl @@ -1,15 +1,30 @@ -#= -Rewrite `f(args...)` to `DerivableInterfaces.call(interface, f, args...)`. -Similar to `Cassette.overdub`. +# noinline trick to make compiler avoid allocating a string +@noinline _warn_no_impl(interface, f, args) = + "The function `$f` does not have a `$interface` implementation for arguments of type `$(typeof(args))`" -This errors for debugging, but probably should be defined as: -```julia -call(interface, f, args...) = f(args...) -``` -=# -call(interface, f, args...; kwargs...) = error("Not implemented") +""" + call(interface, f, args...; kwargs...) -# Change the behavior of a function to use a certain interface. +Call the overdubbed function implementing `f(args...; kwargs...)` for a given interface. + +See also [`@interface`](@ref). +""" +function call(interface, f, args...; kwargs...) + @warn _warn_no_impl(interface, f, args) maxlog = 1 + return f(args...; kwargs...) +end + +""" + struct InterfaceFunction{I,F} <: Function + +Callable struct to overdub a function `f::F` with a custom implementation based on +an interface `interface::I`. + +## Fields + +- `interface::I`: interface struct +- `f::F`: function to overdub +""" struct InterfaceFunction{Interface,F} <: Function interface::Interface f::F From a93b7d7eaa04606ede6e1db3d403de0aab563e32 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 12:00:15 -0500 Subject: [PATCH 19/34] refactor AbstractArrayInterface and DefaultArrayInterface --- src/DerivableInterfaces.jl | 3 +- src/abstractarrayinterface.jl | 248 ---------------------------------- src/abstractinterface.jl | 1 + src/arrayinterface.jl | 32 +++++ src/defaultarrayinterface.jl | 32 ----- 5 files changed, 34 insertions(+), 282 deletions(-) delete mode 100644 src/abstractarrayinterface.jl create mode 100644 src/arrayinterface.jl delete mode 100644 src/defaultarrayinterface.jl diff --git a/src/DerivableInterfaces.jl b/src/DerivableInterfaces.jl index 1a254d5..08e1c6b 100644 --- a/src/DerivableInterfaces.jl +++ b/src/DerivableInterfaces.jl @@ -5,8 +5,7 @@ include("abstractinterface.jl") include("derive_macro.jl") include("interface_macro.jl") include("wrappedarrays.jl") -include("abstractarrayinterface.jl") -include("defaultarrayinterface.jl") +include("arrayinterface.jl") # Specific AbstractArray alternatives include("concatenate.jl") diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl deleted file mode 100644 index 5e5b56b..0000000 --- a/src/abstractarrayinterface.jl +++ /dev/null @@ -1,248 +0,0 @@ -# TODO: Add `ndims` type parameter. -abstract type AbstractArrayInterface <: AbstractInterface end - -function interface(::Type{<:Broadcast.AbstractArrayStyle}) - return DefaultArrayInterface() -end - -function interface(::Type{<:Broadcast.Broadcasted{Nothing}}) - return DefaultArrayInterface() -end - -function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style} - return interface(Style) -end - -using ArrayLayouts: ArrayLayouts - -@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I...) - return ArrayLayouts.layout_getindex(a, I...) -end - -@interface interface::AbstractArrayInterface function Base.setindex!( - a::AbstractArray, value, I... -) - # TODO: Change to this once broadcasting in `@interface` is supported: - # @interface interface a[I...] .= value - @interface interface map!(identity, @view(a[I...]), value) - return a -end - -# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or -# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`. -# TODO: Use `MethodError`? -@interface ::AbstractArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - return error("Not implemented.") -end - -# TODO: Make this more general, use `Base.to_index`. -@interface interface::AbstractArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::CartesianIndex{N} -) where {N} - return @interface interface getindex(a, Tuple(I)...) -end - -# Linear indexing. -@interface interface ::AbstractArrayInterface function Base.getindex( - a::AbstractArray, I::Int -) - return @interface interface getindex(a, CartesianIndices(a)[I]) -end - -# TODO: Use `MethodError`? -@interface ::AbstractArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - return error("Not implemented.") -end - -# Linear indexing. -@interface interface ::AbstractArrayInterface function Base.setindex!( - a::AbstractArray, value, I::Int -) - return @interface interface setindex!(a, value, CartesianIndices(a)[I]) -end - -# TODO: Make this more general, use `Base.to_index`. -@interface interface::AbstractArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::CartesianIndex{N} -) where {N} - return @interface interface setindex!(a, value, Tuple(I)...) -end - -@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type) - return Broadcast.DefaultArrayStyle{ndims(type)}() -end - -# TODO: Maybe define as `Array{T}(undef, size...)` or -# `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`. -# TODO: Use `MethodError`? -@interface interface::AbstractArrayInterface function Base.similar( - a::AbstractArray, T::Type, size::Tuple{Vararg{Int}} -) - return similar(interface, T, size) -end - -@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray) - a_dest = similar(a) - return a_dest .= a -end - -# TODO: Use `Base.to_shape(axes)` or -# `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`. -# TODO: Make this more general, handle mixtures of integers and ranges (`Union{Integer,Base.OneTo}`). -@interface interface::AbstractArrayInterface function Base.similar( - a::AbstractArray, T::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}} -) - return @interface interface similar(a, T, Base.to_shape(axes)) -end - -@interface interface::AbstractArrayInterface function Base.similar( - bc::Broadcast.Broadcasted, T::Type, axes::Tuple -) - return similar(interface, T, axes) -end - -using MapBroadcast: Mapped -# TODO: Turn this into an `@interface AbstractArrayInterface` function? -# TODO: Look into `SparseArrays.capturescalars`: -# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 -@interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, bc::Broadcast.Broadcasted -) - m = Mapped(bc) - return @interface interface map!(m.f, a_dest, m.args...) -end - -# This captures broadcast expressions such as `a .= 2`. -# Ideally this would be handled by `map!(f, a_dest)` but that isn't defined yet: -# https://github.com/JuliaLang/julia/issues/31677 -# https://github.com/JuliaLang/julia/pull/40632 -@interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} -) - @interface interface fill!(a_dest, bc.f(bc.args...)[]) -end - -# This is defined in this way so we can rely on the Broadcast logic -# for determining the destination of the operation (element type, shape, etc.). -@interface ::AbstractArrayInterface function Base.map(f, as::AbstractArray...) - # TODO: Should this be `@interface interface ...`? That doesn't support - # broadcasting yet. - # Broadcasting is used here to determine the destination array but that - # could be done manually here. - return f.(as...) -end - -# TODO: Maybe define as -# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`. -# TODO: Use `MethodError`? -@interface ::AbstractArrayInterface function Base.map!( - f, a_dest::AbstractArray, a_srcs::AbstractArray... -) - return error("Not implemented.") -end - -@interface interface::AbstractArrayInterface function Base.fill!(a::AbstractArray, value) - @interface interface map!(Returns(value), a, a) -end - -using ArrayLayouts: zero! - -# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` -# and is useful for sparse array logic, since it can be used to empty -# the sparse array storage. -# We use a single function definition to minimize method ambiguities. -@interface interface::AbstractArrayInterface function ArrayLayouts.zero!(a::AbstractArray) - # More generally, the first codepath could be taking if `zero(eltype(a))` - # is defined and the elements are immutable. - f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! - return @interface interface map!(f, a, a) -end - -# Specialized version of `Base.zero` written in terms of `ArrayLayouts.zero!`. -# This is friendlier for sparse arrays since `ArrayLayouts.zero!` makes it easier -# to handle the logic of dropping all elements of the sparse array when possible. -# We use a single function definition to minimize method ambiguities. -@interface interface::AbstractArrayInterface function Base.zero(a::AbstractArray) - # More generally, the first codepath could be taking if `zero(eltype(a))` - # is defined and the elements are immutable. - if eltype(a) <: Number - return @interface interface zero!(similar(a)) - end - return @interface interface map(interface(zero), a) -end - -@interface ::AbstractArrayInterface function Base.mapreduce( - f, op, as::AbstractArray...; kwargs... -) - return error("Not implemented.") -end - -# TODO: Generalize to multiple inputs. -@interface interface::AbstractInterface function Base.reduce(f, a::AbstractArray; kwargs...) - return @interface interface mapreduce(identity, f, a; kwargs...) -end - -@interface interface::AbstractArrayInterface function Base.all(a::AbstractArray) - return @interface interface reduce(&, a; init=true) -end - -@interface interface::AbstractArrayInterface function Base.all( - f::Function, a::AbstractArray -) - return @interface interface mapreduce(f, &, a; init=true) -end - -@interface interface::AbstractArrayInterface function Base.iszero(a::AbstractArray) - return @interface interface all(iszero, a) -end - -@interface interface::AbstractArrayInterface function Base.isreal(a::AbstractArray) - return @interface interface all(isreal, a) -end - -@interface interface::AbstractArrayInterface function Base.permutedims!( - a_dest::AbstractArray, a_src::AbstractArray, perm -) - return @interface interface map!(identity, a_dest, PermutedDimsArray(a_src, perm)) -end - -@interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, a_src::AbstractArray -) - return @interface interface map!(identity, a_dest, a_src) -end - -@interface interface::AbstractArrayInterface function Base.copy!( - a_dest::AbstractArray, a_src::AbstractArray -) - return @interface interface map!(identity, a_dest, a_src) -end - -using LinearAlgebra: LinearAlgebra -# This then requires overloading: -# function ArrayLayouts.materialize!( -# m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout} -# ) -# # Matmul implementation. -# end -@interface ::AbstractArrayInterface function LinearAlgebra.mul!( - a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number -) - return ArrayLayouts.mul!(a_dest, a1, a2, α, β) -end - -@interface ::AbstractArrayInterface function ArrayLayouts.MemoryLayout(type::Type) - # TODO: Define as `UnknownLayout()`? - # TODO: Use `MethodError`? - return error("Not implemented.") -end - -## TODO: Define `const AbstractMatrixInterface = AbstractArrayInterface{2}`, -## requires adding `ndims` type parameter to `AbstractArrayInterface`. -## @interface ::AbstractMatrixInterface function Base.*(a1, a2) -## return ArrayLayouts.mul(a1, a2) -## end diff --git a/src/abstractinterface.jl b/src/abstractinterface.jl index 1cee893..65c2790 100644 --- a/src/abstractinterface.jl +++ b/src/abstractinterface.jl @@ -6,6 +6,7 @@ interface(x1, x_rest...) = combine_interfaces(x1, x_rest...) # Adapted from `Base.Broadcast.combine_styles`. # Get the combined interfaces of the input objects. +# TODO: make rule definitions symmetric function combine_interfaces(x1, x2, x_rest...) return combine_interfaces(combine_interfaces(x1, x2), x_rest...) end diff --git a/src/arrayinterface.jl b/src/arrayinterface.jl new file mode 100644 index 0000000..9b41da6 --- /dev/null +++ b/src/arrayinterface.jl @@ -0,0 +1,32 @@ +""" +`AbstractArrayInterface{N} <: AbstractInterface` is the abstract supertype for any interface +associated with an `AbstractArray` type. +The `N` parameter is the dimensionality, which can be handy for array types that only support +specific dimensionalities. +""" +abstract type AbstractArrayInterface{N} <: AbstractInterface end + +""" +`DefaultArrayInterface{N}()` is the interface indicating that an object behaves as an `N`-dimensional +array, but hasn't defined a specialized interface. In the absence of overrides from other +`AbstractArrayInterface` arguments, this results in non-overdubbed function calls. +""" +struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end + +# avoid emitting warnings in fallback `call` definition +call(::DefaultArrayInterface, f, args...; kwargs...) = f(args...; kwargs...) + +using TypeParameterAccessors: parenttype +# attempt to figure out interface type from parent +function interface(::Type{A}) where {A<:AbstractArray} + pA = parenttype(A) + return pA === A ? DefaultArrayInterface{ndims(A)}() : interface(pA) +end + +function interface(::Type{B}) where {B<:Broadcast.AbstractArrayStyle} + return DefaultArrayInterface{ndims(B)}() +end + +# Combination rules +combine_interface_rule(::DefaultArrayInterface, I::AbstractArrayInterface) = I +combine_interface_rule(I::AbstractArrayInterface, ::DefaultArrayInterface) = I diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl deleted file mode 100644 index bc2f9b6..0000000 --- a/src/defaultarrayinterface.jl +++ /dev/null @@ -1,32 +0,0 @@ -# TODO: Add `ndims` type parameter. -struct DefaultArrayInterface <: AbstractArrayInterface end - -using TypeParameterAccessors: parenttype -function interface(a::Type{<:AbstractArray}) - parenttype(a) === a && return DefaultArrayInterface() - return interface(parenttype(a)) -end - -@interface ::DefaultArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - return Base.getindex(a, I...) -end - -@interface ::DefaultArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - return Base.setindex!(a, value, I...) -end - -@interface ::DefaultArrayInterface function Base.map!( - f, a_dest::AbstractArray, a_srcs::AbstractArray... -) - return Base.map!(f, a_dest, a_srcs...) -end - -@interface ::DefaultArrayInterface function Base.mapreduce( - f, op, as::AbstractArray...; kwargs... -) - return Base.mapreduce(f, op, as...; kwargs...) -end From 219c7527c917f0af112705e0f6ec8ede1bb2a127 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 12:22:14 -0500 Subject: [PATCH 20/34] Re-enable aqua tests --- test/test_aqua.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_aqua.jl b/test/test_aqua.jl index a01787c..1798cfd 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - # Aqua.test_all(DerivableInterfaces) + Aqua.test_all(DerivableInterfaces) end From e0787103fc31499d39220d5f809f9fd060c14d35 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 12:22:21 -0500 Subject: [PATCH 21/34] remove stale dependency --- Project.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Project.toml b/Project.toml index 7f51a0b..d1ca535 100644 --- a/Project.toml +++ b/Project.toml @@ -5,21 +5,17 @@ version = "0.3.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" -MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [compat] Adapt = "4.1.1" -ArrayLayouts = "1.11.0" Compat = "3.47,4.10" ExproniconLite = "0.10.13" LinearAlgebra = "1.10" MLStyle = "0.4.17" -MapBroadcast = "0.1.5" TypeParameterAccessors = "0.2, 0.3" julia = "1.10" From ac8bf842c4f89b1901f29ca24b3ab3e1dc51c3e0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 12:22:40 -0500 Subject: [PATCH 22/34] fix method ambiguity --- src/arrayinterface.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/arrayinterface.jl b/src/arrayinterface.jl index 9b41da6..b234e81 100644 --- a/src/arrayinterface.jl +++ b/src/arrayinterface.jl @@ -28,5 +28,18 @@ function interface(::Type{B}) where {B<:Broadcast.AbstractArrayStyle} end # Combination rules -combine_interface_rule(::DefaultArrayInterface, I::AbstractArrayInterface) = I -combine_interface_rule(I::AbstractArrayInterface, ::DefaultArrayInterface) = I +function combine_interface_rule( + ::DefaultArrayInterface{N}, I::AbstractArrayInterface{N} +) where {N} + return I +end +function combine_interface_rule( + I::AbstractArrayInterface{N}, ::DefaultArrayInterface{N} +) where {N} + return I +end +function combine_interface_rule( + ::DefaultArrayInterface{N}, ::DefaultArrayInterface{N} +) where {N} + return DefaultArrayInterface{N}() +end From a9be4c254155de9397c3996e0eb012e4b0d89168 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 13:59:47 -0500 Subject: [PATCH 23/34] Fix `@derive (T=...,) f(x)` without module --- src/derive_macro.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/derive_macro.jl b/src/derive_macro.jl index 8579b6f..eda8a5b 100644 --- a/src/derive_macro.jl +++ b/src/derive_macro.jl @@ -185,9 +185,7 @@ function derive_interface_func(interface::Union{Symbol,Expr}, func::Expr) # TODO: Use the `@interface` macro rather than `DerivableInterfaces.call` # directly, in case we want to change the implementation. body_args = [interface; name; body_args...] - body_name = @match name begin - :($M.$f) => :(DerivableInterfaces.call) - end + body_name = :(DerivableInterfaces.call) # TODO: Remove defaults from `kwargs`. _, body, _ = split_function( codegen_ast(JLFunction(; name=body_name, args=body_args, kwargs)) From 79f9a82b85b70ec211d6ad4a35c4770e61052c88 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 13:59:55 -0500 Subject: [PATCH 24/34] refactor tests --- test/SparseArrayDOKs.jl | 278 ---------------------------------------- test/test_basics.jl | 203 ++++++++--------------------- 2 files changed, 57 insertions(+), 424 deletions(-) delete mode 100644 test/SparseArrayDOKs.jl diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl deleted file mode 100644 index 46de38d..0000000 --- a/test/SparseArrayDOKs.jl +++ /dev/null @@ -1,278 +0,0 @@ -module SparseArrayDOKs - -isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...) -getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...) -getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...) -function setstoredindex!(a::AbstractArray, value, I::CartesianIndex) - return setstoredindex!(a, value, Tuple(I)...) -end -function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex) - return setunstoredindex!(a, value, Tuple(I)...) -end - -# A view of the stored values of an array. -# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue -# with that is it returns a `SubArray` wrapping a sparse array, which -# is then interpreted as a sparse array. Also, that involves extra -# logic for determining if the indices are stored or not, but we know -# the indices are stored. -struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T} - array::A - storedindices::I -end -StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a))) -Base.size(a::StoredValues) = size(a.storedindices) -Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I]) -function Base.setindex!(a::StoredValues, value, I::Int) - return setstoredindex!(a.array, value, a.storedindices[I]) -end - -storedvalues(a::AbstractArray) = StoredValues(a) - -using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout -using DerivableInterfaces: - DerivableInterfaces, - @array_aliases, - @derive, - @interface, - AbstractArrayInterface, - interface -using LinearAlgebra: LinearAlgebra - -# Define an interface. -struct SparseArrayInterface <: AbstractArrayInterface end - -# Define interface functions. -@interface ::SparseArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - checkbounds(a, I...) - !isstored(a, I...) && return getunstoredindex(a, I...) - return getstoredindex(a, I...) -end -@interface ::SparseArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - checkbounds(a, I...) - if !isstored(a, I...) - iszero(value) && return a - setunstoredindex!(a, value, I...) - return a - end - setstoredindex!(a, value, I...) - return a -end - -struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end -SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}() - -DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface() - -@derive SparseArrayStyle AbstractArrayStyleOps - -function Base.similar(::SparseArrayInterface, ::Type{T}, ax) where {T} - return similar(SparseArrayDOK{T}, ax) -end - -# Interface functions. -@interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type) - return SparseArrayStyle{ndims(type)}() -end - -struct SparseLayout <: MemoryLayout end - -@interface ::SparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type) - return SparseLayout() -end - -@interface ::SparseArrayInterface function Base.map!( - f, a_dest::AbstractArray, as::AbstractArray... -) - # TODO: Define a function `preserves_unstored(a_dest, f, as...)` - # to determine if a function preserves the stored values - # of the destination sparse array. - # The current code may be inefficient since it actually - # accesses an unstored element, which in the case of a - # sparse array of arrays can allocate an array. - # Sparse arrays could be expected to define a cheap - # unstored element allocator, for example - # `get_prototypical_unstored(a::AbstractArray)`. - I = first(eachindex(as...)) - preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...)) - if !preserves_unstored - # Doesn't preserve unstored values, loop over all elements. - for I in eachindex(as...) - a_dest[I] = map(f, map(a -> a[I], as)...) - end - end - # TODO: Define `eachstoredindex(as...)`. - for I in union(eachstoredindex.(as)...) - a_dest[I] = map(f, map(a -> a[I], as)...) - end - return a_dest -end - -@interface ::SparseArrayInterface function Base.mapreduce( - f, op, a::AbstractArray; kwargs... -) - # TODO: Need to select a better `init`. - return mapreduce(f, op, storedvalues(a); kwargs...) -end - -# ArrayLayouts functionality. - -function ArrayLayouts.sub_materialize(::SparseLayout, a::AbstractArray, axes::Tuple) - a_dest = similar(a) - a_dest .= a - return a_dest -end - -function ArrayLayouts.materialize!( - m::MatMulMatAdd{<:SparseLayout,<:SparseLayout,<:SparseLayout} -) - a_dest, a1, a2, α, β = m.C, m.A, m.B, m.α, m.β - for I1 in eachstoredindex(a1) - for I2 in eachstoredindex(a2) - if I1[2] == I2[1] - I_dest = CartesianIndex(I1[1], I2[2]) - a_dest[I_dest] = a1[I1] * a2[I2] * α + a_dest[I_dest] * β - end - end - end - return a_dest -end - -# Sparse array minimal interface -using LinearAlgebra: Adjoint -function isstored(a::Adjoint, i::Int, j::Int) - return isstored(parent(a), j, i) -end -function getstoredindex(a::Adjoint, i::Int, j::Int) - return getstoredindex(parent(a), j, i)' -end -function getunstoredindex(a::Adjoint, i::Int, j::Int) - return getunstoredindex(parent(a), j, i)' -end -function eachstoredindex(a::Adjoint) - return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) -end - -perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p -iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip - -# TODO: Use `Base.PermutedDimsArrays.genperm` or -# https://github.com/jipolanco/StaticPermutations.jl? -genperm(v, perm) = map(j -> v[j], perm) - -function isstored(a::PermutedDimsArray, I::Int...) - return isstored(parent(a), genperm(I, iperm(a))...) -end -function getstoredindex(a::PermutedDimsArray, I::Int...) - return getstoredindex(parent(a), genperm(I, iperm(a))...) -end -function getunstoredindex(a::PermutedDimsArray, I::Int...) - return getunstoredindex(parent(a), genperm(I, iperm(a))...) -end -function eachstoredindex(a::PermutedDimsArray) - return map(collect(eachstoredindex(parent(a)))) do I - return CartesianIndex(genperm(I, perm(a))) - end -end - -tuple_oneto(n) = ntuple(identity, n) -## This is an optimization for `storedvalues` for DOK. -## function valuesview(d::Dict, keys) -## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]] -## end - -function eachstoredparentindex(a::SubArray) - return filter(eachstoredindex(parent(a))) do I - return all(d -> I[d] ∈ parentindices(a)[d], 1:ndims(parent(a))) - end -end -function storedvalues(a::SubArray) - return @view parent(a)[collect(eachstoredparentindex(a))] -end -function isstored(a::SubArray, I::Int...) - return isstored(parent(a), Base.reindex(parentindices(a), I)...) -end -function getstoredindex(a::SubArray, I::Int...) - return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...) -end -function getunstoredindex(a::SubArray, I::Int...) - return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...) -end -function setstoredindex!(a::SubArray, value, I::Int...) - return setstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...) -end -function setunstoredindex!(a::SubArray, value, I::Int...) - return setunstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...) -end -function eachstoredindex(a::SubArray) - nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d - return !(parentindices(a)[d] isa Real) - end - return collect(( - CartesianIndex( - map(nonscalardims) do d - return findfirst(==(I[d]), parentindices(a)[d]) - end, - ) for I in eachstoredparentindex(a) - )) -end - -# Define a type that will derive the interface. -struct SparseArrayDOK{T,N} <: AbstractArray{T,N} - storage::Dict{CartesianIndex{N},T} - size::NTuple{N,Int} -end -storage(a::SparseArrayDOK) = a.storage -Base.size(a::SparseArrayDOK) = a.size -function SparseArrayDOK{T}(size::Int...) where {T} - N = length(size) - return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) -end -# Used in `Base.similar`. -function SparseArrayDOK{T}(::UndefInitializer, size::Tuple{Vararg{Int}}) where {T} - return SparseArrayDOK{T}(size...) -end -function isstored(a::SparseArrayDOK, I::Int...) - return CartesianIndex(I) in keys(storage(a)) -end -function getstoredindex(a::SparseArrayDOK, I::Int...) - return storage(a)[CartesianIndex(I)] -end -function getunstoredindex(a::SparseArrayDOK, I::Int...) - return zero(eltype(a)) -end -function setstoredindex!(a::SparseArrayDOK, value, I::Int...) - storage(a)[CartesianIndex(I)] = value - return a -end -function setunstoredindex!(a::SparseArrayDOK, value, I::Int...) - storage(a)[CartesianIndex(I)] = value - return a -end -eachstoredindex(a::SparseArrayDOK) = keys(storage(a)) -storedlength(a::SparseArrayDOK) = length(eachstoredindex(a)) - -function ArrayLayouts.zero!(a::SparseArrayDOK) - empty!(storage(a)) - return a -end - -# Specify the interface the type adheres to. -DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() - -# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. -@array_aliases SparseArrayDOK - -# DerivableInterfaces the interface for the type. -@derive AnySparseArrayDOK AbstractArrayOps - -# avoid overloading `Base.cat` because of method invalidations -function Base._cat(dims, args::SparseArrayDOK...) - return DerivableInterfaces.Concatenate.cat(args...; dims) -end - -end diff --git a/test/test_basics.jl b/test/test_basics.jl index d088f5a..2d42d3a 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,153 +1,64 @@ -using ArrayLayouts: zero! -include("SparseArrayDOKs.jl") -using .SparseArrayDOKs: SparseArrayDOK, storedlength -using Test: @test, @testset - -elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) -@testset "DerivableInterfaces" for elt in elts - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - @test a isa SparseArrayDOK{elt,2} - @test size(a) == (2, 2) - @test a[1, 1] == 0 - @test a[1, 1, 1] == 0 - @test a[1, 2] == 12 - @test a[1, 2, 1] == 12 - @test storedlength(a) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - for b in (similar(a, Float32, (3, 3)), similar(a, Float32, Base.OneTo.((3, 3)))) - @test b isa SparseArrayDOK{Float32,2} - @test b == zeros(Float32, 3, 3) - @test size(b) == (3, 3) - end - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = similar(a) - bc = Broadcast.Broadcasted(x -> 2x, (a,)) - copyto!(b, bc) - @test b isa SparseArrayDOK{elt,2} - @test b == [0 24; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(3, 3, 3) - a[1, 2, 3] = 123 - b = permutedims(a, (2, 3, 1)) - @test b isa SparseArrayDOK{elt,3} - @test b[2, 3, 1] == 123 - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = copy(a') - @test b isa SparseArrayDOK{elt,2} - @test b == [0 0; 12 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = map(x -> 2x, a) - @test b isa SparseArrayDOK{elt,2} - @test b == [0 24; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a * a' - @test b isa SparseArrayDOK{elt,2} - @test b == [144 0; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a .+ 2 .* a' - @test b isa SparseArrayDOK{elt,2} - @test b == [0 12; 24 0] - @test storedlength(b) == 2 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a[1:2, 2] - @test b isa SparseArrayDOK{elt,1} - @test b == [12, 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - @test iszero(a) - a[2, 1] = 21 - a[1, 2] = 12 - @test !iszero(a) - @test isreal(a) - @test sum(a) == 33 - @test mapreduce(x -> 2x, +, a) == 66 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = similar(a) - copyto!(b, a) - @test b isa SparseArrayDOK{elt,2} - @test b == a - @test b[1, 2] == 12 - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a .= 2 - @test storedlength(a) == length(a) - for I in eachindex(a) - @test a[I] == 2 - end - - a = SparseArrayDOK{elt}(2, 2) - fill!(a, 2) - @test storedlength(a) == length(a) - for I in eachindex(a) - @test a[I] == 2 - end +using Test: @test, @testset, @test_throws +using DerivableInterfaces: DerivableInterfaces as DI +using DerivableInterfaces: @derive, @interface + +# Test setup +# ---------- +struct MyArray{T,N} <: AbstractArray{T,N} + parent::Array{T,N} +end +Base.parent(A::MyArray) = A.parent - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - zero!(a) - @test iszero(a) - @test iszero(storedlength(a)) +@derive (T=MyArray,) Base.getindex(::T, ::Int) - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = zero(a) - @test b isa SparseArrayDOK{elt,2} - @test iszero(b) - @test iszero(storedlength(b)) +# Interfacetype +struct MyInterface{N} <: DI.AbstractArrayInterface{N} end +DI.interface(::Type{A}) where {A<:MyArray} = MyInterface{ndims(A)}() - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - b[2:3, 2:3] .= a - @test isone(storedlength(b)) - @test b[2, 3] == 12 +const f_ctr = Ref(0) # used to verify if function was actually called +@interface ::MyInterface function Base.getindex(A::MyArray, i::Int) + f_ctr[] += 1 + return getindex(parent(A), i) +end +@interface ::MyInterface function Base.getindex(A::AbstractArray, i::Int) + f_ctr[] += 1 + return getindex(A::AbstractArray, i::Int) +end - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - b[2:3, 2:3] = a - @test isone(storedlength(b)) - @test b[2, 3] == 12 +f(A, B) = -1 +for N in 1:3 + @eval @interface ::MyInterface{$N} f(A, B) = $N +end +@derive (T=AbstractArray,) f(::T, ::T) + +# Tests +# ----- +# TODO: test type stability +@testset "@derived types" begin + ctr = f_ctr[] + A = rand(Int, 3) + B = MyArray(A) + @test A[1] == B[1] + @test f_ctr[] == ctr + 1 +end - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - c = @view b[2:3, 2:3] - c .= a - @test isone(storedlength(b)) - @test b[2, 3] == 12 +@testset "using @interface functions for non-derived types" begin + ctr = f_ctr[] + A = zeros(Int, 3) + @test A[1] == @interface MyInterface{1}() A[1] + @test f_ctr[] == ctr + 1 +end - a1 = SparseArrayDOK{elt}(2, 2) - a1[1, 2] = 12 - a2 = SparseArrayDOK{elt}(2, 2) - a2[2, 1] = 21 - b = cat(a1, a2; dims=(1, 2)) - @test b isa SparseArrayDOK{elt,2} - @test storedlength(b) == 2 - @test b[1, 2] == 12 - @test b[4, 3] == 21 +@testset "interface promotion rules" begin + # DefaultArrayInterface should give default + @test f(zeros(1), zeros(1)) == -1 + @test f(zeros(1), zeros(1, 1)) == -1 + # MyInterface + @test f(MyArray(zeros(1)), MyArray(zeros(1))) == 1 + @test f(MyArray(zeros(1, 1)), MyArray(zeros(1, 1))) == 2 + # Mix + @test f(MyArray(zeros(1)), zeros(1)) == 1 + @test f(zeros(1), MyArray(zeros(1))) == 1 + # undefined mix + @test f((1,), zeros(1)) == 1 end From f11399b031b1d7a28eef72e6e4757c5b7c53bc54 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 14:00:05 -0500 Subject: [PATCH 25/34] mention infinite recursion --- src/arrayinterface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/arrayinterface.jl b/src/arrayinterface.jl index b234e81..fb53b8e 100644 --- a/src/arrayinterface.jl +++ b/src/arrayinterface.jl @@ -14,6 +14,7 @@ array, but hasn't defined a specialized interface. In the absence of overrides f struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end # avoid emitting warnings in fallback `call` definition +# TODO: this does not work and leads to infinite recursion call(::DefaultArrayInterface, f, args...; kwargs...) = f(args...; kwargs...) using TypeParameterAccessors: parenttype From c6e506afd8d0c110f915b80dbe1a9b3a2eec9336 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 14:17:38 -0500 Subject: [PATCH 26/34] update skeleton --- .github/workflows/IntegrationTestRequest.yml | 14 ++++++++++++++ docs/make.jl | 2 +- docs/src/reference.md | 5 +++++ test/runtests.jl | 8 +++++--- 4 files changed, 25 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/IntegrationTestRequest.yml create mode 100644 docs/src/reference.md diff --git a/.github/workflows/IntegrationTestRequest.yml b/.github/workflows/IntegrationTestRequest.yml new file mode 100644 index 0000000..d42fcca --- /dev/null +++ b/.github/workflows/IntegrationTestRequest.yml @@ -0,0 +1,14 @@ +name: "Integration Test Request" + +on: + issue_comment: + types: [created] + +jobs: + integrationrequest: + if: | + github.event.issue.pull_request && + contains(fromJSON('["OWNER", "COLLABORATOR", "MEMBER"]'), github.event.comment.author_association) + uses: ITensor/ITensorActions/.github/workflows/IntegrationTestRequest.yml@main + with: + localregistry: https://github.com/ITensor/ITensorRegistry.git diff --git a/docs/make.jl b/docs/make.jl index 5655860..8ba16e2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,7 +16,7 @@ makedocs(; edit_link="main", assets=String[], ), - pages=["Home" => "index.md"], + pages=["Home" => "index.md", "Reference" => "reference.md"], ) deploydocs(; diff --git a/docs/src/reference.md b/docs/src/reference.md new file mode 100644 index 0000000..7a417ab --- /dev/null +++ b/docs/src/reference.md @@ -0,0 +1,5 @@ +# Reference + +```@autodocs +Modules = [DerivableInterfaces] +``` diff --git a/test/runtests.jl b/test/runtests.jl index 2b74a89..1c52c3e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,9 +24,11 @@ isexamplefile(fn) = # tests in groups based on folder structure for testgroup in filter(isdir, readdir(@__DIR__)) if GROUP == "ALL" || GROUP == uppercase(testgroup) - for file in filter(istestfile, readdir(joinpath(@__DIR__, testgroup); join=true)) - @eval @safetestset $(last(splitdir(file))) begin - include($file) + groupdir = joinpath(@__DIR__, testgroup) + for file in filter(istestfile, readdir(groupdir)) + filename = joinpath(groupdir, file) + @eval @safetestset $file begin + include($filename) end end end From 52514c29d3e2d850100f970bedae67465ba39009 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 18 Feb 2025 16:48:46 -0500 Subject: [PATCH 27/34] remove default fallback implementation --- src/arrayinterface.jl | 11 +++++++---- src/interface_function.jl | 10 ++-------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/arrayinterface.jl b/src/arrayinterface.jl index fb53b8e..fff9661 100644 --- a/src/arrayinterface.jl +++ b/src/arrayinterface.jl @@ -12,10 +12,9 @@ array, but hasn't defined a specialized interface. In the absence of overrides f `AbstractArrayInterface` arguments, this results in non-overdubbed function calls. """ struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end - -# avoid emitting warnings in fallback `call` definition -# TODO: this does not work and leads to infinite recursion -call(::DefaultArrayInterface, f, args...; kwargs...) = f(args...; kwargs...) +# this effectively has almost no implementations, as they are inherited from the supertype +# either explicitly or will throw an error. It is simply a concrete instance to use the +# abstractarrayinterface implementations. using TypeParameterAccessors: parenttype # attempt to figure out interface type from parent @@ -29,6 +28,7 @@ function interface(::Type{B}) where {B<:Broadcast.AbstractArrayStyle} end # Combination rules +# ----------------- function combine_interface_rule( ::DefaultArrayInterface{N}, I::AbstractArrayInterface{N} ) where {N} @@ -44,3 +44,6 @@ function combine_interface_rule( ) where {N} return DefaultArrayInterface{N}() end + +# Fallback implementations +# ------------------------ diff --git a/src/interface_function.jl b/src/interface_function.jl index fce3a4f..62d64b2 100644 --- a/src/interface_function.jl +++ b/src/interface_function.jl @@ -1,7 +1,3 @@ -# noinline trick to make compiler avoid allocating a string -@noinline _warn_no_impl(interface, f, args) = - "The function `$f` does not have a `$interface` implementation for arguments of type `$(typeof(args))`" - """ call(interface, f, args...; kwargs...) @@ -9,10 +5,8 @@ Call the overdubbed function implementing `f(args...; kwargs...)` for a given in See also [`@interface`](@ref). """ -function call(interface, f, args...; kwargs...) - @warn _warn_no_impl(interface, f, args) maxlog = 1 - return f(args...; kwargs...) -end +call(interface, f, args...; kwargs...) = throw(MethodError(interface(f), args)) +# TODO: do we want to methoderror for `call` instead? """ struct InterfaceFunction{I,F} <: Function From a011cbd4cb61131f00e79585841609d2a3d05189 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 26 Feb 2025 12:57:51 -0500 Subject: [PATCH 28/34] Add dedicated `zero!` --- src/DerivableInterfaces.jl | 3 +++ src/concatenate.jl | 4 +--- src/zero.jl | 8 ++++++++ 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 src/zero.jl diff --git a/src/DerivableInterfaces.jl b/src/DerivableInterfaces.jl index 08e1c6b..0bb21a1 100644 --- a/src/DerivableInterfaces.jl +++ b/src/DerivableInterfaces.jl @@ -1,5 +1,6 @@ module DerivableInterfaces +export zero! include("interface_function.jl") include("abstractinterface.jl") include("derive_macro.jl") @@ -8,6 +9,8 @@ include("wrappedarrays.jl") include("arrayinterface.jl") # Specific AbstractArray alternatives + include("concatenate.jl") +include("zero.jl") end diff --git a/src/concatenate.jl b/src/concatenate.jl index a0dc296..287a6d1 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -32,7 +32,7 @@ using Compat: @compat @compat public Concatenated using Base: promote_eltypeof -using ..DerivableInterfaces: DerivableInterfaces, AbstractInterface, interface +using ..DerivableInterfaces: DerivableInterfaces, AbstractInterface, interface, zero! """ Concatenated{Interface,Dims,Args<:Tuple} @@ -158,6 +158,4 @@ end # copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) # copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) -zero!(x::AbstractArray) = fill!(x, zero(eltype(x))) - end diff --git a/src/zero.jl b/src/zero.jl new file mode 100644 index 0000000..dae9277 --- /dev/null +++ b/src/zero.jl @@ -0,0 +1,8 @@ +""" + zero!(x::AbstractArray) + +In-place function for zero-ing out an array. +""" +zero!(x::AbstractArray) = @interface interface(x) zero!(x) + +@interface ::AbstractArrayInterface zero!(x::AbstractArray) = fill!(x, zero(eltype(x))) From d5378a74b7a51a478c237b83e3e303fcd2d3cfdf Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 26 Feb 2025 12:57:59 -0500 Subject: [PATCH 29/34] Add exports --- src/DerivableInterfaces.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/DerivableInterfaces.jl b/src/DerivableInterfaces.jl index 0bb21a1..5ba6f2c 100644 --- a/src/DerivableInterfaces.jl +++ b/src/DerivableInterfaces.jl @@ -1,6 +1,9 @@ module DerivableInterfaces +export @derive, @interface +export interface, AbstractInterface, AbstractArrayInterface export zero! + include("interface_function.jl") include("abstractinterface.jl") include("derive_macro.jl") From ced87b646f6c4409220fd9b7115ed50fb5ac5b00 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 26 Feb 2025 12:58:13 -0500 Subject: [PATCH 30/34] Fix linenumbers --- src/derive_macro.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/derive_macro.jl b/src/derive_macro.jl index eda8a5b..0f16cc2 100644 --- a/src/derive_macro.jl +++ b/src/derive_macro.jl @@ -10,7 +10,7 @@ argname(i::Int) = Symbol(:arg, i) function rmlines(expr) return @match expr begin e::Expr => Expr(e.head, filter(!isnothing, map(rmlines, e.args))...) - _::LineNumberNode => nothing + # _::LineNumberNode => nothing a => a end end @@ -130,6 +130,8 @@ function derive_func(interface_or_types::Union{Symbol,Expr}, func::Expr) return derive_interface_func(interface, func) end +derive_func(::Union{Symbol,Expr}, l::LineNumberNode) = l + #= ```julia @derive (T=SparseArrayDOK,) Base.getindex(::T, ::Int...) From 67f3a3d55df9e8f5307ee3d0f6ae884601fb792b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 26 Feb 2025 13:54:03 -0500 Subject: [PATCH 31/34] Fix types --- src/derive_macro.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/derive_macro.jl b/src/derive_macro.jl index 0f16cc2..2799b67 100644 --- a/src/derive_macro.jl +++ b/src/derive_macro.jl @@ -96,7 +96,7 @@ function replace_typevars(types::Expr, func::Expr) :($x = $y) => (x, y) end # TODO: Handle type parameters in other positions besides the first one. - new_args = map(args) do arg + new_args = map(new_args) do arg return @match arg begin :(::$Type{<:$T}) => T == typevar ? :(::$Type{<:$type}) : :(::$Type{<:$T}) :(::$T...) => T == typevar ? :(::$type...) : :(::$T...) From 857d86d3d6d5da8dd38976380b589e770176ea33 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 26 Feb 2025 13:54:12 -0500 Subject: [PATCH 32/34] remove `ndims` type parameter --- src/arrayinterface.jl | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/arrayinterface.jl b/src/arrayinterface.jl index fff9661..bd65465 100644 --- a/src/arrayinterface.jl +++ b/src/arrayinterface.jl @@ -1,17 +1,15 @@ """ -`AbstractArrayInterface{N} <: AbstractInterface` is the abstract supertype for any interface +`AbstractArrayInterface <: AbstractInterface` is the abstract supertype for any interface associated with an `AbstractArray` type. -The `N` parameter is the dimensionality, which can be handy for array types that only support -specific dimensionalities. """ -abstract type AbstractArrayInterface{N} <: AbstractInterface end +abstract type AbstractArrayInterface <: AbstractInterface end """ -`DefaultArrayInterface{N}()` is the interface indicating that an object behaves as an `N`-dimensional +`DefaultArrayInterface()` is the interface indicating that an object behaves as an array, but hasn't defined a specialized interface. In the absence of overrides from other `AbstractArrayInterface` arguments, this results in non-overdubbed function calls. """ -struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end +struct DefaultArrayInterface <: AbstractArrayInterface end # this effectively has almost no implementations, as they are inherited from the supertype # either explicitly or will throw an error. It is simply a concrete instance to use the # abstractarrayinterface implementations. @@ -20,29 +18,23 @@ using TypeParameterAccessors: parenttype # attempt to figure out interface type from parent function interface(::Type{A}) where {A<:AbstractArray} pA = parenttype(A) - return pA === A ? DefaultArrayInterface{ndims(A)}() : interface(pA) + return pA === A ? DefaultArrayInterface() : interface(pA) end function interface(::Type{B}) where {B<:Broadcast.AbstractArrayStyle} - return DefaultArrayInterface{ndims(B)}() + return DefaultArrayInterface() end # Combination rules # ----------------- -function combine_interface_rule( - ::DefaultArrayInterface{N}, I::AbstractArrayInterface{N} -) where {N} +function combine_interface_rule(::DefaultArrayInterface, I::AbstractArrayInterface) return I end -function combine_interface_rule( - I::AbstractArrayInterface{N}, ::DefaultArrayInterface{N} -) where {N} +function combine_interface_rule(I::AbstractArrayInterface, ::DefaultArrayInterface) return I end -function combine_interface_rule( - ::DefaultArrayInterface{N}, ::DefaultArrayInterface{N} -) where {N} - return DefaultArrayInterface{N}() +function combine_interface_rule(::DefaultArrayInterface, ::DefaultArrayInterface) + return DefaultArrayInterface() end # Fallback implementations From fe8031b6730c155e5b9b054333f3a6fb60031058 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 26 Feb 2025 17:18:13 -0500 Subject: [PATCH 33/34] start adding fallbacks --- src/arrayinterface.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/arrayinterface.jl b/src/arrayinterface.jl index bd65465..b03e866 100644 --- a/src/arrayinterface.jl +++ b/src/arrayinterface.jl @@ -1,5 +1,6 @@ """ `AbstractArrayInterface <: AbstractInterface` is the abstract supertype for any interface +using Base: BroadcastStyle associated with an `AbstractArray` type. """ abstract type AbstractArrayInterface <: AbstractInterface end @@ -25,6 +26,10 @@ function interface(::Type{B}) where {B<:Broadcast.AbstractArrayStyle} return DefaultArrayInterface() end +function interface(::Type{B}) where {B<:Broadcast.Broadcasted} + return interface(Broadcast.BroadcastStyle(B)) +end + # Combination rules # ----------------- function combine_interface_rule(::DefaultArrayInterface, I::AbstractArrayInterface) @@ -39,3 +44,28 @@ end # Fallback implementations # ------------------------ +# whenever we want to overload new interface implementations, we better have a fallback that +# sends us back to the default implementation. + +@interface ::AbstractArrayInterface Base.getindex(A::AbstractArray, I...) = ( + @inline; getindex(A, I...) +) +@interface ::AbstractArrayInterface Base.setindex!(A::AbstractArray, v, I...) = ( + @inline; setindex!(A, v, I...) +) + +@interface ::AbstractArrayInterface Base.similar( + A::AbstractArray, ::Type{T}, axes +) where {T} = similar(A, T, axes) + +@interface ::AbstractArrayInterface Base.map(f, A::AbstractArray, As::AbstractArray...) = + map(f, A, As...) +@interface ::AbstractArrayInterface Base.map!(f, A::AbstractArray, As::AbstractArray...) = + map!(f, A, As...) + +@interface ::AbstractArrayInterface Base.reduce( + op, A::AbstractArray, As::AbstractArray... +) = reduce(op, A, As...) +@interface ::AbstractArrayInterface Base.mapreduce( + f, op, A::AbstractArray, As::AbstractArray... +) = mapreduce(f, op, A, As...) From 3c7eee80d163819a44149300d6b661e465013306 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 26 Feb 2025 17:18:19 -0500 Subject: [PATCH 34/34] improve testing --- test/test_basics.jl | 19 ++++++++----------- test/test_defaultarrayinterface.jl | 4 ++-- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index 2d42d3a..282598a 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,6 @@ using Test: @test, @testset, @test_throws using DerivableInterfaces: DerivableInterfaces as DI -using DerivableInterfaces: @derive, @interface +using DerivableInterfaces: @derive, @interface, AbstractArrayInterface # Test setup # ---------- @@ -12,8 +12,8 @@ Base.parent(A::MyArray) = A.parent @derive (T=MyArray,) Base.getindex(::T, ::Int) # Interfacetype -struct MyInterface{N} <: DI.AbstractArrayInterface{N} end -DI.interface(::Type{A}) where {A<:MyArray} = MyInterface{ndims(A)}() +struct MyInterface <: DI.AbstractArrayInterface end +DI.interface(::Type{A}) where {A<:MyArray} = MyInterface() const f_ctr = Ref(0) # used to verify if function was actually called @interface ::MyInterface function Base.getindex(A::MyArray, i::Int) @@ -25,11 +25,9 @@ end return getindex(A::AbstractArray, i::Int) end -f(A, B) = -1 -for N in 1:3 - @eval @interface ::MyInterface{$N} f(A, B) = $N -end -@derive (T=AbstractArray,) f(::T, ::T) +@derive (TA=Any, TB=Any) f(::TA, ::TB) +@interface ::AbstractArrayInterface f(A, B) = -1 +@interface ::MyInterface f(A, B) = 1 # Tests # ----- @@ -45,7 +43,7 @@ end @testset "using @interface functions for non-derived types" begin ctr = f_ctr[] A = zeros(Int, 3) - @test A[1] == @interface MyInterface{1}() A[1] + @test A[1] == @interface MyInterface() A[1] @test f_ctr[] == ctr + 1 end @@ -55,10 +53,9 @@ end @test f(zeros(1), zeros(1, 1)) == -1 # MyInterface @test f(MyArray(zeros(1)), MyArray(zeros(1))) == 1 - @test f(MyArray(zeros(1, 1)), MyArray(zeros(1, 1))) == 2 # Mix @test f(MyArray(zeros(1)), zeros(1)) == 1 @test f(zeros(1), MyArray(zeros(1))) == 1 # undefined mix - @test f((1,), zeros(1)) == 1 + @test_throws ErrorException f((1,), zeros(1)) == -1 end diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index d12bade..55be97c 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -33,6 +33,6 @@ end @testset "Broadcast.DefaultArrayStyle" begin @test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface() - @test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) == - DefaultArrayInterface() + @test interface(Broadcast.broadcasted(+, randn(2), randn(2))) == + DefaultArrayInterface() end