Skip to content

Commit 14b3630

Browse files
authored
[Perf] Another round of performance and type-stability improvements (#229)
* Optimize contraction: `renormalize_west_edge` * Improve type stability of CTMRG * Fix type stability for fixedspacetruncation * Avoid boxed variables in closure * patch svd algorithm when krylov dimension is too low
1 parent 47e6675 commit 14b3630

File tree

11 files changed

+120
-54
lines changed

11 files changed

+120
-54
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Random = "1"
4343
Statistics = "1"
4444
TensorKit = "0.14.6"
4545
TensorOperations = "5"
46+
TestExtras = "0.3"
4647
VectorInterface = "0.4, 0.5"
4748
Zygote = "0.6, 0.7"
4849
julia = "1.10"
@@ -53,6 +54,7 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
5354
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
5455
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5556
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
57+
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5658

5759
[targets]
58-
test = ["Test", "SafeTestsets", "ChainRulesTestUtils", "FiniteDifferences", "QuadGK"]
60+
test = ["Test", "TestExtras", "SafeTestsets", "ChainRulesTestUtils", "FiniteDifferences", "QuadGK"]

src/PEPSKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import MPSKit: leading_boundary, loginit!, logiter!, logfinish!, logcancel!, phy
1515

1616
using MPSKitModels
1717
using FiniteDifferences
18-
using OhMyThreads: tmap
18+
using OhMyThreads: tmap, tmap!
1919
using DocStringExtensions
2020

2121
include("Defaults.jl") # Include first to allow for docstring interpolation with Defaults values

src/algorithms/contractions/ctmrg_contractions.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,12 +1310,19 @@ end
13101310
function renormalize_west_edge(
13111311
E_west::CTMRG_PEPS_EdgeTensor, P_bottom, P_top, A::PEPSSandwich
13121312
)
1313-
return @autoopt @tensor edge[χ_S D_Eab D_Ebe; χ_N] :=
1314-
E_west[χ1 D1 D2; χ2] *
1315-
ket(A)[d; D3 D_Eab D5 D1] *
1316-
conj(bra(A)[d; D4 D_Ebe D6 D2]) *
1317-
P_bottom[χ2 D3 D4; χ_N] *
1318-
P_top[χ_S; χ1 D5 D6]
1313+
# starting with P_top to save one permute in the end
1314+
return @tensor begin
1315+
# already putting χE in front here to make next permute cheaper
1316+
PE[χS χNW DSb DWb; DSt DWt] := P_top[χS; χSW DSt DSb] * E_west[χSW DWt DWb; χNW]
1317+
1318+
PEket[χS χNW DNt DEt; DSb DWb d] :=
1319+
PE[χS χNW DSb DWb; DSt DWt] * ket(A)[d; DNt DEt DSt DWt]
1320+
1321+
corner[χS DEt DEb; χNW DNt DNb] :=
1322+
PEket[χS χNW DNt DEt; DSb DWb d] * conj(bra(A)[d; DNb DEb DSb DWb])
1323+
1324+
edge[χS DEt DEb; χN] := corner[χS DEt DEb; χNW DNt DNb] * P_bottom[χNW DNt DNb; χN]
1325+
end
13191326
end
13201327
function renormalize_west_edge(E_west::CTMRG_PF_EdgeTensor, P_bottom, P_top, A::PFTensor)
13211328
return @autoopt @tensor edge[χ_S D_E; χ_N] :=

src/algorithms/correlators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function end_correlator_numerator(
120120
C_southeast = env.corners[SOUTHEAST, _next(r, end), _next(c, end)]
121121
sandwich = (above[mod1(r, end), mod1(c, end)], below[mod1(r, end), mod1(c, end)])
122122

123-
return @autoopt @tensor contractcheck = true V[χSW DWt dstring DWb; χNW] *
123+
return @autoopt @tensor V[χSW DWt dstring DWb; χNW] *
124124
E_south[χSSE DSt DSb; χSW] *
125125
E_east[χNEE DEt DEb; χSEE] *
126126
E_north[χNW DNt DNb; χNNE] *

src/algorithms/ctmrg/ctmrg.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,12 @@ end
110110
function leading_boundary(
111111
env₀::CTMRGEnv, network::InfiniteSquareNetwork, alg::CTMRGAlgorithm
112112
)
113-
CS = map(x -> tsvd(x)[2], env₀.corners)
114-
TS = map(x -> tsvd(x)[2], env₀.edges)
115-
116-
η = one(real(scalartype(network)))
117-
env = deepcopy(env₀)
118113
log = ignore_derivatives(() -> MPSKit.IterLog("CTMRG"))
119-
120114
return LoggingExtras.withlevel(; alg.verbosity) do
115+
env = deepcopy(env₀)
116+
CS = map(x -> tsvd(x)[2], env₀.corners)
117+
TS = map(x -> tsvd(x)[2], env₀.edges)
118+
η = one(real(scalartype(network)))
121119
ctmrg_loginit!(log, η, network, env₀)
122120
local info
123121
for iter in 1:(alg.maxiter)

src/algorithms/ctmrg/sequential.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ For a full description, see [`leading_boundary`](@ref). The supported keywords a
2424
* `svd_alg::Union{<:SVDAdjoint,NamedTuple}`
2525
* `projector_alg::Symbol=:$(Defaults.projector_alg)`
2626
"""
27-
struct SequentialCTMRG <: CTMRGAlgorithm
27+
struct SequentialCTMRG{P<:ProjectorAlgorithm} <: CTMRGAlgorithm
2828
tol::Float64
2929
maxiter::Int
3030
miniter::Int
3131
verbosity::Int
32-
projector_alg::ProjectorAlgorithm
32+
projector_alg::P
3333
end
3434
function SequentialCTMRG(; kwargs...)
3535
return CTMRGAlgorithm(; alg=:sequential, kwargs...)
@@ -81,14 +81,18 @@ for a specific `coordinate` (where `dir=WEST` is already implied in the `:sequen
8181
"""
8282
function sequential_projectors(col::Int, network, env::CTMRGEnv, alg::ProjectorAlgorithm)
8383
coordinates = eachcoordinate(env)[:, col]
84-
proj_and_info = dtmap(coordinates) do (r, c)
84+
T_dst = Base.promote_op(
85+
sequential_projectors, NTuple{3,Int}, typeof(network), typeof(env), typeof(alg)
86+
)
87+
proj_and_info = similar(coordinates, T_dst)
88+
proj_and_info′::typeof(proj_and_info) = dtmap!!(proj_and_info, coordinates) do (r, c)
8589
trscheme = truncation_scheme(alg, env.edges[WEST, _prev(r, size(env, 2)), c])
8690
proj, info = sequential_projectors(
8791
(WEST, r, c), network, env, @set(alg.trscheme = trscheme)
8892
)
8993
return proj, info
9094
end
91-
return _split_proj_and_info(proj_and_info)
95+
return _split_proj_and_info(proj_and_info)
9296
end
9397
function sequential_projectors(
9498
coordinate::NTuple{3,Int}, network, env::CTMRGEnv, alg::HalfInfiniteProjector

src/algorithms/ctmrg/simultaneous.jl

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ For a full description, see [`leading_boundary`](@ref). The supported keywords a
2323
* `svd_alg::Union{<:SVDAdjoint,NamedTuple}`
2424
* `projector_alg::Symbol=:$(Defaults.projector_alg)`
2525
"""
26-
struct SimultaneousCTMRG <: CTMRGAlgorithm
26+
struct SimultaneousCTMRG{P<:ProjectorAlgorithm} <: CTMRGAlgorithm
2727
tol::Float64
2828
maxiter::Int
2929
miniter::Int
3030
verbosity::Int
31-
projector_alg::ProjectorAlgorithm
31+
projector_alg::P
3232
end
3333
function SimultaneousCTMRG(; kwargs...)
3434
return CTMRGAlgorithm(; alg=:simultaneous, kwargs...)
@@ -37,9 +37,15 @@ end
3737
CTMRG_SYMBOLS[:simultaneous] = SimultaneousCTMRG
3838

3939
function ctmrg_iteration(network, env::CTMRGEnv, alg::SimultaneousCTMRG)
40-
enlarged_corners = dtmap(eachcoordinate(network, 1:4)) do idx
41-
return TensorMap(EnlargedCorner(network, env, idx))
42-
end # expand environment
40+
coordinates = eachcoordinate(network, 1:4)
41+
T_corners = Base.promote_op(
42+
TensorMap EnlargedCorner, typeof(network), typeof(env), eltype(coordinates)
43+
)
44+
enlarged_corners′ = similar(coordinates, T_corners)
45+
enlarged_corners::typeof(enlarged_corners′) =
46+
dtmap!!(enlarged_corners′, eachcoordinate(network, 1:4)) do idx
47+
return TensorMap(EnlargedCorner(network, env, idx))
48+
end # expand environment
4349
projectors, info = simultaneous_projectors(enlarged_corners, env, alg.projector_alg) # compute projectors on all coordinates
4450
env′ = renormalize_simultaneously(enlarged_corners, projectors, network, env) # renormalize enlarged corners
4551
return env′, info
@@ -72,25 +78,36 @@ enlarged corners or on a specific `coordinate`.
7278
function simultaneous_projectors(
7379
enlarged_corners::Array{E,3}, env::CTMRGEnv, alg::ProjectorAlgorithm
7480
) where {E}
75-
proj_and_info = dtmap(eachcoordinate(env, 1:4)) do coordinate
76-
coordinate′ = _next_coordinate(coordinate, size(env)[2:3]...)
77-
trscheme = truncation_scheme(alg, env.edges[coordinate[1], coordinate′[2:3]...])
78-
return simultaneous_projectors(
79-
coordinate, enlarged_corners, @set(alg.trscheme = trscheme)
80-
)
81-
end
81+
coordinates = eachcoordinate(env, 1:4)
82+
T_dst = Base.promote_op(
83+
simultaneous_projectors,
84+
NTuple{3,Int},
85+
typeof(enlarged_corners),
86+
typeof(env),
87+
typeof(alg),
88+
)
89+
proj_and_info′ = similar(coordinates, T_dst)
90+
proj_and_info::typeof(proj_and_info′) =
91+
dtmap!!(proj_and_info′, coordinates) do coordinate
92+
return simultaneous_projectors(coordinate, enlarged_corners, env, alg)
93+
end
8294
return _split_proj_and_info(proj_and_info)
8395
end
8496
function simultaneous_projectors(
85-
coordinate, enlarged_corners::Array{E,3}, alg::HalfInfiniteProjector
97+
coordinate, enlarged_corners::Array{E,3}, env, alg::HalfInfiniteProjector
8698
) where {E}
87-
coordinate′ = _next_coordinate(coordinate, size(enlarged_corners)[2:3]...)
99+
coordinate′ = _next_coordinate(coordinate, size(env)[2:3]...)
100+
trscheme = truncation_scheme(alg, env.edges[coordinate[1], coordinate′[2:3]...])
101+
alg′ = @set alg.trscheme = trscheme
88102
ec = (enlarged_corners[coordinate...], enlarged_corners[coordinate′...])
89-
return compute_projector(ec, coordinate, alg)
103+
return compute_projector(ec, coordinate, alg)
90104
end
91105
function simultaneous_projectors(
92-
coordinate, enlarged_corners::Array{E,3}, alg::FullInfiniteProjector
106+
coordinate, enlarged_corners::Array{E,3}, env, alg::FullInfiniteProjector
93107
) where {E}
108+
coordinate′ = _next_coordinate(coordinate, size(env)[2:3]...)
109+
trscheme = truncation_scheme(alg, env.edges[coordinate[1], coordinate′[2:3]...])
110+
alg′ = @set alg.trscheme = trscheme
94111
rowsize, colsize = size(enlarged_corners)[2:3]
95112
coordinate2 = _next_coordinate(coordinate, rowsize, colsize)
96113
coordinate3 = _next_coordinate(coordinate2, rowsize, colsize)
@@ -101,7 +118,7 @@ function simultaneous_projectors(
101118
enlarged_corners[coordinate2...],
102119
enlarged_corners[coordinate3...],
103120
)
104-
return compute_projector(ec, coordinate, alg)
121+
return compute_projector(ec, coordinate, alg)
105122
end
106123

107124
"""
@@ -112,22 +129,33 @@ Renormalize all enlarged corners and edges simultaneously.
112129
function renormalize_simultaneously(enlarged_corners, projectors, network, env)
113130
P_left, P_right = projectors
114131
coordinates = eachcoordinate(env, 1:4)
115-
corners_edges = dtmap(coordinates) do (dir, r, c)
116-
if dir == NORTH
117-
corner = renormalize_northwest_corner((r, c), enlarged_corners, P_left, P_right)
118-
edge = renormalize_north_edge((r, c), env, P_left, P_right, network)
119-
elseif dir == EAST
120-
corner = renormalize_northeast_corner((r, c), enlarged_corners, P_left, P_right)
121-
edge = renormalize_east_edge((r, c), env, P_left, P_right, network)
122-
elseif dir == SOUTH
123-
corner = renormalize_southeast_corner((r, c), enlarged_corners, P_left, P_right)
124-
edge = renormalize_south_edge((r, c), env, P_left, P_right, network)
125-
elseif dir == WEST
126-
corner = renormalize_southwest_corner((r, c), enlarged_corners, P_left, P_right)
127-
edge = renormalize_west_edge((r, c), env, P_left, P_right, network)
132+
T_CE = Tuple{cornertype(env),edgetype(env)}
133+
corners_edges′ = similar(coordinates, T_CE)
134+
corners_edges::typeof(corners_edges′) =
135+
dtmap!!(corners_edges′, coordinates) do (dir, r, c)
136+
if dir == NORTH
137+
corner = renormalize_northwest_corner(
138+
(r, c), enlarged_corners, P_left, P_right
139+
)
140+
edge = renormalize_north_edge((r, c), env, P_left, P_right, network)
141+
elseif dir == EAST
142+
corner = renormalize_northeast_corner(
143+
(r, c), enlarged_corners, P_left, P_right
144+
)
145+
edge = renormalize_east_edge((r, c), env, P_left, P_right, network)
146+
elseif dir == SOUTH
147+
corner = renormalize_southeast_corner(
148+
(r, c), enlarged_corners, P_left, P_right
149+
)
150+
edge = renormalize_south_edge((r, c), env, P_left, P_right, network)
151+
elseif dir == WEST
152+
corner = renormalize_southwest_corner(
153+
(r, c), enlarged_corners, P_left, P_right
154+
)
155+
edge = renormalize_west_edge((r, c), env, P_left, P_right, network)
156+
end
157+
return corner / norm(corner), edge / norm(edge)
128158
end
129-
return corner / norm(corner), edge / norm(edge)
130-
end
131159

132160
return CTMRGEnv(map(first, corners_edges), map(last, corners_edges))
133161
end

src/algorithms/ctmrg/sparse_environments.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ function EnlargedCorner(network::InfiniteSquareNetwork, env, coordinates)
5656
network[r, c],
5757
dir,
5858
)
59+
else
60+
throw(ArgumentError(lazy"Invalid direction $dir"))
5961
end
6062
end
6163

@@ -73,6 +75,8 @@ function TensorKit.TensorMap(Q::EnlargedCorner)
7375
return enlarge_southeast_corner(Q.E_1, Q.C, Q.E_2, Q.A)
7476
elseif Q.dir == SOUTHWEST
7577
return enlarge_southwest_corner(Q.E_1, Q.C, Q.E_2, Q.A)
78+
else
79+
throw(ArgumentError(lazy"Invalid direction $dir"))
7680
end
7781
end
7882

src/utility/diffable_threads.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ All calls of `dtmap` inside of PEPSKit use the threading scheduler stored inside
88
"""
99
dtmap(args...; scheduler=Defaults.scheduler[]) = tmap(args...; scheduler)
1010

11+
dtmap!!(args...; scheduler=Defaults.scheduler[]) = tmap!(args...; scheduler)
12+
1113
# Follows the `map` rrule from ChainRules.jl but specified for the case of one AbstractArray that is being mapped
1214
# https://github.com/JuliaDiff/ChainRules.jl/blob/e245d50a1ae56ce46fc8c1f0fe9b925964f1146e/src/rulesets/Base/base.jl#L243
1315
function ChainRulesCore.rrule(
@@ -33,6 +35,22 @@ function ChainRulesCore.rrule(
3335
return y, dtmap_pullback
3436
end
3537

38+
function ChainRulesCore.rrule(
39+
config::RuleConfig{>:HasReverseMode},
40+
::typeof(dtmap!!),
41+
f,
42+
C′::AbstractArray,
43+
A::AbstractArray;
44+
kwargs...,
45+
)
46+
C, dtmap_pullback = rrule(config, dtmap, f, A; kwargs...)
47+
function dtmap!!_pullback(dy)
48+
dtmap, df, dA = dtmap_pullback(dy)
49+
return dtmap, df, NoTangent, dA
50+
end
51+
return C, dtmap!!_pullback
52+
end
53+
3654
"""
3755
@fwdthreads(ex)
3856

src/utility/svd.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,11 @@ function TensorKit._compute_svddata!(
290290
V = V[1:howmany, :]
291291
else
292292
x₀ = alg.start_vector(b)
293-
S, lvecs, rvecs, info = KrylovKit.svdsolve(b, x₀, howmany, :LR, alg.alg)
293+
svd_alg = alg.alg
294+
if howmany > alg.alg.krylovdim
295+
svd_alg = @set svd_alg.krylovdim = round(Int, howmany * 1.2)
296+
end
297+
S, lvecs, rvecs, info = KrylovKit.svdsolve(b, x₀, howmany, :LR, svd_alg)
294298
if info.converged < howmany # Fall back to dense SVD if not properly converged
295299
@warn "Iterative SVD did not converge for block $c, falling back to dense SVD"
296300
U, S, V = TensorKit.MatrixAlgebra.svd!(b, TensorKit.SDD())

0 commit comments

Comments
 (0)