Skip to content

More customization points in Concatenate #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DerivableInterfaces"
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.0"
version = "0.4.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -13,9 +13,16 @@ MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"

[extensions]
DerivableInterfacesBlockArraysExt = "BlockArrays"

[compat]
Adapt = "4.1.1"
ArrayLayouts = "1.11.0"
ArrayLayouts = "1.11"
BlockArrays = "1.4"
Compat = "3.47,4.10"
ExproniconLite = "0.10.13"
LinearAlgebra = "1.10"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module DerivableInterfacesBlockArraysExt

using BlockArrays: BlockedOneTo, blockedrange, blocklengths
using DerivableInterfaces.Concatenate: Concatenate

function Concatenate.cat_axis(a1::BlockedOneTo, a2::BlockedOneTo)
return blockedrange([blocklengths(a1); blocklengths(a2)])

Check warning on line 7 in ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl#L6-L7

Added lines #L6 - L7 were not covered by tests
end

end
121 changes: 93 additions & 28 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
using ..DerivableInterfaces:
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype

unval(x) = x
unval(::Val{x}) where {x} = x

Check warning on line 35 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L34-L35

Added lines #L34 - L35 were not covered by tests

function _Concatenated end

"""
Concatenated{Interface,Dims,Args<:Tuple}

Expand All @@ -41,25 +46,25 @@
interface::Interface
dims::Val{Dims}
args::Args

function Concatenated(
interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple
) where {Dims}
return new{typeof(interface),Dims,typeof(args)}(interface, dims, args)
end
function Concatenated(dims, args::Tuple)
return Concatenated(interface(args...), dims, args)
end
function Concatenated{Interface}(dims, args) where {Interface}
return Concatenated(Interface(), dims, args)
end
function Concatenated{Interface,Dims}(args) where {Interface,Dims}
return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args)
global @inline function _Concatenated(
interface::Interface, dims::Val{Dims}, args::Args
) where {Interface,Dims,Args<:Tuple}
return new{Interface,Dims,Args}(interface, dims, args)
end
end

dims(::Concatenated{A,D}) where {A,D} = D
DerivableInterfaces.interface(concat::Concatenated) = concat.interface
function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple)
return _Concatenated(interface, dims, args)
end
function Concatenated(dims::Val, args::Tuple)
return Concatenated(interface(args...), dims, args)
end
function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface}
return Concatenated(Interface(), dims, args)
end

dims(::Concatenated{<:Any,D}) where {D} = D
DerivableInterfaces.interface(concat::Concatenated) = getfield(concat, :interface)

concatenated(dims, args...) = concatenated(Val(dims), args...)
concatenated(dims::Val, args...) = Concatenated(dims, args)
Expand All @@ -80,13 +85,33 @@
return similar(arraytype(interface(concat), T), ax)
end

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
function cat_axis(

Check warning on line 88 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L88

Added line #L88 was not covered by tests
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return cat_axis(cat_axis(a1, a2), a_rest...)

Check warning on line 91 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L91

Added line #L91 was not covered by tests
end
cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))

# For now, simply couple back to base implementation
function Base.axes(concat::Concatenated)
catdims = Base.dims2cat(dims(concat))
return Base.cat_size_shape(catdims, concat.args...)
function cat_ndims(dims, as::AbstractArray...)
return max(maximum(dims), maximum(ndims, as))
end
function cat_ndims(dims::Val, as::AbstractArray...)
return cat_ndims(unval(dims), as...)

Check warning on line 99 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L98-L99

Added lines #L98 - L99 were not covered by tests
end

function cat_axes(dims, a::AbstractArray, as::AbstractArray...)
return ntuple(cat_ndims(dims, a, as...)) do dim
return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim)
end
end
function cat_axes(dims::Val, as::AbstractArray...)
return cat_axes(unval(dims), as...)

Check warning on line 108 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...)
Base.size(concat::Concatenated) = length.(axes(concat))
Base.ndims(concat::Concatenated) = length(axes(concat))

Check warning on line 114 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L114

Added line #L114 was not covered by tests

# Main logic
# ----------
Expand Down Expand Up @@ -122,19 +147,59 @@

Base.copy(concat::Concatenated) = copyto!(similar(concat), concat)

# The following is largely copied from the Base implementation of `Base.cat`, see:
# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887
_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x)

Check warning on line 152 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L152

Added line #L152 was not covered by tests
_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x)

cat_size(A) = (1,)
cat_size(A::AbstractArray) = size(A)
cat_size(A, d) = 1

Check warning on line 157 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L155-L157

Added lines #L155 - L157 were not covered by tests
cat_size(A::AbstractArray, d) = size(A, d)

cat_indices(A, d) = Base.OneTo(1)

Check warning on line 160 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L160

Added line #L160 was not covered by tests
cat_indices(A::AbstractArray, d) = axes(A, d)

function __cat!(A, shape, catdims, X...)
return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
end
function __cat_offset!(A, shape, catdims, offsets, x, X...)
# splitting the "work" on x from X... may reduce latency (fewer costly specializations)
newoffsets = __cat_offset1!(A, shape, catdims, offsets, x)
return __cat_offset!(A, shape, catdims, newoffsets, X...)
end
__cat_offset!(A, shape, catdims, offsets) = A
function __cat_offset1!(A, shape, catdims, offsets, x)
inds = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
end
_copy_or_fill!(A, inds, x)
newoffsets = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
end
return newoffsets
end

dims2cat(dims::Val) = dims2cat(unval(dims))

Check warning on line 183 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L183

Added line #L183 was not covered by tests
function dims2cat(dims)
if any(≤(0), dims)
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))

Check warning on line 186 in src/concatenate.jl

View check run for this annotation

Codecov / codecov/patch

src/concatenate.jl#L186

Added line #L186 was not covered by tests
end
return ntuple(in(dims), maximum(dims))
end

# default falls back to replacing interface with Nothing
# this permits specializing on typeof(dest) without ambiguities
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) =
copyto!(dest, convert(Concatenated{Nothing}, concat))
@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated)
return copyto!(dest, convert(Concatenated{Nothing}, concat))
end

# couple back to Base implementation if no specialization exists:
# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852
function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing})
catdims = Base.dims2cat(dims(concat))
shape = Base.cat_size_shape(catdims, concat.args...)
catdims = dims2cat(dims(concat))
shape = size(concat)
count(!iszero, catdims)::Int > 1 && zero!(dest)
return Base.__cat(dest, shape, catdims, concat.args...)
return __cat!(dest, shape, catdims, concat.args...)
end

end
4 changes: 4 additions & 0 deletions src/defaultarrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ end
)
return Base.mapreduce(f, op, as...; kwargs...)
end

function arraytype(::DefaultArrayInterface, T::Type)
return Array{T}
end
4 changes: 4 additions & 0 deletions src/zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
In-place version of `Base.zero`.
"""
function zero! end

@derive (T=AbstractArray,) begin
DerivableInterfaces.zero!(::T)
end
31 changes: 31 additions & 0 deletions test/test_concatenate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using DerivableInterfaces.Concatenate: concatenated
using Test: @test, @testset

@testset "Concatenated" begin
a = randn(Float32, 2, 2)
b = randn(Float64, 2, 2)

concat = concatenated((1, 2), a, b)
@test axes(concat) == Base.OneTo.((4, 4))
@test size(concat) == (4, 4)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=(1, 2))

concat = concatenated(1, a, b)
@test axes(concat) == Base.OneTo.((4, 2))
@test size(concat) == (4, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=1)

concat = concatenated(3, a, b)
@test axes(concat) == Base.OneTo.((2, 2, 2))
@test size(concat) == (2, 2, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=3)

concat = concatenated(4, a, b)
@test axes(concat) == Base.OneTo.((2, 2, 1, 2))
@test size(concat) == (2, 2, 1, 2)
@test eltype(concat) === Float64
@test copy(concat) == cat(a, b; dims=4)
end
Loading