Skip to content

Commit fed4cf1

Browse files
committed
update states
rework gauging of InfiniteMPS
1 parent 79618ac commit fed4cf1

File tree

6 files changed

+88
-51
lines changed

6 files changed

+88
-51
lines changed

src/states/abstractmps.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,20 @@ function isfullrank(A::GenericMPSTensor; side = :both)
109109
end
110110

111111
"""
112-
makefullrank!(A::PeriodicVector{<:GenericMPSTensor}; alg=QRpos())
112+
makefullrank!(A::PeriodicVector{<:GenericMPSTensor}; alg=Defalts.alg_qr())
113113
114114
Make the set of MPS tensors full rank by performing a series of orthogonalizations.
115115
"""
116-
function makefullrank!(A::PeriodicVector{<:GenericMPSTensor}; alg = QRpos())
116+
function makefullrank!(A::PeriodicVector{<:GenericMPSTensor}; alg = Defaults.alg_qr())
117117
while true
118118
i = findfirst(!isfullrank, A)
119119
isnothing(i) && break
120120
if !isfullrank(A[i]; side = :left)
121-
L, Q = rightorth!(_transpose_tail(A[i]); alg = alg')
121+
L, Q = _right_orth!(_transpose_tail(A[i]); alg)
122122
A[i] = _transpose_front(Q)
123123
A[i - 1] = A[i - 1] * L
124124
else
125-
A[i], R = leftorth!(A[i]; alg)
125+
A[i], R = _left_orth!(A[i]; alg)
126126
A[i + 1] = _transpose_front(R * _transpose_tail(A[i + 1]))
127127
end
128128
end

src/states/finitemps.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,14 @@ function FiniteMPS(As::Vector{<:GenericMPSTensor}; normalize = false, overwrite
211211
# vectors anyways, maybe deprecate `overwrite`.
212212
As = overwrite ? As : copy(As)
213213
N = length(As)
214-
for i in 1:(N - 1)
215-
As[i], C = leftorth(As[i]; alg = QRpos())
214+
As[1] = MatrixAlgebraKit.copy_input(qr_compact, As[1])
215+
local C
216+
for i in eachindex(As)
217+
As[i], C = qr_compact!(As[i]; positive = true)
216218
normalize && normalize!(C)
217-
As[i + 1] = _transpose_front(C * _transpose_tail(As[i + 1]))
219+
i == N || (As[i + 1] = _transpose_front(C * _transpose_tail(As[i + 1])))
218220
end
219221

220-
As[end], C = leftorth(As[end]; alg = QRpos())
221-
normalize && normalize!(C)
222-
223222
A = eltype(As)
224223
B = typeof(C)
225224

@@ -533,11 +532,11 @@ function Base.:+(ψ₁::MPS, ψ₂::MPS) where {MPS <: FiniteMPS}
533532
F₁ = isometry(
534533
storagetype(ψ), (_lastspace(ψ₁.AL[1]) _lastspace(ψ₂.AL[1]))', _lastspace(ψ₁.AL[1])'
535534
)
536-
F₂ = leftnull(F₁)
535+
F₂ = left_null(F₁)
537536
@assert _lastspace(F₂) == _lastspace(ψ₂.AL[1])
538537

539538
AL = ψ₁.AL[1] * F₁' + ψ₂.AL[1] * F₂'
540-
ψ.ALs[1], R = leftorth!(AL)
539+
ψ.ALs[1], R = left_orth!(AL)
541540

542541
for i in 2:halfN
543542
AL₁ = _transpose_front(F₁ * _transpose_tail(ψ₁.AL[i]))
@@ -546,11 +545,11 @@ function Base.:+(ψ₁::MPS, ψ₂::MPS) where {MPS <: FiniteMPS}
546545
F₁ = isometry(
547546
storagetype(ψ), (_lastspace(AL₁) _lastspace(ψ₂.AL[i]))', _lastspace(AL₁)'
548547
)
549-
F₂ = leftnull(F₁)
548+
F₂ = left_null(F₁)
550549
@assert _lastspace(F₂) == _lastspace(ψ₂.AL[i])
551550

552551
AL = _transpose_front(R * _transpose_tail(AL₁ * F₁' + AL₂ * F₂'))
553-
ψ.ALs[i], R = leftorth!(AL)
552+
ψ.ALs[i], R = left_orth!(AL)
554553
end
555554

556555
C₁ = F₁ * ψ₁.C[halfN]
@@ -560,11 +559,11 @@ function Base.:+(ψ₁::MPS, ψ₂::MPS) where {MPS <: FiniteMPS}
560559
F₁ = isometry(
561560
storagetype(ψ), _firstspace(ψ₁.AR[end]) _firstspace(ψ₂.AR[end]), _firstspace(ψ₁.AR[end])
562561
)
563-
F₂ = leftnull(F₁)
562+
F₂ = left_null(F₁)
564563
@assert _lastspace(F₂) == _firstspace(ψ₂.AR[end])'
565564

566565
AR = F₁ * _transpose_tail(ψ₁.AR[end]) + F₂ * _transpose_tail(ψ₂.AR[end])
567-
L, AR′ = rightorth!(AR)
566+
L, AR′ = right_orth!(AR)
568567
ψ.ARs[end] = _transpose_front(AR′)
569568

570569
for i in Iterators.reverse((halfN + 1):(length(ψ) - 1))
@@ -574,11 +573,11 @@ function Base.:+(ψ₁::MPS, ψ₂::MPS) where {MPS <: FiniteMPS}
574573
F₁ = isometry(
575574
storagetype(ψ), _firstspace(ψ₁.AR[i]) _firstspace(AR₂), _firstspace(ψ₁.AR[i])
576575
)
577-
F₂ = leftnull(F₁)
576+
F₂ = left_null(F₁)
578577
@assert _lastspace(F₂) == _firstspace(AR₂)'
579578

580579
AR = _transpose_tail(_transpose_front(F₁ * AR₁ + F₂ * AR₂) * L)
581-
L, AR′ = rightorth!(AR)
580+
L, AR′ = right_orth!(AR)
582581
ψ.ARs[i] = _transpose_front(AR′)
583582
end
584583

src/states/ortho.jl

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ $(TYPEDFIELDS)
2121
verbosity::Int = VERBOSE_WARN
2222

2323
"algorithm used for orthogonalization of the tensors"
24-
alg_orth = QRpos()
24+
alg_orth = LAPACK_HouseholderQR(; positive = true)
2525
"algorithm used for the eigensolver"
2626
alg_eigsolve = _GAUGE_ALG_EIGSOLVE
2727
"minimal amount of iterations before using the eigensolver steps"
@@ -46,7 +46,7 @@ $(TYPEDFIELDS)
4646
verbosity::Int = VERBOSE_WARN
4747

4848
"algorithm used for orthogonalization of the tensors"
49-
alg_orth = LQpos()
49+
alg_orth = LAPACK_HouseholderLQ(; positive = true)
5050
"algorithm used for the eigensolver"
5151
alg_eigsolve = _GAUGE_ALG_EIGSOLVE
5252
"minimal amount of iterations before using the eigensolver steps"
@@ -73,18 +73,18 @@ end
7373

7474
function MixedCanonical(;
7575
tol::Real = Defaults.tolgauge, maxiter::Int = Defaults.maxiter,
76-
verbosity::Int = VERBOSE_WARN, alg_orth = QRpos(),
76+
verbosity::Int = VERBOSE_WARN, alg_orth = LAPACK_HouseholderQR(; positive = true),
7777
alg_eigsolve = _GAUGE_ALG_EIGSOLVE,
7878
eig_miniter::Int = 10, order::Symbol = :LR
7979
)
80-
if alg_orth isa QR || alg_orth isa QRpos
80+
if alg_orth isa LAPACK_HouseholderQR
8181
alg_leftorth = alg_orth
82-
alg_rightorth = alg_orth'
83-
elseif alg_orth isa LQ || alg_orth isa LQpos
84-
alg_leftorth = alg_orth'
82+
alg_rightorth = LAPACK_HouseholderLQ(; alg_orth.kwargs...)
83+
elseif alg_orth isa LAPACK_HouseholderLQ
84+
alg_leftorth = LAPACK_HouseholderQR(; alg_orth.kwargs...)
8585
alg_rightorth = alg_orth
8686
else
87-
throw(ArgumentError("Invalid orthogonalization algorithm: $(typeof(alg_orth))"))
87+
alg_leftorth = alg_rightorth = alg_orth
8888
end
8989

9090
left = LeftCanonical(;
@@ -145,45 +145,54 @@ function gaugefix!(ψ::InfiniteMPS, A, C₀, alg::RightCanonical)
145145
end
146146

147147
@doc """
148-
regauge!(AC::GenericMPSTensor, C::MPSBondTensor; alg=QRpos()) -> AL
149-
regauge!(CL::MPSBondTensor, AC::GenericMPSTensor; alg=LQpos()) -> AR
148+
regauge!(AC::GenericMPSTensor, C::MPSBondTensor; alg) -> AL
149+
regauge!(CL::MPSBondTensor, AC::GenericMPSTensor; alg) -> AR
150150
151151
Bring updated `AC` and `C` tensors back into a consistent set of left or right canonical
152-
tensors. This minimizes `∥AC_i - AL_i * C_i∥` or `∥AC_i - C_{i-1} * AR_i∥`. The optimal algorithm uses
153-
`Polar()` decompositions, but `QR`-based algorithms are typically more performant.
152+
tensors. This minimizes `∥AC_i - AL_i * C_i∥` or `∥AC_i - C_{i-1} * AR_i∥`.
153+
154+
The `alg` is passed on to [`left_orth!`](@extref MatrixAlgebraKit) and
155+
[`right_orth!`](@extref MatrixAlgebraKit), and can be used to control the kind of
156+
factorization used. By default, this is set to a (positive) QR/LQ, even though the
157+
optimal algorithm would use a polar decompositions instead, sacrificing a bit of
158+
performance for accuracy.
154159
155160
!!! note
156161
Computing `AL` is slightly faster than `AR`, as it avoids an intermediate transposition.
157162
"""
158163
regauge!
159164

160-
function regauge!(AC::GenericMPSTensor, C::MPSBondTensor; alg = QRpos())
161-
Q_AC, _ = leftorth!(AC; alg)
162-
Q_C, _ = leftorth!(C; alg)
165+
function regauge!(
166+
AC::GenericMPSTensor, C::MPSBondTensor; alg = Defaults.alg_qr()
167+
)
168+
Q_AC, _ = _left_orth!(AC; alg)
169+
Q_C, _ = _left_orth!(C; alg)
163170
return mul!(AC, Q_AC, Q_C')
164171
end
165-
function regauge!(AC::Vector{<:GenericMPSTensor}, C::Vector{<:MPSBondTensor}; alg = QRpos())
172+
function regauge!(AC::Vector{<:GenericMPSTensor}, C::Vector{<:MPSBondTensor}; kwargs...)
166173
for i in 1:length(AC)
167-
regauge!(AC[i], C[i]; alg)
174+
regauge!(AC[i], C[i]; kwargs...)
168175
end
169176
return AC
170177
end
171-
function regauge!(CL::MPSBondTensor, AC::GenericMPSTensor; alg = LQpos())
178+
function regauge!(
179+
CL::MPSBondTensor, AC::GenericMPSTensor; alg = Defaults.alg_lq()
180+
)
172181
AC_tail = _transpose_tail(AC)
173-
_, Q_AC = rightorth!(AC_tail; alg)
174-
_, Q_C = rightorth!(CL; alg)
182+
_, Q_AC = _right_orth!(AC_tail; alg)
183+
_, Q_C = _right_orth!(CL; alg)
175184
AR_tail = mul!(AC_tail, Q_C', Q_AC)
176185
return repartition!(AC, AR_tail)
177186
end
178-
function regauge!(CL::Vector{<:MPSBondTensor}, AC::Vector{<:GenericMPSTensor}; alg = LQpos())
187+
function regauge!(CL::Vector{<:MPSBondTensor}, AC::Vector{<:GenericMPSTensor}; kwargs...)
179188
for i in length(CL):-1:1
180-
regauge!(CL[i], AC[i]; alg)
189+
regauge!(CL[i], AC[i]; kwargs...)
181190
end
182191
return CL
183192
end
184193
# fix ambiguity + error
185-
regauge!(::MPSBondTensor, ::MPSBondTensor; alg = QRpos()) = error("method ambiguity")
186-
function regauge!(::Vector{<:MPSBondTensor}, ::Vector{<:MPSBondTensor}; alg = QRpos())
194+
regauge!(::MPSBondTensor, ::MPSBondTensor; kwargs...) = error("method ambiguity")
195+
function regauge!(::Vector{<:MPSBondTensor}, ::Vector{<:MPSBondTensor}; kwargs...)
187196
return error("method ambiguity")
188197
end
189198

@@ -232,17 +241,18 @@ function gauge_eigsolve_step!(it::IterativeSolver{LeftCanonical}, state)
232241
if iter it.eig_miniter
233242
alg_eigsolve = updatetol(it.alg_eigsolve, 1, ϵ^2)
234243
_, vec = fixedpoint(flip(TransferMatrix(A, AL)), C[end], :LM, alg_eigsolve)
235-
_, C[end] = leftorth!(vec; alg = it.alg_orth)
244+
_, C[end] = _left_orth!(vec; alg = it.alg_orth)
236245
end
237246
return C[end]
238247
end
239248

240249
function gauge_orth_step!(it::IterativeSolver{LeftCanonical}, state)
241250
(; AL, C, A_tail, CA_tail) = state
242251
for i in 1:length(AL)
252+
# repartition!(A_tail[i], AL[i])
243253
mul!(CA_tail[i], C[i - 1], A_tail[i])
244254
repartition!(AL[i], CA_tail[i])
245-
AL[i], C[i] = leftorth!(AL[i]; alg = it.alg_orth)
255+
AL[i], C[i] = _left_orth!(AL[i]; alg = it.alg_orth)
246256
end
247257
normalize!(C[end])
248258
return C[end]
@@ -289,7 +299,7 @@ function gauge_eigsolve_step!(it::IterativeSolver{RightCanonical}, state)
289299
if iter it.eig_miniter
290300
alg_eigsolve = updatetol(it.alg_eigsolve, 1, ϵ^2)
291301
_, vec = fixedpoint(TransferMatrix(A, AR), C[end], :LM, alg_eigsolve)
292-
C[end], _ = rightorth!(vec; alg = it.alg_orth)
302+
C[end], _ = _right_orth!(vec; alg = it.alg_orth)
293303
end
294304
return C[end]
295305
end
@@ -299,7 +309,7 @@ function gauge_orth_step!(it::IterativeSolver{RightCanonical}, state)
299309
for i in length(AR):-1:1
300310
AC = mul!(AR[i], A[i], C[i]) # use AR as temporary storage for A * C
301311
tmp = repartition!(AC_tail[i], AC)
302-
C[i - 1], tmp = rightorth!(tmp; alg = it.alg_orth)
312+
C[i - 1], tmp = _right_orth!(tmp; alg = it.alg_orth)
303313
repartition!(AR[i], tmp) # TODO: avoid doing this every iteration
304314
end
305315
normalize!(C[end])

src/states/orthoview.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function Base.getindex(v::CView{<:FiniteMPS, E}, i::Int)::E where {E}
6060
end
6161

6262
for j in Iterators.reverse((i + 1):center)
63-
v.parent.Cs[j], tmp = rightorth(_transpose_tail(v.parent.ACs[j]); alg = LQpos())
63+
v.parent.Cs[j], tmp = lq_compact!(_transpose_tail(v.parent.ACs[j]); positive = true)
6464
v.parent.ARs[j] = _transpose_front(tmp)
6565
if j != i + 1 # last AC not needed
6666
v.parent.ACs[j - 1] = _mul_tail(v.parent.ALs[j - 1], v.parent.Cs[j])
@@ -76,7 +76,7 @@ function Base.getindex(v::CView{<:FiniteMPS, E}, i::Int)::E where {E}
7676
end
7777

7878
for j in center:i
79-
v.parent.ALs[j], v.parent.Cs[j + 1] = leftorth(v.parent.ACs[j]; alg = QRpos())
79+
v.parent.ALs[j], v.parent.Cs[j + 1] = qr_compact(v.parent.ACs[j]; positive = true)
8080
if j != i # last AC not needed
8181
v.parent.ACs[j + 1] = _mul_front(v.parent.Cs[j + 1], v.parent.ARs[j + 1])
8282
end
@@ -89,10 +89,10 @@ end
8989
function Base.setindex!(v::CView{<:FiniteMPS}, vec, i::Int)
9090
if ismissing(v.parent.Cs[i + 1])
9191
if !ismissing(v.parent.ALs[i])
92-
v.parent.Cs[i + 1], temp = rightorth(_transpose_tail(v.parent.AC[i + 1]); alg = LQpos())
92+
v.parent.Cs[i + 1], temp = lq_compact!(_transpose_tail(v.parent.AC[i + 1]); positive = true)
9393
v.parent.ARs[i + 1] = _transpose_front(temp)
9494
else
95-
v.parent.ALs[i], v.parent.Cs[i + 1] = leftorth(v.parent.AC[i]; alg = QRpos())
95+
v.parent.ALs[i], v.parent.Cs[i + 1] = qr_compact(v.parent.AC[i]; positive = true)
9696
end
9797
end
9898

src/utility/utility.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,31 @@ function check_unambiguous_braiding(V::VectorSpace)
149149
return check_unambiguous_braiding(Bool, V) ||
150150
throw(ArgumentError("cannot unambiguously braid $V"))
151151
end
152+
153+
# temporary workaround for the fact that left_orth and right_orth are poorly designed:
154+
function _left_orth!(t; alg::MatrixAlgebraKit.AbstractAlgorithm)
155+
if alg isa LAPACK_HouseholderQR
156+
return left_orth!(t; kind = :qr, alg_qr = alg)
157+
elseif alg isa LAPACK_HouseholderLQ
158+
return left_orth!(t; kind = :qr, alg_qr = LAPACK_HouseholderQR(; alg.kwargs...))
159+
elseif alg isa PolarViaSVD
160+
return left_orth!(t; kind = :polar, alg_polar = alg)
161+
elseif alg isa LAPACK_SVDAlgorithm
162+
return left_orth!(t; kind = :svd, alg_svd = alg)
163+
else
164+
error(lazy"unkown algorithm $alg")
165+
end
166+
end
167+
function _right_orth!(t; alg::MatrixAlgebraKit.AbstractAlgorithm)
168+
if alg isa LAPACK_HouseholderLQ
169+
return right_orth!(t; kind = :lq, alg_lq = alg)
170+
elseif alg isa LAPACK_HouseholderQr
171+
return right_orth!(t; kind = :lq, alg_lq = LAPACK_HouseholderLQ(; alg.kwargs...))
172+
elseif alg isa PolarViaSVD
173+
return right_orth!(t; kind = :polar, alg_polar = alg)
174+
elseif alg isa LAPACK_SVDAlgorithm
175+
return right_orth!(t; kind = :svd, alg_svd = alg)
176+
else
177+
error(lazy"unkown algorithm $alg")
178+
end
179+
end

test/states.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ module TestStates
207207

208208
@test real(e2) real(e1)
209209

210-
window, envs = timestep(window, ham, 0.1, 0.0, TDVP2(; trscheme = truncdim(20)), envs)
210+
window, envs = timestep(window, ham, 0.1, 0.0, TDVP2(; trscheme = truncrank(20)), envs)
211211
window, envs = timestep(window, ham, 0.1, 0.0, TDVP(), envs)
212212

213213
e3 = expectation_value(window, (2, 3) => O)

0 commit comments

Comments
 (0)