From 671d9b963491cf3a46541e83bf3486f6aff8b0e8 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 10 Mar 2025 22:49:09 -0400 Subject: [PATCH 1/5] More customization points in Concatenate --- Project.toml | 11 ++- .../DerivableInterfacesBlockArraysExt.jl | 10 +++ src/concatenate.jl | 74 ++++++++++++++----- src/defaultarrayinterface.jl | 4 + src/zero.jl | 4 + test/test_concatenate.jl | 31 ++++++++ 6 files changed, 112 insertions(+), 22 deletions(-) create mode 100644 ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl create mode 100644 test/test_concatenate.jl diff --git a/Project.toml b/Project.toml index 7092f35..eb5847e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DerivableInterfaces" uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" authors = ["ITensor developers and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -13,9 +13,16 @@ MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" +[weakdeps] +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + +[extensions] +DerivableInterfacesBlockArraysExt = "BlockArrays" + [compat] Adapt = "4.1.1" -ArrayLayouts = "1.11.0" +ArrayLayouts = "1.11" +BlockArrays = "1.4" Compat = "3.47,4.10" ExproniconLite = "0.10.13" LinearAlgebra = "1.10" diff --git a/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl b/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl new file mode 100644 index 0000000..000fe6b --- /dev/null +++ b/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl @@ -0,0 +1,10 @@ +module DerivableInterfacesBlockArraysExt + +using BlockArrays: BlockedOneTo, blockedrange, blocklengths +using DerivableInterfaces.Concatenate: Concatenate + +function Concatenate.cat_axis(a1::BlockedOneTo, a2::BlockedOneTo) + return blockedrange([blocklengths(a1); blocklengths(a2)]) +end + +end diff --git a/src/concatenate.jl b/src/concatenate.jl index 01229d2..3c6b9e5 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -31,33 +31,40 @@ using Base: promote_eltypeof using ..DerivableInterfaces: DerivableInterfaces, AbstractInterface, interface, zero!, arraytype +unval(x) = x +unval(::Val{x}) where {x} = x + +function _Concatenated end + """ - Concatenated{Interface,Dims,Args<:Tuple} + Concatenated{Interface,Dims,Axes,Args<:Tuple} Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide hooks to customize the implementation. """ -struct Concatenated{Interface,Dims,Args<:Tuple} +struct Concatenated{Interface,Dims,Axes,Args<:Tuple} interface::Interface dims::Val{Dims} args::Args - - function Concatenated( - interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple - ) where {Dims} - return new{typeof(interface),Dims,typeof(args)}(interface, dims, args) - end - function Concatenated(dims, args::Tuple) - return Concatenated(interface(args...), dims, args) - end - function Concatenated{Interface}(dims, args) where {Interface} - return Concatenated(Interface(), dims, args) - end - function Concatenated{Interface,Dims}(args) where {Interface,Dims} - return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args) + axes::Axes + global @inline function _Concatenated( + interface::Interface, dims::Val{Dims}, args::Args + ) where {Interface,Dims,Args<:Tuple} + ax = cat_axes(dims, args...) + return new{Interface,Dims,typeof(ax),Args}(interface, dims, args, ax) end end +function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple) + return _Concatenated(interface, dims, args) +end +function Concatenated(dims::Val, args::Tuple) + return Concatenated(interface(args...), dims, args) +end +function Concatenated{Interface}(dims::Val, args) where {Interface} + return Concatenated(Interface(), dims, args) +end + dims(::Concatenated{A,D}) where {A,D} = D DerivableInterfaces.interface(concat::Concatenated) = concat.interface @@ -80,14 +87,41 @@ function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} return similar(arraytype(interface(concat), T), ax) end -Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) +function cat_axis( + a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... +) + return cat_axis(cat_axis(a1, a2), a_rest...) +end +cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) + +function cat_ndims(dims, as::AbstractArray...) + return max(maximum(dims), maximum(ndims, as)) +end +function cat_ndims(dims::Val, as::AbstractArray...) + return cat_ndims(unval(dims), as...) +end + +function cat_axes(dims, as::AbstractArray...) + return ntuple(cat_ndims(dims, as...)) do dim + if dim ∉ dims + return axes(first(as), dim) + end + return cat_axis(map(ax -> get(ax, dim, Base.OneTo(1)), axes.(as))...) + end +end +function cat_axes(dims::Val, as::AbstractArray...) + return cat_axes(unval(dims), as...) +end -# For now, simply couple back to base implementation function Base.axes(concat::Concatenated) - catdims = Base.dims2cat(dims(concat)) - return Base.cat_size_shape(catdims, concat.args...) + !isnothing(concat.axes) && return concat.axes + return cat_axes(dims(concat), concat.args...) end +Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) +Base.size(concat::Concatenated) = length.(axes(concat)) +Base.ndims(concat::Concatenated) = length(axes(concat)) + # Main logic # ---------- """ diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index bc2f9b6..538b8e7 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -30,3 +30,7 @@ end ) return Base.mapreduce(f, op, as...; kwargs...) end + +function arraytype(::DefaultArrayInterface, T::Type) + return Array{T} +end diff --git a/src/zero.jl b/src/zero.jl index d19cf35..d5c013c 100644 --- a/src/zero.jl +++ b/src/zero.jl @@ -4,3 +4,7 @@ In-place version of `Base.zero`. """ function zero! end + +@derive (T=AbstractArray,) begin + DerivableInterfaces.zero!(::T) +end diff --git a/test/test_concatenate.jl b/test/test_concatenate.jl new file mode 100644 index 0000000..d35674b --- /dev/null +++ b/test/test_concatenate.jl @@ -0,0 +1,31 @@ +using DerivableInterfaces.Concatenate: concatenated +using Test: @test, @testset + +@testset "Concatenated" begin + a = randn(Float32, 2, 2) + b = randn(Float64, 2, 2) + + concat = concatenated((1, 2), a, b) + @test axes(concat) == Base.OneTo.((4, 4)) + @test size(concat) == (4, 4) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims=(1, 2)) + + concat = concatenated(1, a, b) + @test axes(concat) == Base.OneTo.((4, 2)) + @test size(concat) == (4, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims=1) + + concat = concatenated(3, a, b) + @test axes(concat) == Base.OneTo.((2, 2, 2)) + @test size(concat) == (2, 2, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims=3) + + concat = concatenated(4, a, b) + @test axes(concat) == Base.OneTo.((2, 2, 1, 2)) + @test size(concat) == (2, 2, 1, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims=4) +end From c36090fc18a5a57f3817111ad7f8ba511172592f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 11 Mar 2025 08:47:36 -0400 Subject: [PATCH 2/5] Avoid calling private Base functions --- src/concatenate.jl | 50 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 3c6b9e5..33f70d7 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -159,16 +159,56 @@ Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) # default falls back to replacing interface with Nothing # this permits specializing on typeof(dest) without ambiguities # Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. -@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) = - copyto!(dest, convert(Concatenated{Nothing}, concat)) +@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated) + return copyto!(dest, convert(Concatenated{Nothing}, concat)) +end + +_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) +_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) + +cat_size(A) = (1,) +cat_size(A::AbstractArray) = size(A) +cat_size(A, d) = 1 +cat_size(A::AbstractArray, d) = size(A, d) + +cat_indices(A, d) = Base.OneTo(1) +cat_indices(A::AbstractArray, d) = axes(A, d) + +function __cat!(A, shape, catdims, X...) + return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) +end +function __cat_offset!(A, shape, catdims, offsets, x, X...) + # splitting the "work" on x from X... may reduce latency (fewer costly specializations) + newoffsets = __cat_offset1!(A, shape, catdims, offsets, x) + return __cat_offset!(A, shape, catdims, newoffsets, X...) +end +__cat_offset!(A, shape, catdims, offsets) = A +function __cat_offset1!(A, shape, catdims, offsets, x) + inds = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i] + end + _copy_or_fill!(A, inds, x) + newoffsets = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] + end + return newoffsets +end + +dims2cat(dims::Val) = dims2cat(unval(dims)) +function dims2cat(dims) + if any(≤(0), dims) + throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) + end + return ntuple(in(dims), maximum(dims)) +end # couple back to Base implementation if no specialization exists: # https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852 function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) - catdims = Base.dims2cat(dims(concat)) - shape = Base.cat_size_shape(catdims, concat.args...) + catdims = dims2cat(dims(concat)) + shape = size(concat) count(!iszero, catdims)::Int > 1 && zero!(dest) - return Base.__cat(dest, shape, catdims, concat.args...) + return __cat!(dest, shape, catdims, concat.args...) end end From 4bf6ff4ddb3ed445c4e8cc056009dc6f46d59d2a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 11 Mar 2025 14:26:22 -0400 Subject: [PATCH 3/5] Address comments from Lukas --- src/concatenate.jl | 52 +++++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 33f70d7..4fa6bf4 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -48,35 +48,41 @@ struct Concatenated{Interface,Dims,Axes,Args<:Tuple} args::Args axes::Axes global @inline function _Concatenated( - interface::Interface, dims::Val{Dims}, args::Args - ) where {Interface,Dims,Args<:Tuple} - ax = cat_axes(dims, args...) - return new{Interface,Dims,typeof(ax),Args}(interface, dims, args, ax) + interface::Interface, dims::Val{Dims}, args::Args, axes::Axes + ) where {Interface,Dims,Args<:Tuple,Axes} + return new{Interface,Dims,Axes,Args}(interface, dims, args, axes) end end -function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple) - return _Concatenated(interface, dims, args) +function Concatenated( + interface::Union{Nothing,AbstractInterface}, + dims::Val, + args::Tuple, + axes=cat_axes(dims, args...), +) + return _Concatenated(interface, dims, args, axes) end -function Concatenated(dims::Val, args::Tuple) +function Concatenated(dims::Val, args::Tuple, axes=cat_axes(dims, args...)) return Concatenated(interface(args...), dims, args) end -function Concatenated{Interface}(dims::Val, args) where {Interface} +function Concatenated{Interface}( + dims::Val, args::Tuple, axes=cat_axes(dims, args...) +) where {Interface} return Concatenated(Interface(), dims, args) end -dims(::Concatenated{A,D}) where {A,D} = D -DerivableInterfaces.interface(concat::Concatenated) = concat.interface +dims(::Concatenated{<:Any,D}) where {D} = D +DerivableInterfaces.interface(concat::Concatenated) = getfield(concat, :interface) concatenated(dims, args...) = concatenated(Val(dims), args...) concatenated(dims::Val, args...) = Concatenated(dims, args) function Base.convert( - ::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Args} -) where {NewInterface,Dims,Args} + ::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Axes,Args} +) where {NewInterface,Dims,Axes,Args} return Concatenated{NewInterface}( - concat.dims, concat.args - )::Concatenated{NewInterface,Dims,Args} + concat.dims, concat.args, concat.axes + )::Concatenated{NewInterface,Dims,Axes,Args} end # allocating the destination container @@ -101,22 +107,16 @@ function cat_ndims(dims::Val, as::AbstractArray...) return cat_ndims(unval(dims), as...) end -function cat_axes(dims, as::AbstractArray...) - return ntuple(cat_ndims(dims, as...)) do dim - if dim ∉ dims - return axes(first(as), dim) - end - return cat_axis(map(ax -> get(ax, dim, Base.OneTo(1)), axes.(as))...) +function cat_axes(dims, a::AbstractArray, as::AbstractArray...) + return ntuple(cat_ndims(dims, a, as...)) do dim + return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim) end end function cat_axes(dims::Val, as::AbstractArray...) return cat_axes(unval(dims), as...) end -function Base.axes(concat::Concatenated) - !isnothing(concat.axes) && return concat.axes - return cat_axes(dims(concat), concat.args...) -end +Base.axes(concat::Concatenated) = getfield(concat, :axes) Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) Base.size(concat::Concatenated) = length.(axes(concat)) @@ -163,6 +163,8 @@ Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) return copyto!(dest, convert(Concatenated{Nothing}, concat)) end +# The following is largely copied from the Base implementation of `Base.cat`, see: +# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887 _copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) _copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) @@ -202,8 +204,6 @@ function dims2cat(dims) return ntuple(in(dims), maximum(dims)) end -# couple back to Base implementation if no specialization exists: -# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852 function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) catdims = dims2cat(dims(concat)) shape = size(concat) From 66b9fc9d9c8ee2992dbe09c1996d5b8aad1583ac Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 11 Mar 2025 21:52:20 -0400 Subject: [PATCH 4/5] Don't store axes, reorganize code --- src/concatenate.jl | 43 ++++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 4fa6bf4..c059c15 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -37,20 +37,19 @@ unval(::Val{x}) where {x} = x function _Concatenated end """ - Concatenated{Interface,Dims,Axes,Args<:Tuple} + Concatenated{Interface,Dims,Args<:Tuple} Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide hooks to customize the implementation. """ -struct Concatenated{Interface,Dims,Axes,Args<:Tuple} +struct Concatenated{Interface,Dims,Args<:Tuple} interface::Interface dims::Val{Dims} args::Args - axes::Axes global @inline function _Concatenated( - interface::Interface, dims::Val{Dims}, args::Args, axes::Axes - ) where {Interface,Dims,Args<:Tuple,Axes} - return new{Interface,Dims,Axes,Args}(interface, dims, args, axes) + interface::Interface, dims::Val{Dims}, args::Args + ) where {Interface,Dims,Args<:Tuple} + return new{Interface,Dims,Args}(interface, dims, args) end end @@ -58,15 +57,14 @@ function Concatenated( interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple, - axes=cat_axes(dims, args...), ) - return _Concatenated(interface, dims, args, axes) + return _Concatenated(interface, dims, args) end -function Concatenated(dims::Val, args::Tuple, axes=cat_axes(dims, args...)) +function Concatenated(dims::Val, args::Tuple) return Concatenated(interface(args...), dims, args) end function Concatenated{Interface}( - dims::Val, args::Tuple, axes=cat_axes(dims, args...) + dims::Val, args::Tuple ) where {Interface} return Concatenated(Interface(), dims, args) end @@ -78,11 +76,11 @@ concatenated(dims, args...) = concatenated(Val(dims), args...) concatenated(dims::Val, args...) = Concatenated(dims, args) function Base.convert( - ::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Axes,Args} -) where {NewInterface,Dims,Axes,Args} + ::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Args} +) where {NewInterface,Dims,Args} return Concatenated{NewInterface}( - concat.dims, concat.args, concat.axes - )::Concatenated{NewInterface,Dims,Axes,Args} + concat.dims, concat.args + )::Concatenated{NewInterface,Dims,Args} end # allocating the destination container @@ -116,9 +114,8 @@ function cat_axes(dims::Val, as::AbstractArray...) return cat_axes(unval(dims), as...) end -Base.axes(concat::Concatenated) = getfield(concat, :axes) - Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) +Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) Base.size(concat::Concatenated) = length.(axes(concat)) Base.ndims(concat::Concatenated) = length(axes(concat)) @@ -156,13 +153,6 @@ Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat) Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) -# default falls back to replacing interface with Nothing -# this permits specializing on typeof(dest) without ambiguities -# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. -@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated) - return copyto!(dest, convert(Concatenated{Nothing}, concat)) -end - # The following is largely copied from the Base implementation of `Base.cat`, see: # https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887 _copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) @@ -204,6 +194,13 @@ function dims2cat(dims) return ntuple(in(dims), maximum(dims)) end +# default falls back to replacing interface with Nothing +# this permits specializing on typeof(dest) without ambiguities +# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. +@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated) + return copyto!(dest, convert(Concatenated{Nothing}, concat)) +end + function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) catdims = dims2cat(dims(concat)) shape = size(concat) From 9453585fd26176a9922e909166c4329cecce4451 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 11 Mar 2025 22:08:21 -0400 Subject: [PATCH 5/5] Format --- src/concatenate.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index c059c15..fcefee8 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -53,19 +53,13 @@ struct Concatenated{Interface,Dims,Args<:Tuple} end end -function Concatenated( - interface::Union{Nothing,AbstractInterface}, - dims::Val, - args::Tuple, -) +function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple) return _Concatenated(interface, dims, args) end function Concatenated(dims::Val, args::Tuple) return Concatenated(interface(args...), dims, args) end -function Concatenated{Interface}( - dims::Val, args::Tuple -) where {Interface} +function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface} return Concatenated(Interface(), dims, args) end