From 09491bd91c09deffa2d49bbe1f9e1387b6f7b713 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 22 Jun 2025 18:51:47 -0400 Subject: [PATCH 1/5] Introduce CartesianPair --- Project.toml | 4 +- .../KroneckerArraysBlockSparseArraysExt.jl | 15 +++++- src/cartesianproduct.jl | 46 +++++++++++++++---- src/kroneckerarray.jl | 3 ++ 4 files changed, 57 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 2921538..43ecb60 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.18" +version = "0.1.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] [compat] Adapt = "4.3.0" BlockArrays = "1.6" -BlockSparseArrays = "0.7.19" +BlockSparseArrays = "0.7.20" DerivableInterfaces = "0.5.0" DiagonalArrays = "0.3.5" FillArrays = "1.13.0" diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index 91f579f..114674e 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -1,7 +1,20 @@ module KroneckerArraysBlockSparseArraysExt +using BlockArrays: Block +using BlockSparseArrays: BlockIndexVector, GenericBlockIndex +using KroneckerArrays: CartesianPair, CartesianProduct +function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...) + return GenericBlockIndex(b, (I1, Irest...)) +end +function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...) + return BlockIndexVector(b, (I1, Irest...)) +end + using BlockSparseArrays: BlockSparseArrays, blockrange -using KroneckerArrays: CartesianProduct, cartesianrange +using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange +function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair}) + return blockrange(map(cartesianrange, bs)) +end function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct}) return blockrange(map(cartesianrange, bs)) end diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index c77a2f5..2382aea 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -1,4 +1,22 @@ -struct CartesianProduct{A,B} +struct CartesianPair{A,B} + a::A + b::B +end +arguments(a::CartesianPair) = (a.a, a.b) +arguments(a::CartesianPair, n::Int) = arguments(a)[n] + +arg1(a::CartesianPair) = a.a +arg2(a::CartesianPair) = a.b + +×(a, b) = CartesianPair(a, b) + +function Base.show(io::IO, a::CartesianPair) + print(io, a.a, " × ", a.b) + return nothing +end + +struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <: + AbstractVector{CartesianPair{TA,TB}} a::A b::B end @@ -13,15 +31,19 @@ function Base.show(io::IO, a::CartesianProduct) return nothing end -×(a, b) = CartesianProduct(a, b) +×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b) Base.length(a::CartesianProduct) = length(a.a) * length(a.b) -Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b] +Base.size(a::CartesianProduct) = (length(a),) -function Base.iterate(a::CartesianProduct, state...) - x = iterate(Iterators.product(a.a, a.b), state...) - isnothing(x) && return x - next, new_state = x - return ×(next...), new_state +function Base.getindex(a::CartesianProduct, i::CartesianProduct) + return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] +end +function Base.getindex(a::CartesianProduct, i::CartesianPair) + return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] +end +function Base.getindex(a::CartesianProduct, i::Int) + I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i]) + return a[I[1] × I[2]] end struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <: @@ -47,10 +69,18 @@ end to_range(a::AbstractUnitRange) = a to_range(i::Integer) = Base.OneTo(i) cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b)) +function cartesianrange(p::CartesianPair) + p′ = to_range(p.a) × to_range(p.b) + return cartesianrange(p′) +end function cartesianrange(p::CartesianProduct) p′ = to_range(p.a) × to_range(p.b) return cartesianrange(p′, Base.OneTo(length(p′))) end +function cartesianrange(p::CartesianPair, range::AbstractUnitRange) + p′ = to_range(p.a) × to_range(p.b) + return cartesianrange(p′, range) +end function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) p′ = to_range(p.a) × to_range(p.b) return CartesianProductUnitRange(p′, range) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 82b4357..b5caff3 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -188,6 +188,9 @@ end function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...] end +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} + return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...] +end # Fix ambigiuity error. Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[] From 24a1371ff48db37f65eb8f2993d9c881f16ba679 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 22 Jun 2025 18:56:50 -0400 Subject: [PATCH 2/5] Fix tests --- src/fillarrays/kroneckerarray.jl | 2 ++ src/kroneckerarray.jl | 9 ++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index f132dbf..db51fc9 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -21,6 +21,8 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} +_getindex(a::Eye, I1::Colon, I2::Colon) = a + # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index b5caff3..fe7cbcb 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -182,14 +182,13 @@ function Base.getindex(a::KroneckerVector, i::Integer) return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] end -## function Base.getindex(a::KroneckerVector, i::CartesianProduct) -## return a.a[i.a] ⊗ a.b[i.b] -## end +# Allow customizing for `FillArrays.Eye`. +_getindex(a::AbstractArray, I...) = a[I...] function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} - return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...] + return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) end function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} - return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...] + return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) end # Fix ambigiuity error. Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[] From 313374cc5eca9db899e587936187a3b88df4ba46 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 22 Jun 2025 19:20:10 -0400 Subject: [PATCH 3/5] Update style --- src/cartesianproduct.jl | 20 +++--- src/fillarrays/matrixalgebrakit_truncate.jl | 12 ++-- src/kroneckerarray.jl | 80 ++++++++++----------- 3 files changed, 55 insertions(+), 57 deletions(-) diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index 2382aea..3934bd8 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -66,28 +66,28 @@ end function CartesianProductUnitRange(a, b) return CartesianProductUnitRange(a × b) end -to_range(a::AbstractUnitRange) = a -to_range(i::Integer) = Base.OneTo(i) -cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b)) +to_product_indices(a::AbstractUnitRange) = a +to_product_indices(i::Integer) = Base.OneTo(i) +cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b)) function cartesianrange(p::CartesianPair) - p′ = to_range(p.a) × to_range(p.b) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) return cartesianrange(p′) end function cartesianrange(p::CartesianProduct) - p′ = to_range(p.a) × to_range(p.b) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) return cartesianrange(p′, Base.OneTo(length(p′))) end function cartesianrange(p::CartesianPair, range::AbstractUnitRange) - p′ = to_range(p.a) × to_range(p.b) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) return cartesianrange(p′, range) end function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) - p′ = to_range(p.a) × to_range(p.b) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) return CartesianProductUnitRange(p′, range) end function Base.axes(r::CartesianProductUnitRange) - return (CartesianProductUnitRange(r.product, only(axes(r.range))),) + return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),) end using Base.Broadcast: DefaultArrayStyle @@ -96,12 +96,12 @@ for f in (:+, :-) function Broadcast.broadcasted( ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer ) - return CartesianProductUnitRange(r.product, $f.(r.range, x)) + return CartesianProductUnitRange(cartesianproduct(r), $f.(unproduct(r), x)) end function Broadcast.broadcasted( ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange ) - return CartesianProductUnitRange(r.product, $f.(x, r.range)) + return CartesianProductUnitRange(cartesianproduct(r), $f.(x, unproduct(r))) end end end diff --git a/src/fillarrays/matrixalgebrakit_truncate.jl b/src/fillarrays/matrixalgebrakit_truncate.jl index ae50f27..e505cf2 100644 --- a/src/fillarrays/matrixalgebrakit_truncate.jl +++ b/src/fillarrays/matrixalgebrakit_truncate.jl @@ -24,11 +24,11 @@ function MatrixAlgebraKit.findtruncated( values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy ) I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.a, prods)) + prods = collect(cartesianproduct(only(axes(values))))[I] + I_data = unique(map(arg1, prods)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> x.a == i, prods) == length(values.a) + return count(x -> arg1(x) == i, prods) == length(arg1(values)) end return (:) × I_data end @@ -36,11 +36,11 @@ function MatrixAlgebraKit.findtruncated( values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy ) I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.b, prods)) + prods = collect(cartesianproduct(only(axes(values))))[I] + I_data = unique(map(x -> arg2(x), prods)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> x.b == i, prods) == length(values.b) + return count(x -> arg2(x) == i, prods) == length(arg2(values)) end return I_data × (:) end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index fe7cbcb..ea4248a 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -24,17 +24,17 @@ arg2(a::KroneckerArray) = a.b using Adapt: Adapt, adapt _adapt(to, a::AbstractArray) = adapt(to, a) -Adapt.adapt_structure(to, a::KroneckerArray) = _adapt(to, a.a) ⊗ _adapt(to, a.b) +Adapt.adapt_structure(to, a::KroneckerArray) = _adapt(to, arg1(a)) ⊗ _adapt(to, arg2(a)) # Allows extra customization, like for `FillArrays.Eye`. _copy(a::AbstractArray) = copy(a) function Base.copy(a::KroneckerArray) - return _copy(a.a) ⊗ _copy(a.b) + return _copy(arg1(a)) ⊗ _copy(arg2(a)) end function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) - copyto!(dest.a, src.a) - copyto!(dest.b, src.b) + copyto!(arg1(dest), arg1(src)) + copyto!(arg2(dest), arg2(src)) return dest end @@ -53,8 +53,7 @@ function Base.similar( CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} }, ) - return _similar(a, elt, map(ax -> ax.product.a, axs)) ⊗ - _similar(a, elt, map(ax -> ax.product.b, axs)) + return _similar(a, elt, map(arg1, axs)) ⊗ _similar(a, elt, map(arg2, axs)) end function Base.similar( a::KroneckerArray, @@ -63,8 +62,7 @@ function Base.similar( CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} }, ) - return _similar(a.a, elt, map(ax -> ax.product.a, axs)) ⊗ - _similar(a.b, elt, map(ax -> ax.product.b, axs)) + return _similar(arg1(a), elt, map(arg1, axs)) ⊗ _similar(arg2(a), elt, map(arg2, axs)) end function Base.similar( arrayt::Type{<:AbstractArray}, @@ -72,8 +70,7 @@ function Base.similar( CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} }, ) - return _similar(arrayt, map(ax -> ax.product.a, axs)) ⊗ - _similar(arrayt, map(ax -> ax.product.b, axs)) + return _similar(arrayt, map(arg1, axs)) ⊗ _similar(arrayt, map(arg2, axs)) end function Base.similar( arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, @@ -81,8 +78,7 @@ function Base.similar( CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} }, ) where {A,B} - return _similar(A, map(ax -> ax.product.a, axs)) ⊗ - _similar(B, map(ax -> ax.product.b, axs)) + return _similar(A, map(arg1, axs)) ⊗ _similar(B, map(arg2, axs)) end function Base.similar( ::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, sz::Tuple{Int,Vararg{Int}} @@ -115,7 +111,7 @@ kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b) kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b) # Eagerly collect arguments to make more general on GPU. -Base.collect(a::KroneckerArray) = kron_nd(collect(a.a), collect(a.b)) +Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a))) Base.zero(a::KroneckerArray) = zero(arg1(a)) ⊗ zero(arg2(a)) @@ -123,31 +119,33 @@ function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} return convert(Array{T,N}, collect(a)) end -Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a)) +function Base.size(a::KroneckerArray) + return ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a)) +end function Base.axes(a::KroneckerArray) return ntuple(ndims(a)) do dim return CartesianProductUnitRange( - axes(a.a, dim) × axes(a.b, dim), Base.OneTo(size(a, dim)) + axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim)) ) end end -arguments(a::KroneckerArray) = (a.a, a.b) +arguments(a::KroneckerArray) = (arg1(a), arg2(a)) arguments(a::KroneckerArray, n::Int) = arguments(a)[n] argument_types(a::KroneckerArray) = argument_types(typeof(a)) argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B) function Base.print_array(io::IO, a::KroneckerArray) - Base.print_array(io, a.a) + Base.print_array(io, arg1(a)) println(io, "\n ⊗") - Base.print_array(io, a.b) + Base.print_array(io, arg2(a)) return nothing end function Base.show(io::IO, a::KroneckerArray) - show(io, a.a) + show(io, arg1(a)) print(io, " ⊗ ") - show(io, a.b) + show(io, arg2(a)) return nothing end @@ -172,14 +170,14 @@ function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) GPUArraysCore.assertscalar("getindex") # Code logic from Kronecker.jl: # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105 - k, l = size(a.b) - return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1] + k, l = size(arg2(a)) + return arg1(a)[cld(i1, k), cld(i2, l)] * arg2(a)[(i1 - 1) % k + 1, (i2 - 1) % l + 1] end function Base.getindex(a::KroneckerVector, i::Integer) GPUArraysCore.assertscalar("getindex") - k = length(a.b) - return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] + k = length(arg2(a)) + return arg1(a)[cld(i, k)] * arg2(a)[(i - 1) % k + 1] end # Allow customizing for `FillArrays.Eye`. @@ -191,49 +189,49 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) w return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) end # Fix ambigiuity error. -Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[] +Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[] function Base.:(==)(a::KroneckerArray, b::KroneckerArray) - return a.a == b.a && a.b == b.b + return arg1(a) == arg1(b) && arg2(a) == arg2(b) end function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...) - return isapprox(a.a, b.a; kwargs...) && isapprox(a.b, b.b; kwargs...) + return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...) end function Base.iszero(a::KroneckerArray) - return iszero(a.a) || iszero(a.b) + return iszero(arg1(a)) || iszero(arg2(a)) end function Base.isreal(a::KroneckerArray) - return isreal(a.a) && isreal(a.b) + return isreal(arg1(a)) && isreal(arg2(a)) end using DiagonalArrays: DiagonalArrays, diagonal function DiagonalArrays.diagonal(a::KroneckerArray) - return diagonal(a.a) ⊗ diagonal(a.b) + return diagonal(arg1(a)) ⊗ diagonal(arg2(a)) end Base.real(a::KroneckerArray{<:Real}) = a function Base.real(a::KroneckerArray) - if iszero(imag(a.a)) || iszero(imag(a.b)) - return real(a.a) ⊗ real(a.b) - elseif iszero(real(a.a)) || iszero(real(a.b)) - return -imag(a.a) ⊗ imag(a.b) + if iszero(imag(arg1(a))) || iszero(imag(arg2(a))) + return real(arg1(a)) ⊗ real(arg2(a)) + elseif iszero(real(arg1(a))) || iszero(real(arg2(a))) + return -imag(arg1(a)) ⊗ imag(arg2(a)) end - return real(a.a) ⊗ real(a.b) - imag(a.a) ⊗ imag(a.b) + return real(arg1(a)) ⊗ real(arg2(a)) - imag(arg1(a)) ⊗ imag(arg2(a)) end Base.imag(a::KroneckerArray{<:Real}) = zero(a) function Base.imag(a::KroneckerArray) - if iszero(imag(a.a)) || iszero(real(a.b)) - return real(a.a) ⊗ imag(a.b) - elseif iszero(real(a.a)) || iszero(imag(a.b)) - return imag(a.a) ⊗ real(a.b) + if iszero(imag(arg1(a))) || iszero(real(arg2(a))) + return real(arg1(a)) ⊗ imag(arg2(a)) + elseif iszero(real(arg1(a))) || iszero(imag(arg2(a))) + return imag(arg1(a)) ⊗ real(arg2(a)) end - return real(a.a) ⊗ imag(a.b) + imag(a.a) ⊗ real(a.b) + return real(arg1(a)) ⊗ imag(arg2(a)) + imag(arg1(a)) ⊗ real(arg2(a)) end for f in [:transpose, :adjoint, :inv] @eval begin function Base.$f(a::KroneckerArray) - return $f(a.a) ⊗ $f(a.b) + return $f(arg1(a)) ⊗ $f(arg2(a)) end end end From 5acb5a81bf590a308706a028596c6eb8a79aa85d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 22 Jun 2025 19:32:21 -0400 Subject: [PATCH 4/5] More general scalar indexing --- src/kroneckerarray.jl | 23 +++++------------------ test/Project.toml | 2 ++ test/test_basics.jl | 10 ++++++++++ 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index ea4248a..636c1cd 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -158,26 +158,13 @@ function Base.getindex(a::KroneckerArray, i::Integer) return a[CartesianIndices(a)[i]] end -# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing -# in the n-dimensional case and use it to replace the matrix and vector cases: -# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} - return error("Not implemented.") -end - using GPUArraysCore: GPUArraysCore -function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) - GPUArraysCore.assertscalar("getindex") - # Code logic from Kronecker.jl: - # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105 - k, l = size(arg2(a)) - return arg1(a)[cld(i1, k), cld(i2, l)] * arg2(a)[(i1 - 1) % k + 1, (i2 - 1) % l + 1] -end - -function Base.getindex(a::KroneckerVector, i::Integer) +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} GPUArraysCore.assertscalar("getindex") - k = length(arg2(a)) - return arg1(a)[cld(i, k)] * arg2(a)[(i - 1) % k + 1] + I′ = ntuple(Val(N)) do dim + return cartesianproduct(axes(a, dim))[I[dim]] + end + return a[I′...] end # Allow customizing for `FillArrays.Eye`. diff --git a/test/Project.toml b/test/Project.toml index 23b9bdd..9f9c5a9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -24,6 +25,7 @@ BlockSparseArrays = "0.7.19" DerivableInterfaces = "0.5" DiagonalArrays = "0.3.7" FillArrays = "1" +GPUArraysCore = "0.2" JLArrays = "0.2" KroneckerArrays = "0.1" LinearAlgebra = "1.10" diff --git a/test/test_basics.jl b/test/test_basics.jl index 5684936..db58023 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -2,6 +2,7 @@ using Adapt: adapt using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted using DerivableInterfaces: zero! using DiagonalArrays: diagonal +using GPUArraysCore: @allowscalar using JLArrays: JLArray using KroneckerArrays: KroneckerArrays, @@ -44,6 +45,15 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test first(r) == 2 @test last(r) == 7 + # Test high-dimensional materialization. + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 2, 2, 2) + x = Array(a) + y = similar(x) + for I in eachindex(a) + y[I] = @allowscalar x[I] + end + @test x == y + a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3)) b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c = a.a ⊗ b.b From 06c3e30eb9925ce336aa395c43144f1baa9eba11 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 22 Jun 2025 19:41:14 -0400 Subject: [PATCH 5/5] Add checkindex for cartesian product --- src/cartesianproduct.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index 3934bd8..88806da 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -90,6 +90,10 @@ function Base.axes(r::CartesianProductUnitRange) return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),) end +function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair) + return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) +end + using Base.Broadcast: DefaultArrayStyle for f in (:+, :-) @eval begin