Skip to content

Commit 208fb07

Browse files
committed
Revive truncation
1 parent 7bf3e18 commit 208fb07

File tree

4 files changed

+103
-104
lines changed

4 files changed

+103
-104
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/matrixalgebrakit.jl

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -209,36 +209,40 @@ struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
209209
strategy::T
210210
end
211211

212-
## using FillArrays: OnesVector
213-
## const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B}
214-
## const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
215-
## const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
212+
using FillArrays: OnesVector
213+
const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B}
214+
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
215+
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
216216

217217
axis(a) = only(axes(a))
218218

219-
## # Convert indices determined with a generic call to `findtruncated` to indices
220-
## # more suited for a KroneckerVector.
221-
## function to_truncated_indices(values::OnesKroneckerVector, I)
222-
## prods = cartesianproduct(axis(values))[I]
223-
## I_id = only(to_indices(arg1(values), (:,)))
224-
## I_data = unique(arg2.(prods))
225-
## # Drop truncations that occur within the identity.
226-
## I_data = filter(I_data) do i
227-
## return count(x -> arg2(x) == i, prods) == length(arg2(values))
228-
## end
229-
## return I_id × I_data
230-
## end
231-
## function to_truncated_indices(values::KroneckerOnesVector, I)
232-
## #I = findtruncated(Vector(values), strategy.strategy)
233-
## prods = cartesianproduct(axis(values))[I]
234-
## I_data = unique(arg1.(prods))
235-
## # Drop truncations that occur within the identity.
236-
## I_data = filter(I_data) do i
237-
## return count(x -> arg1(x) == i, prods) == length(arg2(values))
238-
## end
239-
## I_id = only(to_indices(arg2(values), (:,)))
240-
## return I_data × I_id
241-
## end
219+
# Convert indices determined with a generic call to `findtruncated` to indices
220+
# more suited for a KroneckerVector.
221+
function to_truncated_indices(values::OnesKroneckerVector, I)
222+
prods = cartesianproduct(axis(values))[I]
223+
I_id = only(to_indices(arg1(values), (:,)))
224+
I_data = unique(arg2.(prods))
225+
# Drop truncations that occur within the identity.
226+
I_data = filter(I_data) do i
227+
return count(x -> arg2(x) == i, prods) == length(arg2(values))
228+
end
229+
return I_id × I_data
230+
end
231+
function to_truncated_indices(values::KroneckerOnesVector, I)
232+
#I = findtruncated(Vector(values), strategy.strategy)
233+
prods = cartesianproduct(axis(values))[I]
234+
I_data = unique(arg1.(prods))
235+
# Drop truncations that occur within the identity.
236+
I_data = filter(I_data) do i
237+
return count(x -> arg1(x) == i, prods) == length(arg2(values))
238+
end
239+
I_id = only(to_indices(arg2(values), (:,)))
240+
return I_data × I_id
241+
end
242+
# Fix ambiguity error.
243+
function to_truncated_indices(values::OnesVectorOnesVector, I)
244+
return throw(ArgumentError("Not implemented"))
245+
end
242246
function to_truncated_indices(values::KroneckerVector, I)
243247
return throw(ArgumentError("Not implemented"))
244248
end

test/test_blocksparsearrays.jl

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -395,24 +395,22 @@ end
395395
@test b[Block(1)] == a[Block(1, 2)]
396396
@test b[Block(2)] == a[Block(2, 2)]
397397

