Skip to content

Commit 1fdc61f

Browse files
authored
Rewrite orthogonalization from recursive to iterative (#241)
* Rewrite orthogonalization from recursive to iterative * Add testcase
1 parent 0900ee7 commit 1fdc61f

File tree

3 files changed

+53
-15
lines changed

3 files changed

+53
-15
lines changed

src/states/orthoview.jl

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,42 @@ struct CView{S,E,N} <: AbstractArray{E,N}
4747
end
4848

4949
function Base.getindex(v::CView{<:FiniteMPS,E}, i::Int)::E where {E}
50-
if ismissing(v.parent.Cs[i + 1])
51-
if i == 0 || !ismissing(v.parent.ALs[i])
52-
(v.parent.Cs[i + 1], temp) = rightorth(_transpose_tail(v.parent.AC[i + 1]);
53-
alg=LQpos())
54-
v.parent.ARs[i + 1] = _transpose_front(temp)
55-
else
56-
(v.parent.ALs[i], v.parent.Cs[i + 1]) = leftorth(v.parent.AC[i]; alg=QRpos())
50+
ismissing(v.parent.Cs[i + 1]) || return v.parent.Cs[i + 1]
51+
52+
if i == 0 || !ismissing(v.parent.ALs[i]) # center is too far right
53+
center = findfirst(!ismissing, v.parent.ACs)
54+
if isnothing(center)
55+
center = findfirst(!ismissing, v.parent.Cs)
56+
@assert !isnothing(center) "Invalid state"
57+
center -= 1 # offset in Cs vs C
58+
@assert !ismissing(v.parent.ALs[center]) "Invalid state"
59+
v.parent.ACs[center] = _mul_tail(v.parent.ALs[center], v.parent.Cs[center + 1])
60+
end
61+
62+
for j in Iterators.reverse((i + 1):center)
63+
v.parent.Cs[j], tmp = rightorth(_transpose_tail(v.parent.ACs[j]); alg=LQpos())
64+
v.parent.ARs[j] = _transpose_front(tmp)
65+
if j != i + 1 # last AC not needed
66+
v.parent.ACs[j - 1] = _mul_tail(v.parent.ALs[j - 1], v.parent.Cs[j])
67+
end
68+
end
69+
else # center is too far left
70+
center = findlast(!ismissing, v.parent.ACs)
71+
if isnothing(center)
72+
center = findlast(!ismissing, v.parent.Cs)
73+
@assert !isnothing(center) "Invalid state"
74+
@assert !ismissing(v.parent.ARs[center]) "Invalid state"
75+
v.parent.ACs[center] = _mul_front(v.parent.Cs[center], v.parent.ARs[center])
76+
end
77+
78+
for j in center:i
79+
v.parent.ALs[j], v.parent.Cs[j + 1] = leftorth(v.parent.ACs[j]; alg=QRpos())
80+
if j != i # last AC not needed
81+
v.parent.ACs[j + 1] = _mul_front(v.parent.Cs[j + 1], v.parent.ARs[j + 1])
82+
end
5783
end
5884
end
85+
5986
return v.parent.Cs[i + 1]
6087
end
6188

@@ -93,15 +120,16 @@ struct ACView{S,E,N} <: AbstractArray{E,N}
93120
end
94121

95122
function Base.getindex(v::ACView{<:FiniteMPS,E}, i::Int)::E where {E}
96-
if ismissing(v.parent.ACs[i]) && !ismissing(v.parent.ARs[i])
97-
c = v.parent.C[i - 1]
98-
ar = v.parent.ARs[i]
99-
v.parent.ACs[i] = _transpose_front(c * _transpose_tail(ar))
100-
elseif ismissing(v.parent.ACs[i]) && !ismissing(v.parent.ALs[i])
101-
c = v.parent.C[i]
102-
al = v.parent.ALs[i]
103-
v.parent.ACs[i] = al * c
123+
ismissing(v.parent.ACs[i]) || return v.parent.ACs[i]
124+
125+
if !ismissing(v.parent.ARs[i]) # center is too far left
126+
v.parent.ACs[i] = _mul_front(v.parent.C[i - 1], v.parent.ARs[i])
127+
elseif !ismissing(v.parent.ALs[i])
128+
v.parent.ACs[i] = _mul_tail(v.parent.ALs[i], v.parent.C[i])
129+
else
130+
error("Invalid state")
104131
end
132+
105133
return v.parent.ACs[i]
106134
end
107135

src/utility/utility.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ function _transpose_as(t1::AbstractTensorMap, t2::AbstractTensorMap)
88
return repartition(t1, numout(t2), numin(t2))
99
end
1010

11+
_mul_front(C, A) = _transpose_front(C * _transpose_tail(A))
12+
_mul_tail(A, C) = A * C
13+
1114
function _similar_tail(A::AbstractTensorMap)
1215
cod = _firstspace(A)
1316
dom = (dual(_lastspace(A)), dual.(space.(Ref(A), reverse(2:(numind(A) - 1))))...)

test/other.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ end
5454
@test ψ2 isa InfiniteMPS
5555
@test norm(ψ2) 1
5656
end
57+
58+
@testset "Stackoverflow with gauging" begin
59+
ψ = FiniteMPS(10_000, ℂ^2, ℂ^1)
60+
@test ψ.AR[1] isa MPSKit.MPSTensor
61+
ψ.AC[1] = -ψ.AR[1] # force invalidation of ALs
62+
@test ψ.AL[end] isa MPSKit.MPSTensor
63+
end
5764
end
5865

5966
end

0 commit comments

Comments
 (0)