Skip to content

Commit 2f2caf3

Browse files
authored
Towards truncated block sparse factorizations (#28)
1 parent 6ff9f93 commit 2f2caf3

File tree

8 files changed

+149
-43
lines changed

8 files changed

+149
-43
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.21"
4+
version = "0.1.22"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
2323
[compat]
2424
Adapt = "4.3.0"
2525
BlockArrays = "1.6"
26-
BlockSparseArrays = "0.7.21"
26+
BlockSparseArrays = "0.7.22"
2727
DerivableInterfaces = "0.5.0"
2828
DiagonalArrays = "0.3.5"
2929
FillArrays = "1.13.0"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,10 @@ function (f::GetUnstoredBlock)(
9999
return error("Not implemented.")
100100
end
101101

102+
using BlockSparseArrays: BlockSparseArrays
103+
using KroneckerArrays: KroneckerArrays, KroneckerVector
104+
function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I)
105+
return KroneckerArrays.to_truncated_indices(values, I)
106+
end
107+
102108
end

src/cartesianproduct.jl

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
2626
arg1(a::CartesianProduct) = a.a
2727
arg2(a::CartesianProduct) = a.b
2828

29+
Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a))
30+
2931
function Base.show(io::IO, a::CartesianProduct)
3032
print(io, a.a, " × ", a.b)
3133
return nothing
3234
end
35+
function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct)
36+
show(io, a)
37+
return nothing
38+
end
3339

3440
×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
3541
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
@@ -42,8 +48,38 @@ function Base.getindex(a::CartesianProduct, i::CartesianPair)
4248
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
4349
end
4450
function Base.getindex(a::CartesianProduct, i::Int)
45-
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
46-
return a[I[1] × I[2]]
51+
I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i])
52+
return a[I[2] × I[1]]
53+
end
54+
55+
struct CartesianProductVector{T,P<:CartesianProduct,V<:AbstractVector{T}} <:
56+
AbstractVector{T}
57+
product::P
58+
values::V
59+
end
60+
cartesianproduct(r::CartesianProductVector) = getfield(r, :product)
61+
unproduct(r::CartesianProductVector) = getfield(r, :values)
62+
Base.length(a::CartesianProductVector) = length(unproduct(a))
63+
Base.size(a::CartesianProductVector) = (length(a),)
64+
function Base.axes(r::CartesianProductVector)
65+
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
66+
end
67+
function Base.copy(a::CartesianProductVector)
68+
return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a)))
69+
end
70+
function Base.getindex(r::CartesianProductVector, i::Integer)
71+
return unproduct(r)[i]
72+
end
73+
74+
function Base.show(io::IO, a::CartesianProductVector)
75+
show(io, unproduct(a))
76+
return nothing
77+
end
78+
function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductVector)
79+
show(io, mime, cartesianproduct(a))
80+
println(io)
81+
show(io, mime, unproduct(a))
82+
return nothing
4783
end
4884

4985
struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
@@ -60,13 +96,24 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
6096
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
6197
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))
6298

99+
function Base.show(io::IO, a::CartesianProductUnitRange)
100+
show(io, unproduct(a))
101+
return nothing
102+
end
103+
function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductUnitRange)
104+
show(io, mime, cartesianproduct(a))
105+
println(io)
106+
show(io, mime, unproduct(a))
107+
return nothing
108+
end
109+
63110
function CartesianProductUnitRange(p::CartesianProduct)
64111
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
65112
end
66113
function CartesianProductUnitRange(a, b)
67114
return CartesianProductUnitRange(a × b)
68115
end
69-
to_product_indices(a::AbstractUnitRange) = a
116+
to_product_indices(a::AbstractVector) = a
70117
to_product_indices(i::Integer) = Base.OneTo(i)
71118
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
72119
function cartesianrange(p::CartesianPair)
@@ -94,10 +141,16 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte
94141
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
95142
end
96143

