Skip to content

More customization points in Concatenate #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DerivableInterfaces"
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.0"
version = "0.4.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -13,9 +13,16 @@ MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"

[extensions]
DerivableInterfacesBlockArraysExt = "BlockArrays"

[compat]
Adapt = "4.1.1"
ArrayLayouts = "1.11.0"
ArrayLayouts = "1.11"
BlockArrays = "1.4"
Compat = "3.47,4.10"
ExproniconLite = "0.10.13"
LinearAlgebra = "1.10"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module DerivableInterfacesBlockArraysExt

using BlockArrays: BlockedOneTo, blockedrange, blocklengths
using DerivableInterfaces.Concatenate: Concatenate

function Concatenate.cat_axis(a1::BlockedOneTo, a2::BlockedOneTo)
return blockedrange([blocklengths(a1); blocklengths(a2)])

Check warning on line 7 in ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl#L6-L7

Added lines #L6 - L7 were not covered by tests
end

end
124 changes: 99 additions & 25 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,40 @@
using ..DerivableInterfaces:
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype

unval(x) = x

Check warning on line 34 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L34

Added line #L34 was not covered by tests
unval(::Val{x}) where {x} = x

function _Concatenated end

"""
Concatenated{Interface,Dims,Args<:Tuple}
Concatenated{Interface,Dims,Axes,Args<:Tuple}

Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide
hooks to customize the implementation.
"""
struct Concatenated{Interface,Dims,Args<:Tuple}
struct Concatenated{Interface,Dims,Axes,Args<:Tuple}
interface::Interface
dims::Val{Dims}
args::Args

function Concatenated(
interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple
) where {Dims}
return new{typeof(interface),Dims,typeof(args)}(interface, dims, args)
end
function Concatenated(dims, args::Tuple)
return Concatenated(interface(args...), dims, args)
end
function Concatenated{Interface}(dims, args) where {Interface}
return Concatenated(Interface(), dims, args)
end
function Concatenated{Interface,Dims}(args) where {Interface,Dims}
return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args)
axes::Axes
global @inline function _Concatenated(
interface::Interface, dims::Val{Dims}, args::Args
) where {Interface,Dims,Args<:Tuple}
ax = cat_axes(dims, args...)
return new{Interface,Dims,typeof(ax),Args}(interface, dims, args, ax)
end
end

function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple)
return _Concatenated(interface, dims, args)
end
function Concatenated(dims::Val, args::Tuple)
return Concatenated(interface(args...), dims, args)
end
function Concatenated{Interface}(dims::Val, args) where {Interface}
return Concatenated(Interface(), dims, args)
end

dims(::Concatenated{A,D}) where {A,D} = D
DerivableInterfaces.interface(concat::Concatenated) = concat.interface

Expand All @@ -80,14 +87,41 @@
return similar(arraytype(interface(concat), T), ax)
end

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
function cat_axis(

Check warning on line 90 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L90

Added line #L90 was not covered by tests
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return cat_axis(cat_axis(a1, a2), a_rest...)

Check warning on line 93 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L93

Added line #L93 was not covered by tests
end
cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))

function cat_ndims(dims, as::AbstractArray...)
return max(maximum(dims), maximum(ndims, as))
end
function cat_ndims(dims::Val, as::AbstractArray...)
return cat_ndims(unval(dims), as...)

Check warning on line 101 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L100-L101

Added lines #L100 - L101 were not covered by tests
end

function cat_axes(dims, as::AbstractArray...)
return ntuple(cat_ndims(dims, as...)) do dim
if dim ∉ dims
return axes(first(as), dim)
end
return cat_axis(map(ax -> get(ax, dim, Base.OneTo(1)), axes.(as))...)
end
end
function cat_axes(dims::Val, as::AbstractArray...)
return cat_axes(unval(dims), as...)
end

# For now, simply couple back to base implementation
function Base.axes(concat::Concatenated)
catdims = Base.dims2cat(dims(concat))
return Base.cat_size_shape(catdims, concat.args...)
!isnothing(concat.axes) && return concat.axes
return cat_axes(dims(concat), concat.args...)

