From 051264713673b983aa86c8347bd5456ace2b31ea Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 1 Jul 2025 19:32:17 -0400 Subject: [PATCH 1/9] [WIP] Towards truncated block sparse factorizations --- Project.toml | 4 +- .../KroneckerArraysBlockSparseArraysExt.jl | 6 ++ src/cartesianproduct.jl | 63 +++++++++++++++++-- src/fillarrays/kroneckerarray.jl | 1 + src/fillarrays/matrixalgebrakit_truncate.jl | 42 +++++++------ test/test_basics.jl | 18 +++--- 6 files changed, 100 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index d1c03f7..dbbe6bb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.21" +version = "0.1.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] [compat] Adapt = "4.3.0" BlockArrays = "1.6" -BlockSparseArrays = "0.7.21" +BlockSparseArrays = "0.7.22" DerivableInterfaces = "0.5.0" DiagonalArrays = "0.3.5" FillArrays = "1.13.0" diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index 9fa1018..b23c869 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -99,4 +99,10 @@ function (f::GetUnstoredBlock)( return error("Not implemented.") end +using BlockSparseArrays: BlockSparseArrays +using KroneckerArrays: KroneckerArrays, KroneckerVector +function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I) + return KroneckerArrays.to_truncated_indices(values, I) +end + end diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index e3eb2fe..e0138f3 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -26,10 +26,16 @@ arguments(a::CartesianProduct, n::Int) = arguments(a)[n] arg1(a::CartesianProduct) = a.a arg2(a::CartesianProduct) = a.b +Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a)) + function Base.show(io::IO, a::CartesianProduct) print(io, a.a, " × ", a.b) return nothing end +function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct) + show(io, a) + return nothing +end ×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b) Base.length(a::CartesianProduct) = length(a.a) * length(a.b) @@ -42,8 +48,38 @@ function Base.getindex(a::CartesianProduct, i::CartesianPair) return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] end function Base.getindex(a::CartesianProduct, i::Int) - I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i]) - return a[I[1] × I[2]] + I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i]) + return a[I[2] × I[1]] +end + +struct CartesianProductVector{T,P<:CartesianProduct,V<:AbstractVector{T}} <: + AbstractVector{T} + product::P + values::V +end +cartesianproduct(r::CartesianProductVector) = getfield(r, :product) +unproduct(r::CartesianProductVector) = getfield(r, :values) +Base.length(a::CartesianProductVector) = length(unproduct(a)) +Base.size(a::CartesianProductVector) = (length(a),) +function Base.axes(r::CartesianProductVector) + return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),) +end +function Base.copy(a::CartesianProductVector) + return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a))) +end +function Base.getindex(r::CartesianProductVector, i::Integer) + return unproduct(r)[i] +end + +function Base.show(io::IO, a::CartesianProductVector) + show(io, unproduct(a)) + return nothing +end +function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductVector) + show(io, mime, cartesianproduct(a)) + println(io) + show(io, mime, unproduct(a)) + return nothing end struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <: @@ -60,13 +96,24 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range) arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a)) arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a)) +function Base.show(io::IO, a::CartesianProductUnitRange) + show(io, unproduct(a)) + return nothing +end +function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductUnitRange) + show(io, mime, cartesianproduct(a)) + println(io) + show(io, mime, unproduct(a)) + return nothing +end + function CartesianProductUnitRange(p::CartesianProduct) return CartesianProductUnitRange(p, Base.OneTo(length(p))) end function CartesianProductUnitRange(a, b) return CartesianProductUnitRange(a × b) end -to_product_indices(a::AbstractUnitRange) = a +to_product_indices(a::AbstractVector) = a to_product_indices(i::Integer) = Base.OneTo(i) cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b)) function cartesianrange(p::CartesianPair) @@ -94,10 +141,16 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) end +function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct) + prod = cartesianproduct(a) + prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)] + return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I)) +end + # Reverse map from CartesianPair to linear index in the range. function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair) - i′ = (findfirst(==(arg1(i)), arg1(inds)), findfirst(==(arg2(i)), arg2(inds))) - return inds[LinearIndices((length(arg1(inds)), length(arg2(inds))))[i′...]] + i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds))) + return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]] end using Base.Broadcast: DefaultArrayStyle diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index fc40ef7..c8ee8e6 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -23,6 +23,7 @@ const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T, _getindex(a::Eye, I1::Colon, I2::Colon) = a _view(a::Eye, I1::Colon, I2::Colon) = a +_view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a diff --git a/src/fillarrays/matrixalgebrakit_truncate.jl b/src/fillarrays/matrixalgebrakit_truncate.jl index e505cf2..1bf6ed2 100644 --- a/src/fillarrays/matrixalgebrakit_truncate.jl +++ b/src/fillarrays/matrixalgebrakit_truncate.jl @@ -20,34 +20,40 @@ const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVe const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} -function MatrixAlgebraKit.findtruncated( - values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(cartesianproduct(only(axes(values))))[I] - I_data = unique(map(arg1, prods)) +axis(a) = only(axes(a)) + +# Convert indices determined with a generic call to `findtruncated` to indices +# more suited for a KroneckerVector. +function to_truncated_indices(values::OnesKroneckerVector, I) + prods = cartesianproduct(axis(values))[I] + I_id = only(to_indices(arg1(values), (:,))) + I_data = unique(arg2.(prods)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> arg1(x) == i, prods) == length(arg1(values)) + return count(x -> arg2(x) == i, prods) == length(arg2(values)) end - return (:) × I_data + return I_id × I_data end -function MatrixAlgebraKit.findtruncated( - values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(cartesianproduct(only(axes(values))))[I] - I_data = unique(map(x -> arg2(x), prods)) +function to_truncated_indices(values::KroneckerOnesVector, I) + #I = findtruncated(Vector(values), strategy.strategy) + prods = cartesianproduct(axis(values))[I] + I_data = unique(arg1.(prods)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> arg2(x) == i, prods) == length(arg2(values)) + return count(x -> arg1(x) == i, prods) == length(arg2(values)) end - return I_data × (:) + I_id = only(to_indices(arg2(values), (:,))) + return I_data × I_id +end +function to_truncated_indices(values::OnesVectorOnesVector, I) + return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) end + function MatrixAlgebraKit.findtruncated( - values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy + values::KroneckerVector, strategy::KroneckerTruncationStrategy ) - return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) + I = findtruncated(Vector(values), strategy.strategy) + return to_truncated_indices(values, I) end for f in [:eig_trunc!, :eigh_trunc!] diff --git a/test/test_basics.jl b/test/test_basics.jl index 0fe2ef3..ee1ec6c 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -26,7 +26,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "KroneckerArrays (eltype=$elt)" for elt in elts p = [1, 2] × [3, 4, 5] @test length(p) == 6 - @test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5] + @test collect(p) == [1 × 3, 1 × 4, 1 × 5, 2 × 3, 2 × 4, 2 × 5] r = @constinferred cartesianrange(2, 3) @test r === @@ -39,10 +39,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test first(r) == 1 @test last(r) == 6 @test r[1 × 1] == 1 - @test r[2 × 1] == 2 - @test r[1 × 2] == 3 - @test r[2 × 2] == 4 - @test r[1 × 3] == 5 + @test r[1 × 2] == 2 + @test r[1 × 3] == 3 + @test r[2 × 1] == 4 + @test r[2 × 2] == 5 @test r[2 × 3] == 6 r = @constinferred(cartesianrange(2 × 3, 2:7)) @@ -53,10 +53,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test first(r) == 2 @test last(r) == 7 @test r[1 × 1] == 2 - @test r[2 × 1] == 3 - @test r[1 × 2] == 4 - @test r[2 × 2] == 5 - @test r[1 × 3] == 6 + @test r[1 × 2] == 3 + @test r[1 × 3] == 4 + @test r[2 × 1] == 5 + @test r[2 × 2] == 6 @test r[2 × 3] == 7 # Test high-dimensional materialization. From 443ac708ce7584823756c3dd3606ce4a8a45f3c9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 2 Jul 2025 12:03:33 -0400 Subject: [PATCH 2/9] Fix indexing order, test blockwise slicing --- test/test_blocksparsearrays.jl | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index bd88d9d..6fe2b82 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -1,10 +1,10 @@ using Adapt: adapt using BlockArrays: Block, BlockRange, mortar using BlockSparseArrays: - BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype + BlockIndexVector, BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype using FillArrays: Eye, SquareEye using JLArrays: JLArray -using KroneckerArrays: KroneckerArray, ⊗, × +using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2 using LinearAlgebra: norm using MatrixAlgebraKit: svd_compact using Test: @test, @test_broken, @testset @@ -50,6 +50,16 @@ arrayts = (Array, JLArray) @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] @test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + # Blockwise slicing, shows up in truncated block sparse matrix factorizations. + I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1]) + I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3]) + I = [I1, I2] + b = a[I, I] + @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] + @test iszero(b[Block(2, 1)]) + @test iszero(b[Block(1, 2)]) + @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] + # Slicing r = blockrange([2 × 2, 3 × 3]) d = Dict( @@ -161,6 +171,20 @@ end @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] @test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + # Blockwise slicing, shows up in truncated block sparse matrix factorizations. + I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1]) + I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3]) + I = [I1, I2] + b = a[I, I] + @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] + @test arg1(b[Block(1, 1)]) isa Eye + @test iszero(b[Block(2, 1)]) + @test arg1(b[Block(2, 1)]) isa Eye + @test iszero(b[Block(1, 2)]) + @test arg1(b[Block(1, 2)]) isa Eye + @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] + @test arg1(b[Block(2, 2)]) isa Eye + # Slicing r = blockrange([2 × 2, 3 × 3]) d = Dict( From 42413dd1cfcf2cc739596ff30cef8d672e6be6cc Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 2 Jul 2025 16:48:59 -0400 Subject: [PATCH 3/9] Fix some broken slicing operations --- src/fillarrays/kroneckerarray.jl | 5 +++++ src/kroneckerarray.jl | 15 ++++++++++----- test/test_blocksparsearrays.jl | 4 ++-- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index c8ee8e6..f943ce0 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -22,8 +22,13 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} _getindex(a::Eye, I1::Colon, I2::Colon) = a +_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a +_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a +_getindex(a::Eye, I1::Colon, I2::Base.Slice) = a _view(a::Eye, I1::Colon, I2::Colon) = a _view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a +_view(a::Eye, I1::Base.Slice, I2::Colon) = a +_view(a::Eye, I1::Colon, I2::Base.Slice) = a # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 94d601a..a9c928b 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -167,13 +167,18 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where { return a[I′...] end +# Indexing logic. +function Base.to_indices(a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg}) + I1 = to_indices(arg1(a), arg1.(inds), arg1.(I)) + I2 = to_indices(arg2(a), arg2.(inds), arg2.(I)) + return I1 .× I2 +end + # Allow customizing for `FillArrays.Eye`. _getindex(a::AbstractArray, I...) = a[I...] -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} - return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) -end -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} - return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N}) where {N} + I′ = to_indices(a, I) + return _getindex(arg1(a), arg1.(I′)...) ⊗ _getindex(arg2(a), arg2.(I′)...) end # Fix ambigiuity error. Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[] diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 6fe2b82..37ead77 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -48,7 +48,7 @@ arrayts = (Array, JLArray) @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] - @test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] # Blockwise slicing, shows up in truncated block sparse matrix factorizations. I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1]) @@ -169,7 +169,7 @@ end @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] - @test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] # Blockwise slicing, shows up in truncated block sparse matrix factorizations. I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1]) From f747785e2f4582911553b12dfff8ecbdb5b98b4f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 2 Jul 2025 16:51:10 -0400 Subject: [PATCH 4/9] Format --- src/kroneckerarray.jl | 8 ++++++-- test/test_blocksparsearrays.jl | 6 ++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index a9c928b..96167f3 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -168,7 +168,9 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where { end # Indexing logic. -function Base.to_indices(a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg}) +function Base.to_indices( + a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg} +) I1 = to_indices(arg1(a), arg1.(inds), arg1.(I)) I2 = to_indices(arg2(a), arg2.(inds), arg2.(I)) return I1 .× I2 @@ -176,7 +178,9 @@ end # Allow customizing for `FillArrays.Eye`. _getindex(a::AbstractArray, I...) = a[I...] -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N}) where {N} +function Base.getindex( + a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N} +) where {N} I′ = to_indices(a, I) return _getindex(arg1(a), arg1.(I′)...) ⊗ _getindex(arg2(a), arg2.(I′)...) end diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 37ead77..770e216 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -48,7 +48,8 @@ arrayts = (Array, JLArray) @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] - @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == + a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] # Blockwise slicing, shows up in truncated block sparse matrix factorizations. I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1]) @@ -169,7 +170,8 @@ end @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] - @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == + a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] # Blockwise slicing, shows up in truncated block sparse matrix factorizations. I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1]) From 7e0e5f05df80a6d2ea74c47087f25081ec5725a3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Jul 2025 18:02:05 -0400 Subject: [PATCH 5/9] Small fixes for Kronecker block operations --- Project.toml | 2 +- src/cartesianproduct.jl | 4 +++- src/kroneckerarray.jl | 19 ++++++++++++++----- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index dbbe6bb..46edc19 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.22" +version = "0.1.23" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index e0138f3..bada26c 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -134,7 +134,9 @@ function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) end function Base.axes(r::CartesianProductUnitRange) - return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),) + prod = cartesianproduct(r) + prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) + return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) end function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 96167f3..6298f29 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -38,6 +38,10 @@ function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) return dest end +function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where {T,N,A,B} + return KroneckerArray(convert(A, arg1(a)), convert(B, arg2(a))) +end + # Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`. function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}}) return similar(a, elt, axs) @@ -189,7 +193,14 @@ Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[] # Allow customizing for `FillArrays.Eye`. _view(a::AbstractArray, I...) = view(a, I...) -function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} +arg1(::Colon) = (:) +arg2(::Colon) = (:) +arg1(::Base.Slice) = (:) +arg2(::Base.Slice) = (:) +function Base.view( + a::KroneckerArray{<:Any,N}, + I::Vararg{Union{CartesianProduct,CartesianProductUnitRange,Base.Slice,Colon},N}, +) where {N} return _view(arg1(a), arg1.(I)...) ⊗ _view(arg2(a), arg2.(I)...) end function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} @@ -272,10 +283,8 @@ function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N return KroneckerStyle{N}(style_a, style_b) end function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type, ax) where {N,A,B} - ax_a = arg1.(ax) - ax_b = arg2.(ax) - bc_a = Broadcasted(A, nothing, (), ax_a) - bc_b = Broadcasted(B, nothing, (), ax_b) + bc_a = Broadcasted(A, bc.f, arg1.(bc.args), arg1.(ax)) + bc_b = Broadcasted(B, bc.f, arg2.(bc.args), arg2.(ax)) a = similar(bc_a, elt) b = similar(bc_b, elt) return a ⊗ b From 8e3284f972f78cd2421b1cafca19b29bead382cd Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Jul 2025 18:16:24 -0400 Subject: [PATCH 6/9] Slice axes --- src/cartesianproduct.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index bada26c..5adae71 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -143,6 +143,11 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) end +const CartesianProductOneTo{T,P<:CartesianProduct,R<:Base.OneTo{T}} = CartesianProductUnitRange{T,P,R} +Base.axes(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) +Base.axes1(S::Base.Slice{<:CartesianProductOneTo}) = S.indices +Base.unsafe_indices(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) + function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct) prod = cartesianproduct(a) prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)] From b2227f2e0198e195e5e47a83f20de4747d9bcdd6 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 17 Jul 2025 18:20:30 -0400 Subject: [PATCH 7/9] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 46edc19..dbbe6bb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.23" +version = "0.1.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 3e2fd9e4939b6c1239f56a4dd31801840ec9a124 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Jul 2025 18:25:03 -0400 Subject: [PATCH 8/9] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dbbe6bb..46edc19 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.22" +version = "0.1.23" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From e4ade2cbc0ff94e19d7a1d5d42d55e114cfb23e4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Jul 2025 18:28:38 -0400 Subject: [PATCH 9/9] Format --- src/cartesianproduct.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index 5adae71..be1c4fa 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -143,7 +143,9 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) end -const CartesianProductOneTo{T,P<:CartesianProduct,R<:Base.OneTo{T}} = CartesianProductUnitRange{T,P,R} +const CartesianProductOneTo{T,P<:CartesianProduct,R<:Base.OneTo{T}} = CartesianProductUnitRange{ + T,P,R +} Base.axes(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) Base.axes1(S::Base.Slice{<:CartesianProductOneTo}) = S.indices Base.unsafe_indices(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,)