Skip to content

Commit 941b747

Browse files
authored
Improve support for mixed scalartypes (#259)
* Fix `fin_mpo * fin_mps` with mixed scalartypes * Add `complex(::FiniteMPS)` * Fix `fin_mpo + fin_mpo` with mixed scalartype * simplify implementation * add utility `left_virtualspace` and `right_virtualspace` * `scale(::MPO, ::Number)` with mixed scalartypes * small fixes * various changes and improvements * Expand testing with mixed scalartypes * convert to complex states in TDVP * Add `complex(::WindowMPS)` * Add `complex(::InfiniteMPS)` * Add `complex(::LazySum)` * fix multiplication of MPOHamiltonian
1 parent d91262b commit 941b747

File tree

12 files changed

+160
-63
lines changed

12 files changed

+160
-63
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
2020
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
2121
TensorKitManifolds = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
22+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
2223
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2324

2425
[compat]
@@ -52,7 +53,6 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5253
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5354
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5455
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
55-
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
5656
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5757
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5858

src/MPSKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ using Compat: @compat
6060
using TensorKit
6161
using TensorKit: BraidingTensor
6262
using BlockTensorKit
63+
using TensorOperations
6364
using KrylovKit
6465
using KrylovKit: KrylovAlgorithm
6566
using OptimKit

src/algorithms/timestep/tdvp.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ $(TYPEDFIELDS)
2525
finalize::F = Defaults._finalize
2626
end
2727

28-
function timestep(ψ::InfiniteMPS, H, t::Number, dt::Number, alg::TDVP,
29-
envs::AbstractMPSEnvironments=environments(ψ, H);
28+
function timestep(ψ_::InfiniteMPS, H, t::Number, dt::Number, alg::TDVP,
29+
envs::AbstractMPSEnvironments=environments(ψ_, H);
3030
leftorthflag=true)
31+
ψ = complex(ψ_)
3132
temp_ACs = similar.AC)
3233
temp_Cs = similar.C)
3334

@@ -172,5 +173,5 @@ end
172173
function timestep::AbstractFiniteMPS, H, time::Number, timestep::Number,
173174
alg::Union{TDVP,TDVP2}, envs::AbstractMPSEnvironments=environments(ψ, H);
174175
kwargs...)
175-
return timestep!(copy(ψ), H, time, timestep, alg, envs; kwargs...)
176+
return timestep!(copy(complex(ψ)), H, time, timestep, alg, envs; kwargs...)
176177
end

src/operators/abstractmpo.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ end
2424
# Properties
2525
# ----------
2626
left_virtualspace(mpo::AbstractMPO, site::Int) = left_virtualspace(mpo[site])
27+
left_virtualspace(mpo::AbstractMPO) = map(left_virtualspace, parent(mpo))
2728
right_virtualspace(mpo::AbstractMPO, site::Int) = right_virtualspace(mpo[site])
29+
right_virtualspace(mpo::AbstractMPO) = map(right_virtualspace, parent(mpo))
2830
physicalspace(mpo::AbstractMPO, site::Int) = physicalspace(mpo[site])
2931
physicalspace(mpo::AbstractMPO) = map(physicalspace, mpo)
3032

@@ -170,7 +172,11 @@ Base.:*(mpo::AbstractMPO, α::Number) = scale(mpo, α)
170172
Base.:/(mpo::AbstractMPO, α::Number) = scale(mpo, inv(α))
171173
Base.:\::Number, mpo::AbstractMPO) = scale(mpo, inv(α))
172174

173-
VectorInterface.scale(mpo::AbstractMPO, α::Number) = scale!(copy(mpo), α)
175+
function VectorInterface.scale(mpo::AbstractMPO, α::Number)
176+
T = VectorInterface.promote_scale(scalartype(mpo), scalartype(α))
177+
dst = similar(mpo, T)
178+
return scale!(dst, mpo, α)
179+
end
174180

175181
LinearAlgebra.norm(mpo::AbstractMPO) = sqrt(abs(dot(mpo, mpo)))
176182

src/operators/lazysum.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ Base.length(x::LazySum) = prod(size(x))
2424
Base.similar(x::LazySum, ::Type{S}, dims::Dims) where {S} = LazySum(similar(x.ops, S, dims))
2525
Base.setindex!(A::LazySum, X, i::Int) = (setindex!(A.ops, X, i); A)
2626

27+
Base.complex(x::LazySum) = LazySum(complex.(x.ops))
28+
2729
# Holy traits
2830
TimeDependence(x::LazySum) = istimed(x) ? TimeDependent() : NotTimeDependent()
2931
istimed(x::LazySum) = any(istimed, x)