Check warning on line 118 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L118

Added line #L118 was not covered by tests
end

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
Base.size(concat::Concatenated) = length.(axes(concat))
Base.ndims(concat::Concatenated) = length(axes(concat))

Check warning on line 123 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L123

Added line #L123 was not covered by tests

# Main logic
# ----------
"""
Expand Down Expand Up @@ -125,16 +159,56 @@
# default falls back to replacing interface with Nothing
# this permits specializing on typeof(dest) without ambiguities
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) =
copyto!(dest, convert(Concatenated{Nothing}, concat))
@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated)
return copyto!(dest, convert(Concatenated{Nothing}, concat))
end

_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x)

Check warning on line 166 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L166

Added line #L166 was not covered by tests
_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x)

cat_size(A) = (1,)
cat_size(A::AbstractArray) = size(A)
cat_size(A, d) = 1

Check warning on line 171 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L169-L171

Added lines #L169 - L171 were not covered by tests
cat_size(A::AbstractArray, d) = size(A, d)

cat_indices(A, d) = Base.OneTo(1)

Check warning on line 174 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L174

Added line #L174 was not covered by tests
cat_indices(A::AbstractArray, d) = axes(A, d)

function __cat!(A, shape, catdims, X...)
return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
end
function __cat_offset!(A, shape, catdims, offsets, x, X...)
# splitting the "work" on x from X... may reduce latency (fewer costly specializations)
newoffsets = __cat_offset1!(A, shape, catdims, offsets, x)
return __cat_offset!(A, shape, catdims, newoffsets, X...)
end
__cat_offset!(A, shape, catdims, offsets) = A
function __cat_offset1!(A, shape, catdims, offsets, x)
inds = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
end
_copy_or_fill!(A, inds, x)
newoffsets = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
end
return newoffsets
end

dims2cat(dims::Val) = dims2cat(unval(dims))

Check warning on line 197 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L197

Added line #L197 was not covered by tests
function dims2cat(dims)
if any(≤(0), dims)
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))

Check warning on line 200 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L200

Added line #L200 was not covered by tests
end
return ntuple(in(dims), maximum(dims))
end

# couple back to Base implementation if no specialization exists:
# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852
function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing})
catdims = Base.dims2cat(dims(concat))
shape = Base.cat_size_shape(catdims, concat.args...)
catdims = dims2cat(dims(concat))
shape = size(concat)
count(!iszero, catdims)::Int > 1 && zero!(dest)
return Base.__cat(dest, shape, catdims, concat.args...)
return __cat!(dest, shape, catdims, concat.args...)
end

end
4 changes: 4 additions & 0 deletions src/defaultarrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ end
)
return Base.mapreduce(f, op, as...; kwargs...)
end

function arraytype(::DefaultArrayInterface, T::Type)
return Array{T}
end
4 changes: 4 additions & 0 deletions src/zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
In-place version of `Base.zero`.
"""
function zero! end

@derive (T=AbstractArray,) begin
DerivableInterfaces.zero!(::T)
end
31 changes: 31 additions & 0 deletions test/test_concatenate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using DerivableInterfaces.Concatenate: concatenated
using Test: @test, @testset

@testset "Concatenated" begin
a = randn(Float32, 2, 2)
b = randn(Float64, 2, 2)

concat = concatenated((1, 2), a, b)
@test axes(concat) == Base.OneTo.((4, 4))
@test size(concat) == (4, 4)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=(1, 2))

concat = concatenated(1, a, b)
@test axes(concat) == Base.OneTo.((4, 2))
@test size(concat) == (4, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=1)

concat = concatenated(3, a, b)
@test axes(concat) == Base.OneTo.((2, 2, 2))
@test size(concat) == (2, 2, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=3)

concat = concatenated(4, a, b)
@test axes(concat) == Base.OneTo.((2, 2, 1, 2))
@test size(concat) == (2, 2, 1, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=4)
end
Loading