144+
function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct)
145+
prod = cartesianproduct(a)
146+
prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)]
147+
return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I))
148+
end
149+
97150
# Reverse map from CartesianPair to linear index in the range.
98151
function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair)
99-
i′ = (findfirst(==(arg1(i)), arg1(inds)), findfirst(==(arg2(i)), arg2(inds)))
100-
return inds[LinearIndices((length(arg1(inds)), length(arg2(inds))))[i′...]]
152+
i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds)))
153+
return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]]
101154
end
102155

103156
using Base.Broadcast: DefaultArrayStyle

src/fillarrays/kroneckerarray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr
2222
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
2323

2424
_getindex(a::Eye, I1::Colon, I2::Colon) = a
25+
_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
26+
_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a
27+
_getindex(a::Eye, I1::Colon, I2::Base.Slice) = a
2528
_view(a::Eye, I1::Colon, I2::Colon) = a
29+
_view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
30+
_view(a::Eye, I1::Base.Slice, I2::Colon) = a
31+
_view(a::Eye, I1::Colon, I2::Base.Slice) = a
2632

2733
# Like `adapt` but preserves `Eye`.
2834
_adapt(to, a::Eye) = a

src/fillarrays/matrixalgebrakit_truncate.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,40 @@ const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVe
2020
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
2121
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
2222

23-
function MatrixAlgebraKit.findtruncated(
24-
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
25-
)
26-
I = findtruncated(Vector(values), strategy.strategy)
27-
prods = collect(cartesianproduct(only(axes(values))))[I]
28-
I_data = unique(map(arg1, prods))
23+
axis(a) = only(axes(a))
24+
25+
# Convert indices determined with a generic call to `findtruncated` to indices
26+
# more suited for a KroneckerVector.
27+
function to_truncated_indices(values::OnesKroneckerVector, I)
28+
prods = cartesianproduct(axis(values))[I]
29+
I_id = only(to_indices(arg1(values), (:,)))
30+
I_data = unique(arg2.(prods))
2931
# Drop truncations that occur within the identity.
3032
I_data = filter(I_data) do i
31-
return count(x -> arg1(x) == i, prods) == length(arg1(values))
33+
return count(x -> arg2(x) == i, prods) == length(arg2(values))
3234
end
33-
return (:) × I_data
35+
return I_id × I_data
3436
end
35-
function MatrixAlgebraKit.findtruncated(
36-
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
37-
)
38-
I = findtruncated(Vector(values), strategy.strategy)
39-
prods = collect(cartesianproduct(only(axes(values))))[I]
40-
I_data = unique(map(x -> arg2(x), prods))
37+
function to_truncated_indices(values::KroneckerOnesVector, I)
38+
#I = findtruncated(Vector(values), strategy.strategy)
39+
prods = cartesianproduct(axis(values))[I]
40+
I_data = unique(arg1.(prods))
4141
# Drop truncations that occur within the identity.
4242
I_data = filter(I_data) do i
43-
return count(x -> arg2(x) == i, prods) == length(arg2(values))
43+
return count(x -> arg1(x) == i, prods) == length(arg2(values))
4444
end
45-
return I_data × (:)
45+
I_id = only(to_indices(arg2(values), (:,)))
46+
return I_data × I_id
47+
end
48+
function to_truncated_indices(values::OnesVectorOnesVector, I)
49+
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
4650
end
51+
4752
function MatrixAlgebraKit.findtruncated(
48-
values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy
53+
values::KroneckerVector, strategy::KroneckerTruncationStrategy
4954
)
50-
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
55+
I = findtruncated(Vector(values), strategy.strategy)
56+
return to_truncated_indices(values, I)
5157
end
5258

5359
for f in [:eig_trunc!, :eigh_trunc!]