src/operators/mpo.jl

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ DenseMPO(mpo::MPO) = mpo isa DenseMPO ? copy(mpo) : MPO(map(TensorMap, parent(mp
5555
Base.parent(mpo::MPO) = mpo.O
5656
Base.copy(mpo::MPO) = MPO(map(copy, mpo))
5757

58-
function Base.similar(mpo::MPO, ::Type{O}, L::Int) where {O}
58+
function Base.similar(mpo::MPO{<:MPOTensor}, ::Type{O}, L::Int) where {O<:MPOTensor}
5959
return MPO(similar(parent(mpo), O, L))
6060
end
61+
function Base.similar(mpo::MPO, ::Type{T}) where {T<:Number}
62+
return MPO(similar.(parent(mpo), T))
63+
end
6164

6265
Base.repeat(mpo::MPO, n::Int) = MPO(repeat(parent(mpo), n))
6366
Base.repeat(mpo::MPO, rows::Int, cols::Int) = MultilineMPO(fill(repeat(mpo, cols), rows))
@@ -102,19 +105,20 @@ function Base.convert(::Type{TensorMap}, mpo::FiniteMPO{<:MPOTensor})
102105
return convert(TensorMap, _instantiate_finitempo(L, M, R))
103106
end
104107

108+
Base.complex(mpo::MPO) = MPO(map(complex, parent(mpo)))
109+
105110
# Linear Algebra
106111
# --------------
107112
# VectorInterface.scalartype(::Type{FiniteMPO{O}}) where {O} = scalartype(O)
108113

109114
Base.:+(mpo::MPO) = MPO(map(+, parent(mpo)))
110-
function Base.:+(mpo1::FiniteMPO{TO}, mpo2::FiniteMPO{TO}) where {TO<:MPOTensor}
111-
(N = length(mpo1)) == length(mpo2) || throw(ArgumentError("dimension mismatch"))
115+
function Base.:+(mpo1::FiniteMPO{<:MPOTensor}, mpo2::FiniteMPO{<:MPOTensor})
116+
N = check_length(mpo1, mpo2)
112117
@assert left_virtualspace(mpo1, 1) == left_virtualspace(mpo2, 1) &&
113118
right_virtualspace(mpo1, N) == right_virtualspace(mpo2, N)
114119

115-
mpo = similar(parent(mpo1))
116120
halfN = N ÷ 2
117-
A = storagetype(TO)
121+
A = storagetype(eltype(mpo1))
118122

119123
# left half
120124
F₁ = isometry(A, (right_virtualspace(mpo1, 1) right_virtualspace(mpo2, 1)),
@@ -127,7 +131,9 @@ function Base.:+(mpo1::FiniteMPO{TO}, mpo2::FiniteMPO{TO}) where {TO<:MPOTensor}
127131

128132
# making sure that the new operator is "full rank"
129133
O, R = leftorth!(O)
130-
mpo[1] = transpose(O, ((2, 3), (1, 4)))
134+
O′ = transpose(O, ((2, 3), (1, 4)))
135+
mpo = similar(mpo1, typeof(O′))
136+
mpo[1] = O′
131137

132138
for i in 2:halfN
133139
# incorporate fusers from left side
@@ -193,11 +199,18 @@ function VectorInterface.scale!(mpo::MPO, α::Number)
193199
scale!(first(mpo), α)
194200
return mpo
195201
end
202+
function VectorInterface.scale!(dst::MPO, src::MPO, α::Number)
203+
N = check_length(dst, src)
204+
for i in 1:N
205+
scale!(dst[i], src[i], i == 1 ? α : One())
206+
end
207+
return dst
208+
end
209+
210+
function Base.:*(mpo1::FiniteMPO{<:MPOTensor}, mpo2::FiniteMPO{<:MPOTensor})
211+
N = check_length(mpo1, mpo2)
212+
(S = spacetype(mpo1)) == spacetype(mpo2) || throw(SectorMismatch())
196213

197-
# TODO: merge implementation with that of InfiniteMPO
198-
function Base.:*(mpo1::FiniteMPO{TO}, mpo2::FiniteMPO{TO}) where {TO<:MPOTensor}
199-
(N = length(mpo1)) == length(mpo2) || throw(ArgumentError("dimension mismatch"))
200-
S = spacetype(TO)
201214
if (left_virtualspace(mpo1, 1) != oneunit(S) ||
202215
left_virtualspace(mpo2, 1) != oneunit(S)) ||
203216
(right_virtualspace(mpo1, N) != oneunit(S) ||
@@ -207,44 +220,34 @@ function Base.:*(mpo1::FiniteMPO{TO}, mpo2::FiniteMPO{TO}) where {TO<:MPOTensor}
207220
# would work and for now I dont feel like figuring out if this is important
208221
end
209222

210-
O = similar(parent(mpo1))
211-
A = storagetype(TO)
212-
213-
# note order of mpos: mpo1 * mpo2 * state -> mpo2 on top of mpo1
214-
local Fᵣ # trick to make Fᵣ defined in the loop
215-
for i in 1:N
216-
Fₗ = i != 1 ? Fᵣ : fuser(A, left_virtualspace(mpo2, i), left_virtualspace(mpo1, i))
217-
Fᵣ = fuser(A, right_virtualspace(mpo2, i), right_virtualspace(mpo1, i))
218-
@plansor O[i][-1 -2; -3 -4] := Fₗ[-1; 1 4] * mpo2[i][1 2; -3 3] *
219-
mpo1[i][4 -2; 2 5] *
220-
conj(Fᵣ[-4; 3 5])
221-
end
222-
223+
O = map(fuse_mul_mpo, parent(mpo1), parent(mpo2))
223224
return changebonds!(FiniteMPO(O), SvdCut(; trscheme=notrunc()))
224225
end
226+
function Base.:*(mpo1::InfiniteMPO, mpo2::InfiniteMPO)
227+
check_length(mpo1, mpo2)
228+
Os = map(fuse_mul_mpo, parent(mpo1), parent(mpo2))
229+
return InfiniteMPO(Os)
230+
end
225231

226232
function Base.:*(mpo::FiniteMPO, mps::FiniteMPS)
227-
length(mpo) == length(mps) || throw(ArgumentError("dimension mismatch"))
228-
229-
A = [mps.AC[1]; mps.AR[2:end]]
230-
TT = storagetype(eltype(A))
231-
232-
local Fᵣ # trick to make Fᵣ defined in the loop
233-
for i in 1:length(mps)
234-
Fₗ = i != 1 ? Fᵣ : fuser(TT, left_virtualspace(mps, i), left_virtualspace(mpo, i))
235-
Fᵣ = fuser(TT, right_virtualspace(mps, i), right_virtualspace(mpo, i))
236-
A[i] = _fuse_mpo_mps(mpo[i], A[i], Fₗ, Fᵣ)
233+
N = check_length(mpo, mps)
234+
T = TensorOperations.promote_contract(scalartype(mpo), scalartype(mps))
235+
A = TensorKit.similarstoragetype(eltype(mps), T)
236+
Fᵣ = fuser(A, left_virtualspace(mps, 1), left_virtualspace(mpo, 1))
237+
A2 = map(1:N) do i
238+
A1 = i == 1 ? mps.AC[1] : mps.AR[i]
239+
Fₗ = Fᵣ
240+
Fᵣ = fuser(A, right_virtualspace(mps, i), right_virtualspace(mpo, i))
241+
return _fuse_mpo_mps(mpo[i], A1, Fₗ, Fᵣ)
237242
end
238-
239-
return changebonds!(FiniteMPS(A),
240-
SvdCut(; trscheme=truncbelow(eps(real(scalartype(TT)))));
241-
normalize=false)
243+
trscheme = truncbelow(eps(real(T)))
244+
return changebonds!(FiniteMPS(A2), SvdCut(; trscheme); normalize=false)
242245
end
243-
244246
function Base.:*(mpo::InfiniteMPO, mps::InfiniteMPS)
245247
L = check_length(mpo, mps)
246248
T = promote_type(scalartype(mpo), scalartype(mps))
247-
fusers = PeriodicArray(fuser.(T, left_virtualspace.(Ref(mps), 1:L),
249+
A = TensorKit.similarstoragetype(eltype(mps), T)
250+
fusers = PeriodicArray(fuser.(A, left_virtualspace.(Ref(mps), 1:L),
248251
left_virtualspace.(Ref(mpo), 1:L)))
249252
As = map(1:L) do i
250253
return _fuse_mpo_mps(mpo[i], mps.AL[i], fusers[i], fusers[i + 1])
@@ -260,12 +263,6 @@ function _fuse_mpo_mps(O::MPOTensor, A::MPSTensor, Fₗ, Fᵣ)
260263
return A′ isa AbstractBlockTensorMap ? TensorMap(A′) : A′
261264
end
262265

263-
function Base.:*(mpo1::InfiniteMPO, mpo2::InfiniteMPO)
264-
check_length(mpo1, mpo2)
265-
Os = map(fuse_mul_mpo, parent(mpo1), parent(mpo2))
266-
return InfiniteMPO(Os)
267-
end
268-
269266
function Base.:*(mpo::FiniteMPO{<:MPOTensor}, x::AbstractTensorMap)
270267
@assert length(mpo) > 1
271268
@assert numout(x) == length(mpo)
@@ -281,8 +278,7 @@ end
281278
# in the middle
282279
function TensorKit.dot(bra::FiniteMPS{T}, mpo::FiniteMPO{<:MPOTensor},
283280
ket::FiniteMPS{T}) where {T}
284-
(N = length(bra)) == length(mpo) == length(ket) ||
285-
throw(ArgumentError("dimension mismatch"))
281+
N = check_length(bra, mpo, ket)
286282
Nhalf = N ÷ 2
287283
# left half
288284
ρ_left = isomorphism(storagetype(T),

src/operators/mpohamiltonian.jl

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,32 @@ function add_physical_charge(H::MPOHamiltonian, charges::AbstractVector{<:Sector
374374
end
375375
end
376376

377+
# TODO: remove once complex(::BraidingTensor) isa BraidingTensor
378+
# Base.complex(H::MPOHamiltonian) = MPOHamiltonian(map(complex, parent(H)))
379+
function Base.complex(H::MPOHamiltonian)
380+
scalartype(H) <: Complex && return H
381+
Ws = map(parent(H)) do W
382+
W′ = jordanmpotensortype(spacetype(W), complex(scalartype(W)))
383+
W′[1] = W[1]
384+
W′[end] = W[end]
385+
for (I, v) in nonzero_pairs(W)
386+
if v isa BraidingTensor
387+
W′[I] = BraidingTensor{scalartype(W′)}(space(v), v.adjoint)
388+
else
389+
W′[I] = complex(v)
390+
end
391+
end
392+
end
393+
return MPOHamiltonian(H)
394+
end
395+
396+
function Base.similar(H::MPOHamiltonian, ::Type{O}, L::Int) where {O<:MPOTensor}
397+
return MPOHamiltonian(similar(parent(H), O, L))
398+
end
399+
function Base.similar(H::MPOHamiltonian, ::Type{T}) where {T<:Number}
400+
return MPOHamiltonian(similar.(parent(H), T))
401+
end
402+
377403
# Linear Algebra
378404
# --------------
379405

@@ -496,15 +522,36 @@ function VectorInterface.scale!(H::FiniteMPOHamiltonian, λ::Number)
496522
return H
497523
end
498524

525+
function VectorInterface.scale!(dst::MPOHamiltonian, src::MPOHamiltonian,
526+
λ::Number)
527+
N = check_length(dst, src)
528+
for i in 1:N
529+
space(dst[i]) == space(src[i]) || throw(SpaceMismatch())
530+
zerovector!(dst[i])
531+
for (I, v) in nonzero_pairs(src[i])
532+
# only scale "starting" terms
533+
isstarting = I[1] == 1 &&
534+
((isfinite(dst) && i == N && I[4] == size(src[i], 4)) ||
535+
((!isfinite(dst) || i != N) && I[4] > 1))
536+
if v isa BraidingTensor && !isstarting
537+
dst[i][I] = v
538+
else
539+
dst[i][I] = scale!(dst[i][I], v, isstarting ? λ : One())
540+
end
541+
end
542+
end
543+
return dst
544+
end
545+
499546
function Base.:*(H1::MPOHamiltonian, H2::MPOHamiltonian)
500547
check_length(H1, H2)
501548
Ws = fuse_mul_mpo.(parent(H1), parent(H2))
502549
return MPOHamiltonian(Ws)
503550
end
504551

505552
function Base.:*(H::FiniteMPOHamiltonian, mps::FiniteMPS)
506-
check_length(H, mps)
507-
@assert length(mps) > 2 "MPS should have at least three sites, to be implemented otherwise"
553+
N = check_length(H, mps)
554+
@assert N > 2 "MPS should have at least three sites, to be implemented otherwise"
508555
A = convert.(BlockTensorMap, [mps.AC[1]; mps.AR[2:end]])
509556
A′ = similar(A,
510557
tensormaptype(spacetype(mps), numout(eltype(mps)), numin(eltype(mps)),
@@ -515,30 +562,30 @@ function Base.:*(H::FiniteMPOHamiltonian, mps::FiniteMPS)
515562
Q, R = leftorth!(a; alg=QR())
516563
A′[1] = convert(TensorMap, Q)
517564

518-
for i in 2:(length(mps) ÷ 2)
565+
for i in 2:(N ÷ 2)
519566
@plansor a[-1 -2; -3 -4] := R[-1; 1 2] * A[i][1 3; -3] * H[i][2 -2; 3 -4]
520567
Q, R = leftorth!(a; alg=QR())
521568
A′[i] = convert(TensorMap, Q)
522569
end
523570

524571
# right to middle
525-
U = ones(scalartype(H), right_virtualspace(H, length(H)))
572+
U = ones(scalartype(H), right_virtualspace(H, N))
526573
@plansor a[-1 -2; -3 -4] := A[end][-1 2; -3] * H[end][-2 -4; 2 1] * U[1]
527574
L, Q = rightorth!(a; alg=LQ())
528575
A′[end] = transpose(convert(TensorMap, Q), ((1, 3), (2,)))
529576

530-
for i in (length(mps) - 1):-1:(length(mps) ÷ 2 + 2)
577+
for i in (N - 1):-1:(N ÷ 2 + 2)
531578
@plansor a[-1 -2; -3 -4] := A[i][-1 3; 1] * H[i][-2 -4; 3 2] * L[1 2; -3]
532579
L, Q = rightorth!(a; alg=LQ())
533580
A′[i] = transpose(convert(TensorMap, Q), ((1, 3), (2,)))
534581
end
535582

536583
# connect pieces
537584
@plansor a[-1 -2; -3] := R[-1; 1 2] *
538-
A[length(mps) ÷ 2 + 1][1 3; 4] *
539-
H[length(mps) ÷ 2 + 1][2 -2; 3 5] *
585+
A[N ÷ 2 + 1][1 3; 4] *
586+
H[N ÷ 2 + 1][2 -2; 3 5] *
540587
L[4 5; -3]
541-
A′[length(mps) ÷ 2 + 1] = convert(TensorMap, a)
588+
A′[N ÷ 2 + 1] = convert(TensorMap, a)
542589

543590
return FiniteMPS(A′)
544591
end
@@ -553,9 +600,7 @@ function Base.:*(H::FiniteMPOHamiltonian{<:MPOTensor}, x::AbstractTensorMap)
553600
end
554601

555602
function TensorKit.dot(H₁::FiniteMPOHamiltonian, H₂::FiniteMPOHamiltonian)
556-
check_length(H₁, H₂)
557-
558-
N = length(H₁)
603+
N = check_length(H₁, H₂)
559604
Nhalf = N ÷ 2
560605
# left half
561606
@plansor ρ_left[-1; -2] := conj(H₁[1][1 2; 3 -1]) * H₂[1][1 2; 3 -2]

src/states/finitemps.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,19 @@ Base.@propagate_inbounds function Base.getindex(ψ::FiniteMPS, i::Int)
316316
end
317317
end
318318

319+
_complex_if_not_missing(x) = ismissing(x) ? x : complex(x)
320+
function Base.complex(mps::FiniteMPS)
321+
scalartype(mps) <: Complex && return mps
322+
ALs = _complex_if_not_missing.(mps.ALs)
323+
ARs = _complex_if_not_missing.(mps.ARs)
324+
Cs = _complex_if_not_missing.(mps.Cs)
325+
ACs = _complex_if_not_missing.(mps.ACs)
326+
return FiniteMPS(collect(Union{Missing,eltype(ALs)}, ALs),
327+
collect(Union{Missing,eltype(ARs)}, ARs),
328+
collect(Union{Missing,eltype(ACs)}, ACs),
329+
collect(Union{Missing,eltype(Cs)}, Cs))
330+
end
331+
319332
@inline function Base.getindex::FiniteMPS, I::AbstractUnitRange)
320333
return Base.getindex.(Ref(ψ), I)
321334
end

src/states/infinitemps.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ function Base.copy!(ψ::InfiniteMPS, ϕ::InfiniteMPS)
224224
return ψ
225225
end
226226

227+
function Base.complex::InfiniteMPS)
228+
scalartype(ψ) <: Complex && return ψ
229+
return InfiniteMPS(complex.(ψ.AL), complex.(ψ.AR), complex.(ψ.C), complex.(ψ.AC))
230+
end
231+
227232
function Base.repeat::InfiniteMPS, i::Int)
228233
return InfiniteMPS(repeat.AL, i), repeat.AR, i), repeat.C, i), repeat.AC, i))
229234
end

0 commit comments

Comments
 (0)