Skip to content

Block sparse SVD #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.13"
version = "0.1.14"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -13,14 +13,16 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"

[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"

[extensions]
KroneckerArraysBlockSparseArraysExt = "BlockSparseArrays"
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]

[compat]
Adapt = "4.3.0"
BlockSparseArrays = "0.7.9"
BlockArrays = "1.6"
BlockSparseArrays = "0.7.13"
DerivableInterfaces = "0.5.0"
DiagonalArrays = "0.3.5"
FillArrays = "1.13.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,84 @@ module KroneckerArraysBlockSparseArraysExt

using BlockSparseArrays: BlockSparseArrays, blockrange
using KroneckerArrays: CartesianProduct, cartesianrange

function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
return blockrange(map(cartesianrange, bs))
end

using BlockArrays: AbstractBlockedUnitRange
using BlockSparseArrays: Block, GetUnstoredBlock, eachblockaxis, mortar_axis
using DerivableInterfaces: zero!
using FillArrays: Eye
using KroneckerArrays:
KroneckerArrays,
EyeEye,
EyeKronecker,
KroneckerEye,
KroneckerMatrix,
⊗,
arg1,
arg2,
_similar

function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
return mortar_axis(arg2.(eachblockaxis(r)))
end
function KroneckerArrays.arg2(r::AbstractBlockedUnitRange)
return mortar_axis(arg2.(eachblockaxis(r)))
end

function block_axes(
ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Vararg{Block{1},N}
) where {N}
return ntuple(N) do d
return only(axes(ax[d][I[d]]))
end
end
function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) where {N}
return block_axes(ax, Tuple(I)...)
end

function (f::GetUnstoredBlock)(
::Type{<:AbstractMatrix{KroneckerMatrix{T,A,B}}}, I::Vararg{Int,2}
) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}}
ax_a = arg1.(f.axes)
f_a = GetUnstoredBlock(ax_a)
a = f_a(AbstractMatrix{A}, I...)

ax_b = arg2.(f.axes)
f_b = GetUnstoredBlock(ax_b)
b = f_b(AbstractMatrix{B}, I...)

return a ⊗ b
end
function (f::GetUnstoredBlock)(
::Type{<:AbstractMatrix{EyeKronecker{T,A,B}}}, I::Vararg{Int,2}
) where {T,A<:Eye{T},B<:AbstractMatrix{T}}
block_ax_a = arg1.(block_axes(f.axes, Block(I)))
a = _similar(A, block_ax_a)

ax_b = arg2.(f.axes)
f_b = GetUnstoredBlock(ax_b)
b = f_b(AbstractMatrix{B}, I...)

return a ⊗ b
end
function (f::GetUnstoredBlock)(
::Type{<:AbstractMatrix{KroneckerEye{T,A,B}}}, I::Vararg{Int,2}
) where {T,A<:AbstractMatrix{T},B<:Eye{T}}
ax_a = arg1.(f.axes)
f_a = GetUnstoredBlock(ax_a)
a = f_a(AbstractMatrix{A}, I...)

block_ax_b = arg2.(block_axes(f.axes, Block(I)))
b = _similar(B, block_ax_b)

return a ⊗ b
end
function (f::GetUnstoredBlock)(
::Type{<:AbstractMatrix{EyeEye{T,A,B}}}, I::Vararg{Int,2}
) where {T,A<:Eye{T},B<:Eye{T}}
return error("Not implemented.")
end

end
6 changes: 6 additions & 0 deletions src/cartesianproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ end
arguments(a::CartesianProduct) = (a.a, a.b)
arguments(a::CartesianProduct, n::Int) = arguments(a)[n]

arg1(a::CartesianProduct) = a.a
arg2(a::CartesianProduct) = a.b

function Base.show(io::IO, a::CartesianProduct)
print(io, a.a, " × ", a.b)
return nothing
Expand Down Expand Up @@ -32,6 +35,9 @@ Base.last(r::CartesianProductUnitRange) = last(r.range)
cartesianproduct(r::CartesianProductUnitRange) = getfield(r, :product)
unproduct(r::CartesianProductUnitRange) = getfield(r, :range)

arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))

function CartesianProductUnitRange(p::CartesianProduct)
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
end
Expand Down
17 changes: 14 additions & 3 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using FillArrays: RectDiagonal, OnesVector
const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes}

using FillArrays: Eye
const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
Expand All @@ -11,6 +14,14 @@ const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,
# Like `adapt` but preserves `Eye`.
_adapt(to, a::Eye) = a