src/kroneckerarray.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,22 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {
167167
return a[I′...]
168168
end
169169

170+
# Indexing logic.
171+
function Base.to_indices(
172+
a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg}
173+
)
174+
I1 = to_indices(arg1(a), arg1.(inds), arg1.(I))
175+
I2 = to_indices(arg2(a), arg2.(inds), arg2.(I))
176+
return I1 I2
177+
end
178+
170179
# Allow customizing for `FillArrays.Eye`.
171180
_getindex(a::AbstractArray, I...) = a[I...]
172-
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
173-
return _getindex(arg1(a), arg1.(I)...) _getindex(arg2(a), arg2.(I)...)
174-
end
175-
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
176-
return _getindex(arg1(a), arg1.(I)...) _getindex(arg2(a), arg2.(I)...)
181+
function Base.getindex(
182+
a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N}
183+
) where {N}
184+
I′ = to_indices(a, I)
185+
return _getindex(arg1(a), arg1.(I)...) _getindex(arg2(a), arg2.(I)...)
177186
end
178187
# Fix ambigiuity error.
179188
Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]

test/test_basics.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
2626
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
2727
p = [1, 2] × [3, 4, 5]
2828
@test length(p) == 6
29-
@test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5]
29+
@test collect(p) == [1 × 3, 1 × 4, 1 × 5, 2 × 3, 2 × 4, 2 × 5]
3030

3131
r = @constinferred cartesianrange(2, 3)
3232
@test r ===
@@ -39,10 +39,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
3939
@test first(r) == 1
4040
@test last(r) == 6
4141
@test r[1 × 1] == 1
42-
@test r[2 × 1] == 2
43-
@test r[1 × 2] == 3
44-
@test r[2 × 2] == 4
45-
@test r[1 × 3] == 5
42+
@test r[1 × 2] == 2
43+
@test r[1 × 3] == 3
44+
@test r[2 × 1] == 4
45+
@test r[2 × 2] == 5
4646
@test r[2 × 3] == 6
4747

4848
r = @constinferred(cartesianrange(2 × 3, 2:7))
@@ -53,10 +53,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
5353
@test first(r) == 2
5454
@test last(r) == 7
5555
@test r[1 × 1] == 2
56-
@test r[2 × 1] == 3
57-
@test r[1 × 2] == 4
58-
@test r[2 × 2] == 5
59-
@test r[1 × 3] == 6
56+
@test r[1 × 2] == 3
57+
@test r[1 × 3] == 4
58+
@test r[2 × 1] == 5
59+
@test r[2 × 2] == 6
6060
@test r[2 × 3] == 7
6161

6262
# Test high-dimensional materialization.

test/test_blocksparsearrays.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using Adapt: adapt
22
using BlockArrays: Block, BlockRange, mortar
33
using BlockSparseArrays:
4-
BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
4+
BlockIndexVector, BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
55
using FillArrays: Eye, SquareEye
66
using JLArrays: JLArray
7-
using KroneckerArrays: KroneckerArray, , ×
7+
using KroneckerArrays: KroneckerArray, , ×, arg1, arg2
88
using LinearAlgebra: norm
99
using MatrixAlgebraKit: svd_compact
1010
using Test: @test, @test_broken, @testset
@@ -48,7 +48,18 @@ arrayts = (Array, JLArray)
4848
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
4949
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
5050
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
51-
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
51+
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] ==
52+
a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
53+
54+
# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
55+
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
56+
I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3])
57+
I = [I1, I2]
58+
b = a[I, I]
59+
@test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]]
60+
@test iszero(b[Block(2, 1)])
61+
@test iszero(b[Block(1, 2)])
62+
@test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]]
5263

5364
# Slicing
5465
r = blockrange([2 × 2, 3 × 3])
@@ -159,7 +170,22 @@ end
159170
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
160171
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
161172
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
162-
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
173+
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] ==
174+
a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
175+
176+
# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
177+
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
178+
I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3])
179+
I = [I1, I2]
180+
b = a[I, I]
181+
@test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]]
182+
@test arg1(b[Block(1, 1)]) isa Eye
183+
@test iszero(b[Block(2, 1)])
184+
@test arg1(b[Block(2, 1)]) isa Eye
185+
@test iszero(b[Block(1, 2)])
186+
@test arg1(b[Block(1, 2)]) isa Eye
187+
@test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]]
188+
@test arg1(b[Block(2, 2)]) isa Eye
163189

164190
# Slicing
165191
r = blockrange([2 × 2, 3 × 3])

0 commit comments

Comments
 (0)