From a35afc3999690a1787db171b8b203d7823d5fe1a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 16 Jun 2025 15:32:09 -0400 Subject: [PATCH 1/6] Block sparse SVD --- Project.toml | 4 +- .../KroneckerArraysBlockSparseArraysExt.jl | 77 ++++++++++++++++++- src/cartesianproduct.jl | 6 ++ src/fillarrays/kroneckerarray.jl | 17 +++- src/kroneckerarray.jl | 10 ++- 5 files changed, 108 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 6562906..0e2b55c 100644 --- a/Project.toml +++ b/Project.toml @@ -13,13 +13,15 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" [weakdeps] +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" [extensions] -KroneckerArraysBlockSparseArraysExt = "BlockSparseArrays" +KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] [compat] Adapt = "4.3.0" +BlockArrays = "1.6" BlockSparseArrays = "0.7.9" DerivableInterfaces = "0.5.0" DiagonalArrays = "0.3.5" diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index 85ba21d..91f579f 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -2,9 +2,84 @@ module KroneckerArraysBlockSparseArraysExt using BlockSparseArrays: BlockSparseArrays, blockrange using KroneckerArrays: CartesianProduct, cartesianrange - function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct}) return blockrange(map(cartesianrange, bs)) end +using BlockArrays: AbstractBlockedUnitRange +using BlockSparseArrays: Block, GetUnstoredBlock, eachblockaxis, mortar_axis +using DerivableInterfaces: zero! +using FillArrays: Eye +using KroneckerArrays: + KroneckerArrays, + EyeEye, + EyeKronecker, + KroneckerEye, + KroneckerMatrix, + ⊗, + arg1, + arg2, + _similar + +function KroneckerArrays.arg1(r::AbstractBlockedUnitRange) + return mortar_axis(arg2.(eachblockaxis(r))) +end +function KroneckerArrays.arg2(r::AbstractBlockedUnitRange) + return mortar_axis(arg2.(eachblockaxis(r))) +end + +function block_axes( + ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Vararg{Block{1},N} +) where {N} + return ntuple(N) do d + return only(axes(ax[d][I[d]])) + end +end +function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) where {N} + return block_axes(ax, Tuple(I)...) +end + +function (f::GetUnstoredBlock)( + ::Type{<:AbstractMatrix{KroneckerMatrix{T,A,B}}}, I::Vararg{Int,2} +) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} + ax_a = arg1.(f.axes) + f_a = GetUnstoredBlock(ax_a) + a = f_a(AbstractMatrix{A}, I...) + + ax_b = arg2.(f.axes) + f_b = GetUnstoredBlock(ax_b) + b = f_b(AbstractMatrix{B}, I...) + + return a ⊗ b +end +function (f::GetUnstoredBlock)( + ::Type{<:AbstractMatrix{EyeKronecker{T,A,B}}}, I::Vararg{Int,2} +) where {T,A<:Eye{T},B<:AbstractMatrix{T}} + block_ax_a = arg1.(block_axes(f.axes, Block(I))) + a = _similar(A, block_ax_a) + + ax_b = arg2.(f.axes) + f_b = GetUnstoredBlock(ax_b) + b = f_b(AbstractMatrix{B}, I...) + + return a ⊗ b +end +function (f::GetUnstoredBlock)( + ::Type{<:AbstractMatrix{KroneckerEye{T,A,B}}}, I::Vararg{Int,2} +) where {T,A<:AbstractMatrix{T},B<:Eye{T}} + ax_a = arg1.(f.axes) + f_a = GetUnstoredBlock(ax_a) + a = f_a(AbstractMatrix{A}, I...) + + block_ax_b = arg2.(block_axes(f.axes, Block(I))) + b = _similar(B, block_ax_b) + + return a ⊗ b +end +function (f::GetUnstoredBlock)( + ::Type{<:AbstractMatrix{EyeEye{T,A,B}}}, I::Vararg{Int,2} +) where {T,A<:Eye{T},B<:Eye{T}} + return error("Not implemented.") +end + end diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index 021c7c0..c734253 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -5,6 +5,9 @@ end arguments(a::CartesianProduct) = (a.a, a.b) arguments(a::CartesianProduct, n::Int) = arguments(a)[n] +arg1(a::CartesianProduct) = a.a +arg2(a::CartesianProduct) = a.b + function Base.show(io::IO, a::CartesianProduct) print(io, a.a, " × ", a.b) return nothing @@ -32,6 +35,9 @@ Base.last(r::CartesianProductUnitRange) = last(r.range) cartesianproduct(r::CartesianProductUnitRange) = getfield(r, :product) unproduct(r::CartesianProductUnitRange) = getfield(r, :range) +arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a)) +arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a)) + function CartesianProductUnitRange(p::CartesianProduct) return CartesianProductUnitRange(p, Base.OneTo(length(p))) end diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 8dcd431..689f1e4 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -1,3 +1,6 @@ +using FillArrays: RectDiagonal, OnesVector +const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes} + using FillArrays: Eye const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B} const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} @@ -11,6 +14,14 @@ const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T, # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a +# Allows customizing for `FillArrays.Eye`. +function _convert(::Type{AbstractArray{T}}, a::RectDiagonal) where {T} + _convert(AbstractMatrix{T}, a) +end +function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T} + RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a)) +end + # Like `similar` but preserves `Eye`. function _similar(a::AbstractArray, elt::Type, ax::Tuple) return similar(a, elt, ax) @@ -124,15 +135,15 @@ for op in (:+, :-) end end -function Base.map!(f::typeof(identity), dest::EyeKronecker, a::EyeKronecker) +function Base.map!(f::typeof(identity), dest::EyeKronecker, src::EyeKronecker) map!(f, dest.b, src.b) return dest end -function Base.map!(f::typeof(identity), dest::KroneckerEye, a::KroneckerEye) +function Base.map!(f::typeof(identity), dest::KroneckerEye, src::KroneckerEye) map!(f, dest.a, src.a) return dest end -function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye) +function Base.map!(::typeof(identity), dest::EyeEye, src::EyeEye) return error("Can't write in-place.") end for f in [:+, :-] diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index e2e770e..e4530c4 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -1,3 +1,8 @@ +# Allows customizing for `FillArrays.Eye`. +function _convert(A::Type{<:AbstractArray}, a::AbstractArray) + return convert(A, a) +end + struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N} a::A b::B @@ -9,11 +14,14 @@ function KroneckerArray(a::AbstractArray, b::AbstractArray) ) end elt = promote_type(eltype(a), eltype(b)) - return KroneckerArray(convert(AbstractArray{elt}, a), convert(AbstractArray{elt}, b)) + return KroneckerArray(_convert(AbstractArray{elt}, a), _convert(AbstractArray{elt}, b)) end const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B} const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B} +arg1(a::KroneckerArray) = a.a +arg2(a::KroneckerArray) = a.b + using Adapt: Adapt, adapt _adapt(to, a::AbstractArray) = adapt(to, a) Adapt.adapt_structure(to, a::KroneckerArray) = _adapt(to, a.a) ⊗ _adapt(to, a.b) From d430d27e31e7185b894c8bc8dbf18ec66d524f07 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 16 Jun 2025 16:59:14 -0400 Subject: [PATCH 2/6] Bump dep version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0e2b55c..36cab8c 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] [compat] Adapt = "4.3.0" BlockArrays = "1.6" -BlockSparseArrays = "0.7.9" +BlockSparseArrays = "0.7.13" DerivableInterfaces = "0.5.0" DiagonalArrays = "0.3.5" FillArrays = "1.13.0" From 5e2fb53c705e8838012f57994d75d50e729abc5c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 17 Jun 2025 09:04:32 -0400 Subject: [PATCH 3/6] Fix tests --- test/test_blocksparsearrays.jl | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 5ff50cb..5b02783 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -64,9 +64,16 @@ arrayts = (Array, JLArray) @test_broken inv(a) end + if arrayt === Array && elt <: Real + u, s, v = svd_compact(a) + @test Array(u * s * v) ≈ Array(a) + else + # Broken on GPU and for complex, investigate. + @test_broken svd_compact(a) + end + # Broken operations @test_broken exp(a) - @test_broken svd_compact(a) @test_broken a[Block.(1:2), Block(2)] end @@ -129,8 +136,12 @@ end b = @constinferred exp(a) @test Array(b) ≈ exp(Array(a)) + u, s, v = svd_compact(a) + @test u * s * v ≈ a + @test blocktype(u) === blocktype(a) + @test blocktype(v) === blocktype(a) + # Broken operations @test_broken inv(a) - @test_broken svd_compact(a) @test_broken a[Block.(1:2), Block(2)] end From 51e7e915f3552497a78a82ee08cfbe7a8d9ba1f2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 17 Jun 2025 09:22:22 -0400 Subject: [PATCH 4/6] Reenable more tests --- src/kroneckerarray.jl | 3 ++- test/test_blocksparsearrays.jl | 30 ++++++++++++++++-------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index e4530c4..5bbae8c 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -114,7 +114,8 @@ end kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b) kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b) -Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b) +# Eagerly collect arguments to make more general on GPU. +Base.collect(a::KroneckerArray) = kron_nd(collect(a.a), collect(a.b)) function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} return convert(Array{T,N}, collect(a)) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 5b02783..905b4f2 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -81,17 +81,11 @@ end arrayts, elt in elts - if arrayt == JLArray - # TODO: Collecting to `Array` is broken for GPU arrays so a lot of tests - # are broken, look into fixing that. - continue - end - dev = adapt(arrayt) r = @constinferred blockrange([2 × 2, 3 × 3]) d = Dict( - Block(1, 1) => Eye{elt}(2, 2) ⊗ randn(elt, 2, 2), - Block(2, 2) => Eye{elt}(3, 3) ⊗ randn(elt, 3, 3), + Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2)), + Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3)), ) a = @constinferred dev(blocksparse(d, r, r)) @test sprint(show, a) == sprint(show, Array(a)) @@ -133,13 +127,21 @@ end @test @constinferred(norm(a)) ≈ norm(Array(a)) - b = @constinferred exp(a) - @test Array(b) ≈ exp(Array(a)) + if arrayt === Array + b = @constinferred exp(a) + @test Array(b) ≈ exp(Array(a)) + else + @test_broken exp(a) + end - u, s, v = svd_compact(a) - @test u * s * v ≈ a - @test blocktype(u) === blocktype(a) - @test blocktype(v) === blocktype(a) + if arrayt === Array + u, s, v = svd_compact(a) + @test u * s * v ≈ a + @test blocktype(u) === blocktype(a) + @test blocktype(v) === blocktype(a) + else + @test_broken svd_compact(a) + end # Broken operations @test_broken inv(a) From 75ab024f777fe61c3f9af1be8db4e28b0628cb64 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 17 Jun 2025 09:23:43 -0400 Subject: [PATCH 5/6] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 36cab8c..c790519 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.13" +version = "0.1.14" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 0c0a1bce6f26f1246448882ebaa52116fa499a49 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 17 Jun 2025 09:44:14 -0400 Subject: [PATCH 6/6] Fix tests in Julia v1.10 --- test/test_blocksparsearrays.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 905b4f2..0f2ea3a 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -64,7 +64,8 @@ arrayts = (Array, JLArray) @test_broken inv(a) end - if arrayt === Array && elt <: Real + if (VERSION ≤ v"1.11-" && arrayt === Array && elt <: Complex) || + (arrayt === Array && elt <: Real) u, s, v = svd_compact(a) @test Array(u * s * v) ≈ Array(a) else @@ -134,7 +135,10 @@ end @test_broken exp(a) end - if arrayt === Array + if VERSION < v"1.11-" && elt <: Complex + # Broken because of type stability issue in Julia v1.10. + @test_broken svd_compact(a) + elseif arrayt === Array u, s, v = svd_compact(a) @test u * s * v ≈ a @test blocktype(u) === blocktype(a)