Skip to content

Update to BlockSparseArrays v0.9 #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.26"
version = "0.1.27"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -16,19 +16,22 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"

[extensions]
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
KroneckerArraysTensorProductsExt = "TensorProducts"

[compat]
Adapt = "4.3"
BlockArrays = "1.6"
BlockSparseArrays = "0.8.1"
BlockSparseArrays = "0.9"
DerivableInterfaces = "0.5"
DiagonalArrays = "0.3.5"
FillArrays = "1.13"
GPUArraysCore = "0.2"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.9"
MatrixAlgebraKit = "0.2"
TensorProducts = "0.1.7"
julia = "1.10"
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module KroneckerArraysTensorProductsExt

using KroneckerArrays: CartesianProductOneTo, ×, arg1, arg2, cartesianrange, unproduct
using TensorProducts: TensorProducts, tensor_product
function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo)
prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2))
range = tensor_product(unproduct(a1), unproduct(a2))
return cartesianrange(prod, range)
end

end
4 changes: 3 additions & 1 deletion src/cartesianproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ unproduct(r::CartesianProductVector) = getfield(r, :values)
Base.length(a::CartesianProductVector) = length(unproduct(a))
Base.size(a::CartesianProductVector) = (length(a),)
function Base.axes(r::CartesianProductVector)
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
prod = cartesianproduct(r)
prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod)))
return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),)
end
function Base.copy(a::CartesianProductVector)
return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a)))
Expand Down
34 changes: 4 additions & 30 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,11 @@ function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T}
RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
end

# Like `similar` but preserves `Eye`.
function _similar(a::AbstractArray, elt::Type, ax::Tuple)
return similar(a, elt, ax)
# Like `similar` but preserves `Eye`, `Ones`, etc.
using FillArrays: Ones
function _similar(arrayt::Type{<:Ones}, axs::Tuple)
return Ones{eltype(arrayt)}(axs)
end
function _similar(A::Type{<:AbstractArray}, ax::Tuple)
return similar(A, ax)
end
function _similar(a::AbstractArray, ax::Tuple)
return _similar(a, eltype(a), ax)
end
function _similar(a::AbstractArray, elt::Type)
return _similar(a, elt, axes(a))
end
function _similar(a::AbstractArray)
return _similar(a, eltype(a), axes(a))
end

# Like `similar` but preserves `Eye`.
function _similar(a::Eye, elt::Type, axs::NTuple{2,AbstractUnitRange})
return Eye{elt}(axs)
end
Expand All @@ -77,19 +64,6 @@ end
# Like `copy` but preserves `Eye`.
_copy(a::Eye) = a

