Skip to content

Towards truncated block sparse factorizations #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.21"
version = "0.1.22"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
[compat]
Adapt = "4.3.0"
BlockArrays = "1.6"
BlockSparseArrays = "0.7.21"
BlockSparseArrays = "0.7.22"
DerivableInterfaces = "0.5.0"
DiagonalArrays = "0.3.5"
FillArrays = "1.13.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,10 @@ function (f::GetUnstoredBlock)(
return error("Not implemented.")
end

using BlockSparseArrays: BlockSparseArrays
using KroneckerArrays: KroneckerArrays, KroneckerVector
function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I)
return KroneckerArrays.to_truncated_indices(values, I)
end

end
63 changes: 58 additions & 5 deletions src/cartesianproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
arg1(a::CartesianProduct) = a.a
arg2(a::CartesianProduct) = a.b

Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a))

function Base.show(io::IO, a::CartesianProduct)
print(io, a.a, " × ", a.b)
return nothing
end
function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct)
show(io, a)
return nothing
end

×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
Expand All @@ -42,8 +48,38 @@ function Base.getindex(a::CartesianProduct, i::CartesianPair)
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
end
function Base.getindex(a::CartesianProduct, i::Int)
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
return a[I[1] × I[2]]
I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i])
return a[I[2] × I[1]]
end

struct CartesianProductVector{T,P<:CartesianProduct,V<:AbstractVector{T}} <:
AbstractVector{T}
product::P
values::V
end
cartesianproduct(r::CartesianProductVector) = getfield(r, :product)
unproduct(r::CartesianProductVector) = getfield(r, :values)
Base.length(a::CartesianProductVector) = length(unproduct(a))
Base.size(a::CartesianProductVector) = (length(a),)
function Base.axes(r::CartesianProductVector)
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
end
function Base.copy(a::CartesianProductVector)
return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a)))
end
function Base.getindex(r::CartesianProductVector, i::Integer)
return unproduct(r)[i]
end

function Base.show(io::IO, a::CartesianProductVector)
show(io, unproduct(a))
return nothing
end
function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductVector)
show(io, mime, cartesianproduct(a))
println(io)
show(io, mime, unproduct(a))
return nothing
end

struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
Expand All @@ -60,13 +96,24 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))

function Base.show(io::IO, a::CartesianProductUnitRange)
show(io, unproduct(a))
return nothing
end
function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductUnitRange)
show(io, mime, cartesianproduct(a))
println(io)
show(io, mime, unproduct(a))
return nothing
end

function CartesianProductUnitRange(p::CartesianProduct)
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
end
function CartesianProductUnitRange(a, b)
return CartesianProductUnitRange(a × b)
end
to_product_indices(a::AbstractUnitRange) = a
to_product_indices(a::AbstractVector) = a
to_product_indices(i::Integer) = Base.OneTo(i)
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
function cartesianrange(p::CartesianPair)
Expand Down Expand Up @@ -94,10 +141,16 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
end

function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct)
prod = cartesianproduct(a)
prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)]
return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I))
end

# Reverse map from CartesianPair to linear index in the range.
function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair)
i′ = (findfirst(==(arg1(i)), arg1(inds)), findfirst(==(arg2(i)), arg2(inds)))
return inds[LinearIndices((length(arg1(inds)), length(arg2(inds))))[i′...]]
i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds)))
return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]]
end

using Base.Broadcast: DefaultArrayStyle
Expand Down
6 changes: 6 additions & 0 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

_getindex(a::Eye, I1::Colon, I2::Colon) = a
_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a
_getindex(a::Eye, I1::Colon, I2::Base.Slice) = a
_view(a::Eye, I1::Colon, I2::Colon) = a
_view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
_view(a::Eye, I1::Base.Slice, I2::Colon) = a
_view(a::Eye, I1::Colon, I2::Base.Slice) = a

# Like `adapt` but preserves `Eye`.
_adapt(to, a::Eye) = a
Expand Down
42 changes: 24 additions & 18 deletions src/fillarrays/matrixalgebrakit_truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,40 @@ const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVe
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}

