diff --git a/Project.toml b/Project.toml index 7092f35..eb5847e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DerivableInterfaces" uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" authors = ["ITensor developers and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -13,9 +13,16 @@ MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" +[weakdeps] +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + +[extensions] +DerivableInterfacesBlockArraysExt = "BlockArrays" + [compat] Adapt = "4.1.1" -ArrayLayouts = "1.11.0" +ArrayLayouts = "1.11" +BlockArrays = "1.4" Compat = "3.47,4.10" ExproniconLite = "0.10.13" LinearAlgebra = "1.10" diff --git a/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl b/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl new file mode 100644 index 0000000..000fe6b --- /dev/null +++ b/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl @@ -0,0 +1,10 @@ +module DerivableInterfacesBlockArraysExt + +using BlockArrays: BlockedOneTo, blockedrange, blocklengths +using DerivableInterfaces.Concatenate: Concatenate + +function Concatenate.cat_axis(a1::BlockedOneTo, a2::BlockedOneTo) + return blockedrange([blocklengths(a1); blocklengths(a2)]) +end + +end diff --git a/src/concatenate.jl b/src/concatenate.jl index 01229d2..fcefee8 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -31,6 +31,11 @@ using Base: promote_eltypeof using ..DerivableInterfaces: DerivableInterfaces, AbstractInterface, interface, zero!, arraytype +unval(x) = x +unval(::Val{x}) where {x} = x + +function _Concatenated end + """ Concatenated{Interface,Dims,Args<:Tuple} @@ -41,25 +46,25 @@ struct Concatenated{Interface,Dims,Args<:Tuple} interface::Interface dims::Val{Dims} args::Args - - function Concatenated( - interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple - ) where {Dims} - return new{typeof(interface),Dims,typeof(args)}(interface, dims, args) - end - function Concatenated(dims, args::Tuple) - return Concatenated(interface(args...), dims, args) - end - function Concatenated{Interface}(dims, args) where {Interface} - return Concatenated(Interface(), dims, args) - end - function Concatenated{Interface,Dims}(args) where {Interface,Dims} - return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args) + global @inline function _Concatenated( + interface::Interface, dims::Val{Dims}, args::Args + ) where {Interface,Dims,Args<:Tuple} + return new{Interface,Dims,Args}(interface, dims, args) end end -dims(::Concatenated{A,D}) where {A,D} = D -DerivableInterfaces.interface(concat::Concatenated) = concat.interface +function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple) + return _Concatenated(interface, dims, args) +end +function Concatenated(dims::Val, args::Tuple) + return Concatenated(interface(args...), dims, args) +end +function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface} + return Concatenated(Interface(), dims, args) +end + +dims(::Concatenated{<:Any,D}) where {D} = D +DerivableInterfaces.interface(concat::Concatenated) = getfield(concat, :interface) concatenated(dims, args...) = concatenated(Val(dims), args...) concatenated(dims::Val, args...) = Concatenated(dims, args) @@ -80,13 +85,33 @@ function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} return similar(arraytype(interface(concat), T), ax) end -Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) +function cat_axis( + a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... +) + return cat_axis(cat_axis(a1, a2), a_rest...) +end +cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) -# For now, simply couple back to base implementation -function Base.axes(concat::Concatenated) - catdims = Base.dims2cat(dims(concat)) - return Base.cat_size_shape(catdims, concat.args...) +function cat_ndims(dims, as::AbstractArray...) + return max(maximum(dims), maximum(ndims, as)) +end +function cat_ndims(dims::Val, as::AbstractArray...) + return cat_ndims(unval(dims), as...) +end + +function cat_axes(dims, a::AbstractArray, as::AbstractArray...) + return ntuple(cat_ndims(dims, a, as...)) do dim + return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim) + end end +function cat_axes(dims::Val, as::AbstractArray...) + return cat_axes(unval(dims), as...) +end + +Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) +Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) +Base.size(concat::Concatenated) = length.(axes(concat)) +Base.ndims(concat::Concatenated) = length(axes(concat)) # Main logic # ---------- @@ -122,19 +147,59 @@ Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat) Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) +# The following is largely copied from the Base implementation of `Base.cat`, see: +# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887 +_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) +_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) + +cat_size(A) = (1,) +cat_size(A::AbstractArray) = size(A) +cat_size(A, d) = 1 +cat_size(A::AbstractArray, d) = size(A, d) + +cat_indices(A, d) = Base.OneTo(1) +cat_indices(A::AbstractArray, d) = axes(A, d) + +function __cat!(A, shape, catdims, X...) + return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) +end +function __cat_offset!(A, shape, catdims, offsets, x, X...) + # splitting the "work" on x from X... may reduce latency (fewer costly specializations) + newoffsets = __cat_offset1!(A, shape, catdims, offsets, x) + return __cat_offset!(A, shape, catdims, newoffsets, X...) +end +__cat_offset!(A, shape, catdims, offsets) = A +function __cat_offset1!(A, shape, catdims, offsets, x) + inds = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i] + end + _copy_or_fill!(A, inds, x) + newoffsets = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] + end + return newoffsets +end + +dims2cat(dims::Val) = dims2cat(unval(dims)) +function dims2cat(dims) + if any(≤(0), dims) + throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) + end + return ntuple(in(dims), maximum(dims)) +end + # default falls back to replacing interface with Nothing # this permits specializing on typeof(dest) without ambiguities # Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. -@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) = - copyto!(dest, convert(Concatenated{Nothing}, concat)) +@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated) + return copyto!(dest, convert(Concatenated{Nothing}, concat)) +end -# couple back to Base implementation if no specialization exists: -# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852 function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) - catdims = Base.dims2cat(dims(concat)) - shape = Base.cat_size_shape(catdims, concat.args...) + catdims = dims2cat(dims(concat)) + shape = size(concat) count(!iszero, catdims)::Int > 1 && zero!(dest) - return Base.__cat(dest, shape, catdims, concat.args...) + return __cat!(dest, shape, catdims, concat.args...) end end diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index bc2f9b6..538b8e7 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -30,3 +30,7 @@ end ) return Base.mapreduce(f, op, as...; kwargs...) end + +function arraytype(::DefaultArrayInterface, T::Type) + return Array{T} +end diff --git a/src/zero.jl b/src/zero.jl index d19cf35..d5c013c 100644 --- a/src/zero.jl +++ b/src/zero.jl @@ -4,3 +4,7 @@ In-place version of `Base.zero`. """ function zero! end + +@derive (T=AbstractArray,) begin + DerivableInterfaces.zero!(::T) +end diff --git a/test/test_concatenate.jl b/test/test_concatenate.jl new file mode 100644 index 0000000..d35674b --- /dev/null +++ b/test/test_concatenate.jl @@ -0,0 +1,31 @@ +using DerivableInterfaces.Concatenate: concatenated +using Test: @test, @testset + +@testset "Concatenated" begin + a = randn(Float32, 2, 2) + b = randn(Float64, 2, 2) + + concat = concatenated((1, 2), a, b) + @test axes(concat) == Base.OneTo.((4, 4)) + @test size(concat) == (4, 4) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims=(1, 2)) + + concat = concatenated(1, a, b) + @test axes(concat) == Base.OneTo.((4, 2)) + @test size(concat) == (4, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims=1) + + concat = concatenated(3, a, b) + @test axes(concat) == Base.OneTo.((2, 2, 2)) + @test size(concat) == (2, 2, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims=3) + + concat = concatenated(4, a, b) + @test axes(concat) == Base.OneTo.((2, 2, 1, 2)) + @test size(concat) == (2, 2, 1, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims=4) +end