Skip to content

More customization points in Concatenate #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DerivableInterfaces"
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.0"
version = "0.4.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -13,9 +13,16 @@ MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"

[extensions]
DerivableInterfacesBlockArraysExt = "BlockArrays"

[compat]
Adapt = "4.1.1"
ArrayLayouts = "1.11.0"
ArrayLayouts = "1.11"
BlockArrays = "1.4"
Compat = "3.47,4.10"
ExproniconLite = "0.10.13"
LinearAlgebra = "1.10"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module DerivableInterfacesBlockArraysExt

using BlockArrays: BlockedOneTo, blockedrange, blocklengths
using DerivableInterfaces.Concatenate: Concatenate

function Concatenate.cat_axis(a1::BlockedOneTo, a2::BlockedOneTo)
return blockedrange([blocklengths(a1); blocklengths(a2)])

Check warning on line 7 in ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl#L6-L7

Added lines #L6 - L7 were not covered by tests
end

end
142 changes: 108 additions & 34 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,45 +31,58 @@
using ..DerivableInterfaces:
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype

unval(x) = x

Check warning on line 34 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L34

Added line #L34 was not covered by tests
unval(::Val{x}) where {x} = x

function _Concatenated end

"""
Concatenated{Interface,Dims,Args<:Tuple}
Concatenated{Interface,Dims,Axes,Args<:Tuple}

Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide
hooks to customize the implementation.
"""
struct Concatenated{Interface,Dims,Args<:Tuple}
struct Concatenated{Interface,Dims,Axes,Args<:Tuple}
interface::Interface
dims::Val{Dims}
args::Args

function Concatenated(
interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple
) where {Dims}
return new{typeof(interface),Dims,typeof(args)}(interface, dims, args)
end
function Concatenated(dims, args::Tuple)
return Concatenated(interface(args...), dims, args)
end
function Concatenated{Interface}(dims, args) where {Interface}
return Concatenated(Interface(), dims, args)
end
function Concatenated{Interface,Dims}(args) where {Interface,Dims}
return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args)
axes::Axes
global @inline function _Concatenated(
interface::Interface, dims::Val{Dims}, args::Args, axes::Axes
) where {Interface,Dims,Args<:Tuple,Axes}
return new{Interface,Dims,Axes,Args}(interface, dims, args, axes)
end
end

dims(::Concatenated{A,D}) where {A,D} = D
DerivableInterfaces.interface(concat::Concatenated) = concat.interface
function Concatenated(
interface::Union{Nothing,AbstractInterface},
dims::Val,
args::Tuple,
axes=cat_axes(dims, args...),
)
return _Concatenated(interface, dims, args, axes)
end
function Concatenated(dims::Val, args::Tuple, axes=cat_axes(dims, args...))
return Concatenated(interface(args...), dims, args)
end
function Concatenated{Interface}(
dims::Val, args::Tuple, axes=cat_axes(dims, args...)
) where {Interface}
return Concatenated(Interface(), dims, args)
end

dims(::Concatenated{<:Any,D}) where {D} = D
DerivableInterfaces.interface(concat::Concatenated) = getfield(concat, :interface)

concatenated(dims, args...) = concatenated(Val(dims), args...)
concatenated(dims::Val, args...) = Concatenated(dims, args)

function Base.convert(
::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Args}
) where {NewInterface,Dims,Args}
::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Axes,Args}
) where {NewInterface,Dims,Axes,Args}
return Concatenated{NewInterface}(
concat.dims, concat.args
)::Concatenated{NewInterface,Dims,Args}
concat.dims, concat.args, concat.axes
)::Concatenated{NewInterface,Dims,Axes,Args}
end

# allocating the destination container
Expand All @@ -80,14 +93,35 @@
return similar(arraytype(interface(concat), T), ax)
end

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
function cat_axis(

Check warning on line 96 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L96

Added line #L96 was not covered by tests
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return cat_axis(cat_axis(a1, a2), a_rest...)

Check warning on line 99 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L99

Added line #L99 was not covered by tests
end
cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))

function cat_ndims(dims, as::AbstractArray...)
return max(maximum(dims), maximum(ndims, as))
end
function cat_ndims(dims::Val, as::AbstractArray...)
return cat_ndims(unval(dims), as...)

Check warning on line 107 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
end