function MatrixAlgebraKit.findtruncated(
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(cartesianproduct(only(axes(values))))[I]
I_data = unique(map(arg1, prods))
axis(a) = only(axes(a))

# Convert indices determined with a generic call to `findtruncated` to indices
# more suited for a KroneckerVector.
function to_truncated_indices(values::OnesKroneckerVector, I)
prods = cartesianproduct(axis(values))[I]
I_id = only(to_indices(arg1(values), (:,)))
I_data = unique(arg2.(prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> arg1(x) == i, prods) == length(arg1(values))
return count(x -> arg2(x) == i, prods) == length(arg2(values))
end
return (:) × I_data
return I_id × I_data
end
function MatrixAlgebraKit.findtruncated(
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
)
I = findtruncated(Vector(values), strategy.strategy)
prods = collect(cartesianproduct(only(axes(values))))[I]
I_data = unique(map(x -> arg2(x), prods))
function to_truncated_indices(values::KroneckerOnesVector, I)
#I = findtruncated(Vector(values), strategy.strategy)
prods = cartesianproduct(axis(values))[I]
I_data = unique(arg1.(prods))
# Drop truncations that occur within the identity.
I_data = filter(I_data) do i
return count(x -> arg2(x) == i, prods) == length(arg2(values))
return count(x -> arg1(x) == i, prods) == length(arg2(values))
end
return I_data × (:)
I_id = only(to_indices(arg2(values), (:,)))
return I_data × I_id
end
function to_truncated_indices(values::OnesVectorOnesVector, I)
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
end

function MatrixAlgebraKit.findtruncated(
values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy
values::KroneckerVector, strategy::KroneckerTruncationStrategy
)
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
I = findtruncated(Vector(values), strategy.strategy)
return to_truncated_indices(values, I)
end

for f in [:eig_trunc!, :eigh_trunc!]
Expand Down
19 changes: 14 additions & 5 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,22 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {
return a[I′...]
end

# Indexing logic.
function Base.to_indices(
a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg}
)
I1 = to_indices(arg1(a), arg1.(inds), arg1.(I))
I2 = to_indices(arg2(a), arg2.(inds), arg2.(I))
return I1 .× I2
end

# Allow customizing for `FillArrays.Eye`.
_getindex(a::AbstractArray, I...) = a[I...]
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...)
end
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...)
function Base.getindex(
a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N}
) where {N}
I′ = to_indices(a, I)
return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...)
end
# Fix ambigiuity error.
Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]
Expand Down
18 changes: 9 additions & 9 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
p = [1, 2] × [3, 4, 5]
@test length(p) == 6
@test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5]
@test collect(p) == [1 × 3, 1 × 4, 1 × 5, 2 × 3, 2 × 4, 2 × 5]

r = @constinferred cartesianrange(2, 3)
@test r ===
Expand All @@ -39,10 +39,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test first(r) == 1
@test last(r) == 6
@test r[1 × 1] == 1
@test r[2 × 1] == 2
@test r[1 × 2] == 3
@test r[2 × 2] == 4
@test r[1 × 3] == 5
@test r[1 × 2] == 2
@test r[1 × 3] == 3
@test r[2 × 1] == 4
@test r[2 × 2] == 5
@test r[2 × 3] == 6

r = @constinferred(cartesianrange(2 × 3, 2:7))
Expand All @@ -53,10 +53,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test first(r) == 2
@test last(r) == 7
@test r[1 × 1] == 2
@test r[2 × 1] == 3
@test r[1 × 2] == 4
@test r[2 × 2] == 5
@test r[1 × 3] == 6
@test r[1 × 2] == 3
@test r[1 × 3] == 4
@test r[2 × 1] == 5
@test r[2 × 2] == 6
@test r[2 × 3] == 7

# Test high-dimensional materialization.
Expand Down
34 changes: 30 additions & 4 deletions test/test_blocksparsearrays.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using Adapt: adapt
using BlockArrays: Block, BlockRange, mortar
using BlockSparseArrays:
BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
BlockIndexVector, BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
using FillArrays: Eye, SquareEye
using JLArrays: JLArray
using KroneckerArrays: KroneckerArray, ⊗, ×
using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2
using LinearAlgebra: norm
using MatrixAlgebraKit: svd_compact
using Test: @test, @test_broken, @testset
Expand Down Expand Up @@ -48,7 +48,18 @@ arrayts = (Array, JLArray)
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] ==
a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]

# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3])
I = [I1, I2]
b = a[I, I]
@test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]]
@test iszero(b[Block(2, 1)])
@test iszero(b[Block(1, 2)])
@test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]]

# Slicing
r = blockrange([2 × 2, 3 × 3])
Expand Down Expand Up @@ -159,7 +170,22 @@ end
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] ==
a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]

# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3])
I = [I1, I2]
b = a[I, I]
@test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]]
@test arg1(b[Block(1, 1)]) isa Eye
@test iszero(b[Block(2, 1)])
@test arg1(b[Block(2, 1)]) isa Eye
@test iszero(b[Block(1, 2)])
@test arg1(b[Block(1, 2)]) isa Eye
@test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]]
@test arg1(b[Block(2, 2)]) isa Eye

# Slicing
r = blockrange([2 × 2, 3 × 3])
Expand Down
Loading