# Allows customizing for `FillArrays.Eye`.
function _convert(::Type{AbstractArray{T}}, a::RectDiagonal) where {T}
_convert(AbstractMatrix{T}, a)
end
function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T}
RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
end

# Like `similar` but preserves `Eye`.
function _similar(a::AbstractArray, elt::Type, ax::Tuple)
return similar(a, elt, ax)
Expand Down Expand Up @@ -124,15 +135,15 @@ for op in (:+, :-)
end
end

function Base.map!(f::typeof(identity), dest::EyeKronecker, a::EyeKronecker)
function Base.map!(f::typeof(identity), dest::EyeKronecker, src::EyeKronecker)
map!(f, dest.b, src.b)
return dest
end
function Base.map!(f::typeof(identity), dest::KroneckerEye, a::KroneckerEye)
function Base.map!(f::typeof(identity), dest::KroneckerEye, src::KroneckerEye)
map!(f, dest.a, src.a)
return dest
end
function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye)
function Base.map!(::typeof(identity), dest::EyeEye, src::EyeEye)
return error("Can't write in-place.")
end
for f in [:+, :-]
Expand Down
13 changes: 11 additions & 2 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Allows customizing for `FillArrays.Eye`.
function _convert(A::Type{<:AbstractArray}, a::AbstractArray)
return convert(A, a)
end

struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N}
a::A
b::B
Expand All @@ -9,11 +14,14 @@ function KroneckerArray(a::AbstractArray, b::AbstractArray)
)
end
elt = promote_type(eltype(a), eltype(b))
return KroneckerArray(convert(AbstractArray{elt}, a), convert(AbstractArray{elt}, b))
return KroneckerArray(_convert(AbstractArray{elt}, a), _convert(AbstractArray{elt}, b))
end
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}

arg1(a::KroneckerArray) = a.a
arg2(a::KroneckerArray) = a.b

using Adapt: Adapt, adapt
_adapt(to, a::AbstractArray) = adapt(to, a)
Adapt.adapt_structure(to, a::KroneckerArray) = _adapt(to, a.a) ⊗ _adapt(to, a.b)
Expand Down Expand Up @@ -106,7 +114,8 @@ end
kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b)

Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b)
# Eagerly collect arguments to make more general on GPU.
Base.collect(a::KroneckerArray) = kron_nd(collect(a.a), collect(a.b))

function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
return convert(Array{T,N}, collect(a))
Expand Down
41 changes: 29 additions & 12 deletions test/test_blocksparsearrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,29 @@ arrayts = (Array, JLArray)
@test_broken inv(a)
end

if (VERSION ≤ v"1.11-" && arrayt === Array && elt <: Complex) ||
(arrayt === Array && elt <: Real)
u, s, v = svd_compact(a)
@test Array(u * s * v) ≈ Array(a)
else
# Broken on GPU and for complex, investigate.
@test_broken svd_compact(a)
end

# Broken operations
@test_broken exp(a)
@test_broken svd_compact(a)
@test_broken a[Block.(1:2), Block(2)]
end

@testset "BlockSparseArraysExt, EyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
arrayts,
elt in elts

if arrayt == JLArray
# TODO: Collecting to `Array` is broken for GPU arrays so a lot of tests
# are broken, look into fixing that.
continue
end

dev = adapt(arrayt)
r = @constinferred blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => Eye{elt}(2, 2) ⊗ randn(elt, 2, 2),
Block(2, 2) => Eye{elt}(3, 3) ⊗ randn(elt, 3, 3),
Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2)),
Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3)),
)
a = @constinferred dev(blocksparse(d, r, r))
@test sprint(show, a) == sprint(show, Array(a))
Expand Down Expand Up @@ -126,11 +128,26 @@ end

@test @constinferred(norm(a)) ≈ norm(Array(a))

b = @constinferred exp(a)
@test Array(b) ≈ exp(Array(a))
if arrayt === Array
b = @constinferred exp(a)
@test Array(b) ≈ exp(Array(a))
else
@test_broken exp(a)
end

if VERSION < v"1.11-" && elt <: Complex
# Broken because of type stability issue in Julia v1.10.
@test_broken svd_compact(a)
elseif arrayt === Array
u, s, v = svd_compact(a)
@test u * s * v ≈ a
@test blocktype(u) === blocktype(a)
@test blocktype(v) === blocktype(a)
else
@test_broken svd_compact(a)
end

# Broken operations
@test_broken inv(a)
@test_broken svd_compact(a)
@test_broken a[Block.(1:2), Block(2)]
end
Loading