Skip to content

Commit f471f48

Browse files
committed
Rewrite orthogonalization from recursive to iterative
1 parent ab8e4ea commit f471f48

File tree

2 files changed

+47
-15
lines changed

2 files changed

+47
-15
lines changed

src/states/orthoview.jl

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,43 @@ struct CView{S,E,N} <: AbstractArray{E,N}
4747
end
4848

4949
function Base.getindex(v::CView{<:FiniteMPS,E}, i::Int)::E where {E}
50-
if ismissing(v.parent.Cs[i + 1])
51-
if i == 0 || !ismissing(v.parent.ALs[i])
52-
(v.parent.Cs[i + 1], temp) = rightorth(_transpose_tail(v.parent.AC[i + 1]);
53-
alg=LQpos())
54-
v.parent.ARs[i + 1] = _transpose_front(temp)
55-
else
56-
(v.parent.ALs[i], v.parent.Cs[i + 1]) = leftorth(v.parent.AC[i]; alg=QRpos())
50+
ismissing(v.parent.Cs[i + 1]) || return v.parent.Cs[i + 1]
51+
52+
if i == 0 || !ismissing(v.parent.ALs[i]) # center is too far right
53+
center = findfirst(!ismissing, v.parent.ACs)
54+
if isnothing(center)
55+
center = findfirst(!ismissing, v.parent.Cs)
56+
@assert !isnothing(center) "Invalid state"
57+
center -= 1 # offset in Cs vs C
58+
@assert !ismissing(v.parent.ALs[center]) "Invalid state"
59+
v.parent.ACs[center] = _mul_tail(v.parent.ALs[center], v.parent.Cs[center + 1])
60+
end
61+
62+
for j in Iterators.reverse((i + 1):center)
63+
v.parent.Cs[j], tmp = rightorth!(_transpose_tail(v.parent.ACs[j]); alg=LQpos())
64+
v.parent.ARs[j] = _transpose_front(tmp)
65+
if j != i + 1 # last AC not needed
66+
v.parent.ACs[j - 1] = _mul_tail(v.parent.ALs[j - 1], v.parent.Cs[j])
67+
end
68+
end
69+
else # center is too far left
70+
center = findlast(!ismissing, v.parent.ACs)
71+
if isnothing(center)
72+
center = findlast(!ismissing, v.parent.Cs)
73+
@assert !isnothing(center) "Invalid state"
74+
center -= 1 # offset in Cs vs C
75+
@assert !ismissing(v.parent.ARs[center]) "Invalid state"
76+
v.parent.ACs[center] = _mul_front(v.parent.Cs[center + 1], v.parent.ARs[center])
77+
end
78+
79+
for j in center:i
80+
v.parent.ALs[j], v.parent.Cs[j + 1] = leftorth(v.parent.ACs[j]; alg=QRpos())
81+
if j != i # last AC not needed
82+
v.parent.ACs[j + 1] = _mul_front(v.parent.Cs[j + 1], v.parent.ARs[j + 1])
83+
end
5784
end
5885
end
86+
5987
return v.parent.Cs[i + 1]
6088
end
6189

@@ -93,15 +121,16 @@ struct ACView{S,E,N} <: AbstractArray{E,N}
93121
end
94122

95123
function Base.getindex(v::ACView{<:FiniteMPS,E}, i::Int)::E where {E}
96-
if ismissing(v.parent.ACs[i]) && !ismissing(v.parent.ARs[i])
97-
c = v.parent.C[i - 1]
98-
ar = v.parent.ARs[i]
99-
v.parent.ACs[i] = _transpose_front(c * _transpose_tail(ar))
100-
elseif ismissing(v.parent.ACs[i]) && !ismissing(v.parent.ALs[i])
101-
c = v.parent.C[i]
102-
al = v.parent.ALs[i]
103-
v.parent.ACs[i] = al * c
124+
ismissing(v.parent.ACs[i]) || return v.parent.ACs[i]
125+
126+
if !ismissing(v.parent.ARs[i]) # center is too far left
127+
v.parent.ACs[i] = _mul_front(v.parent.C[i - 1], v.parent.ARs[i])
128+
elseif !ismissing(v.parent.ALs[i])
129+
v.parent.ACs[i] = _mul_tail(v.parent.ALs[i], v.parent.C[i])
130+
else
131+
error("Invalid state")
104132
end
133+
105134
return v.parent.ACs[i]
106135
end
107136

src/utility/utility.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ function _transpose_as(t1::AbstractTensorMap, t2::AbstractTensorMap)
88
return repartition(t1, numout(t2), numin(t2))
99
end
1010

11+
_mul_front(C::MPSBondTensor, A::GenericMPSTensor) = _transpose_front(C * _transpose_tail(A))
12+
_mul_tail(A::GenericMPSTensor, C::MPSBondTensor) = A * C
13+
1114
function _similar_tail(A::AbstractTensorMap)
1215
cod = _firstspace(A)
1316
dom = (dual(_lastspace(A)), dual.(space.(Ref(A), reverse(2:(numind(A) - 1))))...)

0 commit comments

Comments
 (0)