398-
## TODO: Broken, fix and re-enable.
399-
@test_broken false
400-
## # svd_trunc
401-
## dev = adapt(arrayt)
402-
## r = @constinferred blockrange([2 × 2, 3 × 3])
403-
## rng = StableRNG(1234)
404-
## d = Dict(
405-
## Block(1, 1) => δ(elt, (2, 2)) ⊗ randn(rng, elt, 2, 2),
406-
## Block(2, 2) => δ(elt, (3, 3)) ⊗ randn(rng, elt, 3, 3),
407-
## )
408-
## a = @constinferred dev(blocksparse(d, (r, r)))
409-
## if arrayt === Array
410-
## u, s, v = svd_trunc(a; trunc=(; maxrank=6))
411-
## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5))
412-
## @test Matrix(u * s * v) ≈ u′ * s′ * v′
413-
## else
414-
## @test_broken svd_trunc(a; trunc=(; maxrank=6))
415-
## end
398+
# svd_trunc
399+
dev = adapt(arrayt)
400+
r = @constinferred blockrange([2 × 2, 3 × 3])
401+
rng = StableRNG(1234)
402+
d = Dict(
403+
Block(1, 1) => δ(elt, (2, 2)) randn(rng, elt, 2, 2),
404+
Block(2, 2) => δ(elt, (3, 3)) randn(rng, elt, 3, 3),
405+
)
406+
a = @constinferred dev(blocksparse(d, (r, r)))
407+
if arrayt === Array
408+
u, s, v = svd_trunc(a; trunc=(; maxrank=6))
409+
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5))
410+
@test Matrix(u * s * v) u′ * s′ * v′
411+
else
412+
@test_broken svd_trunc(a; trunc=(; maxrank=6))
413+
end
416414

417415
@testset "Block deficient" begin
418416
da = Dict(Block(1, 1) => δ(elt, (2, 2)) dev(randn(elt, 2, 2)))

test/test_matrixalgebrakit_delta.jl

Lines changed: 55 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -73,27 +73,26 @@ herm(a) = parent(hermitianpart(a))
7373
@test arguments(v, 2) isa DeltaMatrix{elt}
7474
end
7575

76-
## TODO: Broken, need to fix truncation.
77-
## for f in (eig_trunc, eigh_trunc)
78-
## a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3)))
79-
## d, v = f(a; trunc=(; maxrank=7))
80-
## @test a * v ≈ v * d
81-
## @test arguments(d, 1) isa DeltaMatrix
82-
## @test arguments(v, 1) isa DeltaMatrix
83-
## @test size(d) == (6, 6)
84-
## @test size(v) == (9, 6)
76+
for f in (eig_trunc, eigh_trunc)
77+
a = δ(3, 3) parent(hermitianpart(randn(3, 3)))
78+
d, v = f(a; trunc=(; maxrank=7))
79+
@test a * v v * d
80+
@test arguments(d, 1) isa DeltaMatrix
81+
@test arguments(v, 1) isa DeltaMatrix
82+
@test size(d) == (6, 6)
83+
@test size(v) == (9, 6)
8584

86-
## a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3)
87-
## d, v = f(a; trunc=(; maxrank=7))
88-
## @test a * v ≈ v * d
89-
## @test arguments(d, 2) isa DeltaMatrix
90-
## @test arguments(v, 2) isa DeltaMatrix
91-
## @test size(d) == (6, 6)
92-
## @test size(v) == (9, 6)
85+
a = parent(hermitianpart(randn(3, 3))) δ(3, 3)
86+
d, v = f(a; trunc=(; maxrank=7))
87+
@test a * v v * d
88+
@test arguments(d, 2) isa DeltaMatrix
89+
@test arguments(v, 2) isa DeltaMatrix
90+
@test size(d) == (6, 6)
91+
@test size(v) == (9, 6)
9392

94-
## a = δ(3, 3) ⊗ δ(3, 3)
95-
## @test_throws ArgumentError f(a)
96-
## end
93+
a = δ(3, 3) δ(3, 3)
94+
@test_throws ArgumentError f(a)
95+
end
9796

9897
for f in (eig_vals, eigh_vals)
9998
a = δ(3, 3) parent(hermitianpart(randn(3, 3)))
@@ -183,46 +182,44 @@ herm(a) = parent(hermitianpart(a))
183182
end
184183
end
185184

