Skip to content

Commit 0512647

Browse files
committed
[WIP] Towards truncated block sparse factorizations
1 parent 6ff9f93 commit 0512647

File tree

6 files changed

+100
-34
lines changed

6 files changed

+100
-34
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.21"
4+
version = "0.1.22"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
2323
[compat]
2424
Adapt = "4.3.0"
2525
BlockArrays = "1.6"
26-
BlockSparseArrays = "0.7.21"
26+
BlockSparseArrays = "0.7.22"
2727
DerivableInterfaces = "0.5.0"
2828
DiagonalArrays = "0.3.5"
2929
FillArrays = "1.13.0"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,10 @@ function (f::GetUnstoredBlock)(
9999
return error("Not implemented.")
100100
end
101101

102+
using BlockSparseArrays: BlockSparseArrays
103+
using KroneckerArrays: KroneckerArrays, KroneckerVector
104+
function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I)
105+
return KroneckerArrays.to_truncated_indices(values, I)
106+
end
107+
102108
end

src/cartesianproduct.jl

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
2626
arg1(a::CartesianProduct) = a.a
2727
arg2(a::CartesianProduct) = a.b
2828

29+
Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a))
30+
2931
function Base.show(io::IO, a::CartesianProduct)
3032
print(io, a.a, " × ", a.b)
3133
return nothing
3234
end
35+
function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct)
36+
show(io, a)
37+
return nothing
38+
end
3339

3440
×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
3541
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
@@ -42,8 +48,38 @@ function Base.getindex(a::CartesianProduct, i::CartesianPair)
4248
return arg1(a)[arg1(i)] × arg2(a)[arg2(i)]
4349
end
4450
function Base.getindex(a::CartesianProduct, i::Int)
45-
I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i])
46-
return a[I[1] × I[2]]
51+
I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i])
52+
return a[I[2] × I[1]]
53+
end
54+
55+
struct CartesianProductVector{T,P<:CartesianProduct,V<:AbstractVector{T}} <:
56+
AbstractVector{T}
57+
product::P
58+
values::V
59+
end
60+
cartesianproduct(r::CartesianProductVector) = getfield(r, :product)
61+
unproduct(r::CartesianProductVector) = getfield(r, :values)
62+
Base.length(a::CartesianProductVector) = length(unproduct(a))
63+
Base.size(a::CartesianProductVector) = (length(a),)
64+
function Base.axes(r::CartesianProductVector)
65+
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
66+
end
67+
function Base.copy(a::CartesianProductVector)
68+
return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a)))
69+
end
70+
function Base.getindex(r::CartesianProductVector, i::Integer)
71+
return unproduct(r)[i]
72+
end
73+
74+
function Base.show(io::IO, a::CartesianProductVector)
75+
show(io, unproduct(a))
76+
return nothing
77+
end
78+
function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductVector)
79+
show(io, mime, cartesianproduct(a))
80+
println(io)
81+
show(io, mime, unproduct(a))
82+
return nothing
4783
end
4884

4985
struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <:
@@ -60,13 +96,24 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
6096
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
6197
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))
6298

99+
function Base.show(io::IO, a::CartesianProductUnitRange)
100+
show(io, unproduct(a))
101+
return nothing
102+
end
103+
function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductUnitRange)
104+
show(io, mime, cartesianproduct(a))
105+
println(io)
106+
show(io, mime, unproduct(a))
107+
return nothing
108+
end
109+
63110
function CartesianProductUnitRange(p::CartesianProduct)
64111
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
65112
end
66113
function CartesianProductUnitRange(a, b)
67114
return CartesianProductUnitRange(a × b)
68115
end
69-
to_product_indices(a::AbstractUnitRange) = a
116+
to_product_indices(a::AbstractVector) = a
70117
to_product_indices(i::Integer) = Base.OneTo(i)
71118
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
72119
function cartesianrange(p::CartesianPair)
@@ -94,10 +141,16 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte
94141
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
95142
end
96143

144+
function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct)
145+
prod = cartesianproduct(a)
146+
prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)]
147+
return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I))
148+
end
149+
97150
# Reverse map from CartesianPair to linear index in the range.
98151
function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair)
99-
i′ = (findfirst(==(arg1(i)), arg1(inds)), findfirst(==(arg2(i)), arg2(inds)))
100-
return inds[LinearIndices((length(arg1(inds)), length(arg2(inds))))[i′...]]
152+
i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds)))
153+
return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]]
101154
end
102155