using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::EyeKronecker)
zero!(a.b)
return a
end
function DerivableInterfaces.zero!(a::KroneckerEye)
zero!(a.a)
return a
end
function DerivableInterfaces.zero!(a::EyeEye)
return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`."))
end

using Base.Broadcast:
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted

Expand Down
27 changes: 24 additions & 3 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,19 @@ function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where
end

# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}})
function _similar(a::AbstractArray, elt::Type, axs::Tuple)
return similar(a, elt, axs)
end
function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple{Vararg{AbstractUnitRange}})
function _similar(a::AbstractArray, ax::Tuple)
return _similar(a, eltype(a), ax)
end
function _similar(a::AbstractArray, elt::Type)
return _similar(a, elt, axes(a))
end
function _similar(a::AbstractArray)
return _similar(a, eltype(a), axes(a))
end
function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple)
return similar(arrayt, axs)
end

Expand Down Expand Up @@ -130,6 +139,16 @@ Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))

Base.zero(a::KroneckerArray) = zero(arg1(a)) ⊗ zero(arg2(a))

using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::KroneckerArray)
ismut1 = ismutable(arg1(a))
ismut2 = ismutable(arg2(a))
(ismut1 || ismut2) || throw(ArgumentError("Can't zero out immutable KroneckerArray."))
ismut1 && zero!(arg1(a))
ismut2 && zero!(arg2(a))
return a
end

function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
return convert(Array{T,N}, collect(a))
end
Expand Down Expand Up @@ -372,13 +391,15 @@ _eltype(x) = eltype(x)
_eltype(x::Broadcasted) = Base.promote_op(x.f, _eltype.(x.args)...)

using Base.Broadcast: broadcasted
struct KroneckerBroadcasted{A<:Broadcasted,B<:Broadcasted}
struct KroneckerBroadcasted{A,B}
a::A
b::B
end
arg1(a::KroneckerBroadcasted) = a.a
arg2(a::KroneckerBroadcasted) = a.b
⊗(a::Broadcasted, b::Broadcasted) = KroneckerBroadcasted(a, b)
⊗(a::Broadcasted, b) = KroneckerBroadcasted(a, b)
⊗(a, b::Broadcasted) = KroneckerBroadcasted(a, b)
Broadcast.materialize(a::KroneckerBroadcasted) = copy(a)
Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a)
Broadcast.broadcastable(a::KroneckerBroadcasted) = a
Expand Down
7 changes: 0 additions & 7 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,3 @@ function LinearAlgebra.lq(a::KroneckerArray)
Fb = lq(a.b)
return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q)
end

using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::KroneckerArray)
zero!(a.a)
zero!(a.b)
return a
end
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Adapt = "4"
Aqua = "0.8"
BlockArrays = "1.6"
BlockSparseArrays = "0.8.1"
BlockSparseArrays = "0.9"
DerivableInterfaces = "0.5"
DiagonalArrays = "0.3.7"
FillArrays = "1"
Expand All @@ -33,5 +34,6 @@ MatrixAlgebraKit = "0.2"
SafeTestsets = "0.1"
StableRNGs = "1.0"
Suppressor = "0.2"
TensorProducts = "0.1.7"
Test = "1.10"
TestExtras = "0.3"
9 changes: 9 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using KroneckerArrays:
KroneckerArray,
KroneckerStyle,
CartesianProductUnitRange,
CartesianProductVector,
⊗,
×,
arg1,
Expand Down Expand Up @@ -45,6 +46,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test r[2 × 2] == 5
@test r[2 × 3] == 6

# CartesianProductUnitRange axes
r = cartesianrange((2:3) × (3:4), 2:5)
@test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),)

# CartesianProductVector axes
r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9])
@test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),)

r = @constinferred(cartesianrange(2 × 3, 2:7))
@test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7)
@test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3)
Expand Down
31 changes: 18 additions & 13 deletions test/test_blocksparsearrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ arrayts = (Array, JLArray)
Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
a = dev(blocksparse(d, (r, r)))
@test sprint(show, a) isa String
@test sprint(show, MIME("text/plain"), a) isa String
@test blocktype(a) === valtype(d)
Expand All @@ -45,7 +45,7 @@ arrayts = (Array, JLArray)
Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
a = dev(blocksparse(d, (r, r)))
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
Expand All @@ -68,7 +68,7 @@ arrayts = (Array, JLArray)
Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
a = dev(blocksparse(d, (r, r)))
i1 = Block(1)[(1:2) × (1:2)]
i2 = Block(2)[(2:3) × (2:3)]
I = mortar([i1, i2])
Expand All @@ -83,7 +83,7 @@ arrayts = (Array, JLArray)
Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
a = dev(blocksparse(d, (r, r)))
i1 = Block(1)[(1:2) × (1:2)]
i2 = Block(2)[(2:3) × (2:3)]
I = [i1, i2]
Expand Down Expand Up @@ -130,9 +130,12 @@ arrayts = (Array, JLArray)
@test_broken svd_compact(a)
end

b = a[Block.(1:2), Block(2)]
@test b[Block(1)] == a[Block(1, 2)]
@test b[Block(2)] == a[Block(2, 2)]

# Broken operations
@test_broken exp(a)
@test_broken a[Block.(1:2), Block(2)]
end

@testset "BlockSparseArraysExt, EyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
Expand All @@ -145,7 +148,7 @@ end
Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2)),
Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3)),
)
a = @constinferred dev(blocksparse(d, r, r))
a = @constinferred dev(blocksparse(d, (r, r)))
@test sprint(show, a) == sprint(show, Array(a))
@test sprint(show, MIME("text/plain"), a) isa String
@test @constinferred(blocktype(a)) === valtype(d)
Expand All @@ -167,7 +170,7 @@ end
Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
a = dev(blocksparse(d, (r, r)))
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
Expand All @@ -194,7 +197,7 @@ end
Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
a = dev(blocksparse(d, (r, r)))
i1 = Block(1)[(1:2) × (1:2)]
i2 = Block(2)[(2:3) × (2:3)]
I = mortar([i1, i2])
Expand All @@ -209,7 +212,7 @@ end
Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
a = dev(blocksparse(d, (r, r)))
i1 = Block(1)[(1:2) × (1:2)]
i2 = Block(2)[(2:3) × (2:3)]
I = [i1, i2]
Expand Down Expand Up @@ -272,7 +275,9 @@ end
end

# Broken operations
@test_broken a[Block.(1:2), Block(2)]
b = a[Block.(1:2), Block(2)]
@test b[Block(1)] == a[Block(1, 2)]
@test b[Block(2)] == a[Block(2, 2)]

# svd_trunc
dev = adapt(arrayt)
Expand All @@ -282,7 +287,7 @@ end
Block(1, 1) => Eye{elt}(2, 2) ⊗ randn(rng, elt, 2, 2),
Block(2, 2) => Eye{elt}(3, 3) ⊗ randn(rng, elt, 3, 3),
)
a = @constinferred dev(blocksparse(d, r, r))
a = @constinferred dev(blocksparse(d, (r, r)))
if arrayt === Array
u, s, v = svd_trunc(a; trunc=(; maxrank=6))
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5))
Expand All @@ -293,10 +298,10 @@ end

@testset "Block deficient" begin
da = Dict(Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2)))
a = @constinferred dev(blocksparse(da, r, r))
a = @constinferred dev(blocksparse(da, (r, r)))

db = Dict(Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3)))
b = @constinferred dev(blocksparse(db, r, r))
b = @constinferred dev(blocksparse(db, (r, r)))

@test Array(a + b) ≈ Array(a) + Array(b)
@test Array(2a) ≈ 2Array(a)
Expand Down
13 changes: 13 additions & 0 deletions test/test_tensorproducts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using KroneckerArrays: ×, arg1, arg2, cartesianrange, unproduct
using TensorProducts: tensor_product
using Test: @test, @testset

@testset "KroneckerArraysTensorProductsExt" begin
r1 = cartesianrange(2, 3)
r2 = cartesianrange(4, 5)
r = tensor_product(r1, r2)
@test r ≡ cartesianrange(8, 15)
@test arg1(r) ≡ Base.OneTo(8)
@test arg2(r) ≡ Base.OneTo(15)
@test unproduct(r) ≡ Base.OneTo(120)
end
Loading