186-
## TODO: Need to implement truncation.
187-
## # svd_trunc
188-
## for elt in (Float32, ComplexF32)
189-
## a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3)
190-
## # TODO: Type inference is broken for `svd_trunc`,
191-
## # look into fixing it.
192-
## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
193-
## u, s, v = svd_trunc(a; trunc=(; maxrank=7))
194-
## @test eltype(u) === elt
195-
## @test eltype(s) === real(elt)
196-
## @test eltype(v) === elt
197-
## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
198-
## @test Matrix(u * s * v) ≈ u′ * s′ * v′
199-
## @test arguments(u, 1) isa DeltaMatrix{elt}
200-
## @test arguments(s, 1) isa DeltaMatrix{real(elt)}
201-
## @test arguments(v, 1) isa DeltaMatrix{elt}
202-
## @test size(u) == (9, 6)
203-
## @test size(s) == (6, 6)
204-
## @test size(v) == (6, 9)
205-
## end
185+
# svd_trunc
186+
for elt in (Float32, ComplexF32)
187+
a = δ(elt, 3, 3) randn(elt, 3, 3)
188+
# TODO: Type inference is broken for `svd_trunc`,
189+
# look into fixing it.
190+
# u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
191+
u, s, v = svd_trunc(a; trunc=(; maxrank=7))
192+
@test eltype(u) === elt
193+
@test eltype(s) === real(elt)
194+
@test eltype(v) === elt
195+
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
196+
@test Matrix(u * s * v) u′ * s′ * v′
197+
@test arguments(u, 1) isa DeltaMatrix{elt}
198+
@test arguments(s, 1) isa DeltaMatrix{real(elt)}
199+
@test arguments(v, 1) isa DeltaMatrix{elt}
200+
@test size(u) == (9, 6)
201+
@test size(s) == (6, 6)
202+
@test size(v) == (6, 9)
203+
end
206204

207-
## TODO: Need to implement truncation.
208-
## for elt in (Float32, ComplexF32)
209-
## a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3)
210-
## # TODO: Type inference is broken for `svd_trunc`,
211-
## # look into fixing it.
212-
## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
213-
## u, s, v = svd_trunc(a; trunc=(; maxrank=7))
214-
## @test eltype(u) === elt
215-
## @test eltype(s) === real(elt)
216-
## @test eltype(v) === elt
217-
## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
218-
## @test Matrix(u * s * v) ≈ u′ * s′ * v′
219-
## @test arguments(u, 2) isa DeltaMatrix{elt}
220-
## @test arguments(s, 2) isa DeltaMatrix{real(elt)}
221-
## @test arguments(v, 2) isa DeltaMatrix{elt}
222-
## @test size(u) == (9, 6)
223-
## @test size(s) == (6, 6)
224-
## @test size(v) == (6, 9)
225-
## end
205+
for elt in (Float32, ComplexF32)
206+
a = randn(elt, 3, 3) δ(elt, 3, 3)
207+
# TODO: Type inference is broken for `svd_trunc`,
208+
# look into fixing it.
209+
# u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
210+
u, s, v = svd_trunc(a; trunc=(; maxrank=7))
211+
@test eltype(u) === elt
212+
@test eltype(s) === real(elt)
213+
@test eltype(v) === elt
214+
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
215+
@test Matrix(u * s * v) u′ * s′ * v′
216+
@test arguments(u, 2) isa DeltaMatrix{elt}
217+
@test arguments(s, 2) isa DeltaMatrix{real(elt)}
218+
@test arguments(v, 2) isa DeltaMatrix{elt}
219+
@test size(u) == (9, 6)
220+
@test size(s) == (6, 6)
221+
@test size(v) == (6, 9)
222+
end
226223

227224
a = δ(3, 3) δ(3, 3)
228225
@test_broken svd_trunc(a)

0 commit comments

Comments
 (0)