103156
using Base.Broadcast: DefaultArrayStyle

src/fillarrays/kroneckerarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,
2323

2424
_getindex(a::Eye, I1::Colon, I2::Colon) = a
2525
_view(a::Eye, I1::Colon, I2::Colon) = a
26+
_view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
2627

2728
# Like `adapt` but preserves `Eye`.
2829
_adapt(to, a::Eye) = a

src/fillarrays/matrixalgebrakit_truncate.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,40 @@ const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVe
2020
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
2121
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
2222

23-
function MatrixAlgebraKit.findtruncated(
24-
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
25-
)
26-
I = findtruncated(Vector(values), strategy.strategy)
27-
prods = collect(cartesianproduct(only(axes(values))))[I]
28-
I_data = unique(map(arg1, prods))
23+
axis(a) = only(axes(a))
24+
25+
# Convert indices determined with a generic call to `findtruncated` to indices
26+
# more suited for a KroneckerVector.
27+
function to_truncated_indices(values::OnesKroneckerVector, I)
28+
prods = cartesianproduct(axis(values))[I]
29+
I_id = only(to_indices(arg1(values), (:,)))
30+
I_data = unique(arg2.(prods))
2931
# Drop truncations that occur within the identity.
3032
I_data = filter(I_data) do i
31-
return count(x -> arg1(x) == i, prods) == length(arg1(values))
33+
return count(x -> arg2(x) == i, prods) == length(arg2(values))
3234
end
33-
return (:) × I_data
35+
return I_id × I_data
3436
end
35-
function MatrixAlgebraKit.findtruncated(
36-
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
37-
)
38-
I = findtruncated(Vector(values), strategy.strategy)
39-
prods = collect(cartesianproduct(only(axes(values))))[I]
40-
I_data = unique(map(x -> arg2(x), prods))
37+
function to_truncated_indices(values::KroneckerOnesVector, I)
38+
#I = findtruncated(Vector(values), strategy.strategy)
39+
prods = cartesianproduct(axis(values))[I]
40+
I_data = unique(arg1.(prods))
4141
# Drop truncations that occur within the identity.
4242
I_data = filter(I_data) do i
43-
return count(x -> arg2(x) == i, prods) == length(arg2(values))
43+
return count(x -> arg1(x) == i, prods) == length(arg2(values))
4444
end
45-
return I_data × (:)
45+
I_id = only(to_indices(arg2(values), (:,)))
46+
return I_data × I_id
47+
end
48+
function to_truncated_indices(values::OnesVectorOnesVector, I)
49+
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
4650
end
51+
4752
function MatrixAlgebraKit.findtruncated(
48-
values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy
53+
values::KroneckerVector, strategy::KroneckerTruncationStrategy
4954
)
50-
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
55+
I = findtruncated(Vector(values), strategy.strategy)
56+
return to_truncated_indices(values, I)
5157
end
5258

5359
for f in [:eig_trunc!, :eigh_trunc!]

test/test_basics.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
2626
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
2727
p = [1, 2] × [3, 4, 5]
2828
@test length(p) == 6
29-
@test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5]
29+
@test collect(p) == [1 × 3, 1 × 4, 1 × 5, 2 × 3, 2 × 4, 2 × 5]
3030

3131
r = @constinferred cartesianrange(2, 3)
3232
@test r ===
@@ -39,10 +39,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
3939
@test first(r) == 1
4040
@test last(r) == 6
4141
@test r[1 × 1] == 1
42-
@test r[2 × 1] == 2
43-
@test r[1 × 2] == 3
44-
@test r[2 × 2] == 4
45-
@test r[1 × 3] == 5
42+
@test r[1 × 2] == 2
43+
@test r[1 × 3] == 3
44+
@test r[2 × 1] == 4
45+
@test r[2 × 2] == 5
4646
@test r[2 × 3] == 6
4747

4848
r = @constinferred(cartesianrange(2 × 3, 2:7))
@@ -53,10 +53,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
5353
@test first(r) == 2
5454
@test last(r) == 7
5555
@test r[1 × 1] == 2
56-
@test r[2 × 1] == 3
57-
@test r[1 × 2] == 4
58-
@test r[2 × 2] == 5
59-
@test r[1 × 3] == 6
56+
@test r[1 × 2] == 3
57+
@test r[1 × 3] == 4
58+
@test r[2 × 1] == 5
59+
@test r[2 × 2] == 6
6060
@test r[2 × 3] == 7
6161

6262
# Test high-dimensional materialization.

0 commit comments

Comments
 (0)