# For now, simply couple back to base implementation
function Base.axes(concat::Concatenated)
catdims = Base.dims2cat(dims(concat))
return Base.cat_size_shape(catdims, concat.args...)
function cat_axes(dims, a::AbstractArray, as::AbstractArray...)
return ntuple(cat_ndims(dims, a, as...)) do dim
return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim)
end
end
function cat_axes(dims::Val, as::AbstractArray...)
return cat_axes(unval(dims), as...)
end

Base.axes(concat::Concatenated) = getfield(concat, :axes)

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
Base.size(concat::Concatenated) = length.(axes(concat))
Base.ndims(concat::Concatenated) = length(axes(concat))

Check warning on line 123 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L123

Added line #L123 was not covered by tests

# Main logic
# ----------
"""
Expand Down Expand Up @@ -125,16 +159,56 @@
# default falls back to replacing interface with Nothing
# this permits specializing on typeof(dest) without ambiguities
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) =
copyto!(dest, convert(Concatenated{Nothing}, concat))
@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated)
return copyto!(dest, convert(Concatenated{Nothing}, concat))
end

# The following is largely copied from the Base implementation of `Base.cat`, see:
# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887
_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x)

Check warning on line 168 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L168

Added line #L168 was not covered by tests
_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x)

cat_size(A) = (1,)
cat_size(A::AbstractArray) = size(A)
cat_size(A, d) = 1

Check warning on line 173 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L171-L173

Added lines #L171 - L173 were not covered by tests
cat_size(A::AbstractArray, d) = size(A, d)

cat_indices(A, d) = Base.OneTo(1)

Check warning on line 176 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L176

Added line #L176 was not covered by tests
cat_indices(A::AbstractArray, d) = axes(A, d)

function __cat!(A, shape, catdims, X...)
return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
end
function __cat_offset!(A, shape, catdims, offsets, x, X...)
# splitting the "work" on x from X... may reduce latency (fewer costly specializations)
newoffsets = __cat_offset1!(A, shape, catdims, offsets, x)
return __cat_offset!(A, shape, catdims, newoffsets, X...)
end
__cat_offset!(A, shape, catdims, offsets) = A
function __cat_offset1!(A, shape, catdims, offsets, x)
inds = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
end
_copy_or_fill!(A, inds, x)
newoffsets = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
end
return newoffsets
end

dims2cat(dims::Val) = dims2cat(unval(dims))

Check warning on line 199 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L199

Added line #L199 was not covered by tests
function dims2cat(dims)
if any(≤(0), dims)
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))

Check warning on line 202 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L202

Added line #L202 was not covered by tests
end
return ntuple(in(dims), maximum(dims))
end

# couple back to Base implementation if no specialization exists:
# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852
function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing})
catdims = Base.dims2cat(dims(concat))
shape = Base.cat_size_shape(catdims, concat.args...)
catdims = dims2cat(dims(concat))
shape = size(concat)
count(!iszero, catdims)::Int > 1 && zero!(dest)
return Base.__cat(dest, shape, catdims, concat.args...)
return __cat!(dest, shape, catdims, concat.args...)
end

end
4 changes: 4 additions & 0 deletions src/defaultarrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ end
)
return Base.mapreduce(f, op, as...; kwargs...)
end

function arraytype(::DefaultArrayInterface, T::Type)
return Array{T}
end
4 changes: 4 additions & 0 deletions src/zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
In-place version of `Base.zero`.
"""
function zero! end

@derive (T=AbstractArray,) begin
DerivableInterfaces.zero!(::T)
end
31 changes: 31 additions & 0 deletions test/test_concatenate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using DerivableInterfaces.Concatenate: concatenated
using Test: @test, @testset

@testset "Concatenated" begin
a = randn(Float32, 2, 2)
b = randn(Float64, 2, 2)

concat = concatenated((1, 2), a, b)
@test axes(concat) == Base.OneTo.((4, 4))
@test size(concat) == (4, 4)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=(1, 2))

concat = concatenated(1, a, b)
@test axes(concat) == Base.OneTo.((4, 2))
@test size(concat) == (4, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=1)

concat = concatenated(3, a, b)
@test axes(concat) == Base.OneTo.((2, 2, 2))
@test size(concat) == (2, 2, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=3)

concat = concatenated(4, a, b)
@test axes(concat) == Base.OneTo.((2, 2, 1, 2))
@test size(concat) == (2, 2, 1, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=4)
end
Loading