Skip to content

Commit d68876b

Browse files
lkdvosogauthe
andauthored
[Performance] TreeTransformer refactor + multithreading (#251)
This PR features a change in the implementation of the treetransformers in the generic case of non-unique fusion. The core idea is to avoid having to access data non-contiguously as much as possible, while making use of BLAS to facilitate the recoupling. In order to achieve this, we now loop over the set of uncoupled charges and process all fusion trees that are attached to this simultaneously. The recoupling is done through BLAS, the copying into a temporary buffer with an optimized copy and the final permutation with Strided. Additionally, I added and centralized the implementation of multithreading for the treetransformers. This is a result of many discussions and inspiration from @ogauthe. Co-authored-by: Olivier Gauthe <[email protected]>
1 parent 8e9384b commit d68876b

File tree

10 files changed

+405
-176
lines changed

10 files changed

+405
-176
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
name = "TensorKit"
22
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
33
authors = ["Jutho Haegeman"]
4-
version = "0.14.6"
4+
version = "0.14.7"
55

66
[deps]
77
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1211
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1312
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
1413
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
@@ -33,7 +32,6 @@ LRUCache = "1.0.2"
3332
LinearAlgebra = "1"
3433
PackageExtensionCompat = "1"
3534
Random = "1"
36-
SparseArrays = "1"
3735
Strided = "2"
3836
TensorKitSectors = "0.1"
3937
TensorOperations = "5.1"

benchmark/TensorKitBenchmarks/TensorKitBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using BenchmarkTools
44
using TensorKit
55
using TOML
66

7-
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 1.0
7+
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 20.0
88
BenchmarkTools.DEFAULT_PARAMETERS.samples = 10000
99
BenchmarkTools.DEFAULT_PARAMETERS.time_tolerance = 0.15
1010
BenchmarkTools.DEFAULT_PARAMETERS.memory_tolerance = 0.01

benchmark/TensorKitBenchmarks/tensornetworks/TensorNetworkBenchmarks.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ function benchmark_mpo!(bench; sigmas=nothing, T="Float64", I="Trivial", dims)
4545
end
4646

4747
if haskey(all_parameters, "mpo")
48-
g = addgroup!(SUITE, "mpo")
4948
for params in all_parameters["mpo"]
50-
benchmark_mpo!(g, params)
49+
benchmark_mpo!(SUITE, params)
5150
end
5251
end
5352

@@ -90,9 +89,8 @@ function benchmark_pepo!(bench; sigmas=nothing, T="Float64", I="Trivial", dims)
9089
end
9190

9291
if haskey(all_parameters, "pepo")
93-
g = addgroup!(SUITE, "pepo")
9492
for params in all_parameters["pepo"]
95-
benchmark_pepo!(g, params)
93+
benchmark_pepo!(SUITE, params)
9694
end
9795
end
9896

@@ -136,9 +134,8 @@ function benchmark_mera!(bench; sigmas=nothing, T="Float64", I="Trivial", dims)
136134
end
137135

138136
if haskey(all_parameters, "mera")
139-
g = addgroup!(SUITE, "mera")
140137
for params in all_parameters["mera"]
141-
benchmark_mera!(g, params)
138+
benchmark_mera!(SUITE, params)
142139
end
143140
end
144141

benchmark/TensorKitBenchmarks/tensornetworks/benchparams.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ I = "U1Irrep"
1919
dims = [[40, 5, 3], [160, 5, 3], [640, 5, 3], [2560, 5, 3], [6120, 5, 3], [200, 20, 20], [400, 20, 20], [400, 40, 40]]
2020
sigmas = [0.5, 0.5, 0.5]
2121

22+
[[mpo]]
23+
T = ["Float64"]
24+
I = "SU2Irrep"
25+
dims = [[40, 5, 3], [160, 5, 3], [640, 5, 3], [2560, 5, 3], [6120, 5, 3], [200, 20, 20], [400, 20, 20], [400, 40, 40]]
26+
sigmas = [2, 2, 2]
27+
2228
# PEPO
2329
# ----
2430
# dims = [peps, pepo, phys, env]
@@ -40,6 +46,12 @@ I = "U1Irrep"
4046
dims = [[4, 2, 2, 100], [4, 4, 4, 200], [6, 2, 2, 100], [6, 3, 4, 200], [8, 2, 2, 100], [8, 2, 4, 200], [10, 2, 2, 50], [10, 3, 2, 100]]
4147
sigmas = [0.5, 0.5, 0.5, 0.5]
4248

49+
[[pepo]]
50+
T = ["Float64"]
51+
I = "SU2Irrep"
52+
dims = [[4, 2, 2, 100], [4, 4, 4, 200], [6, 2, 2, 100], [6, 3, 4, 200], [8, 2, 2, 100], [8, 2, 4, 200], [10, 2, 2, 50], [10, 3, 2, 100]]
53+
sigmas = [2.0, 2.0, 2.0, 2.0]
54+
4355
# MERA
4456
# ----
4557
# dims = mera
@@ -60,3 +72,9 @@ T = ["Float64"]
6072
I = "U1Irrep"
6173
dims = [4, 8, 12, 16, 22, 28]
6274
sigmas = [0.5]
75+
76+
[[mera]]
77+
T = ["Float64"]
78+
I = "SU2Irrep"
79+
dims = [4, 8, 12, 16, 22, 28]
80+
sigmas = [2.0]

benchmark/TensorKitBenchmarks/utils/BenchUtils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,15 @@ function generate_space(::Type{U1Irrep}, D::Int, sigma::Real=0.5)
5555
return U1Space((s => d for (s, d) in zip(sectors, dims))...)
5656
end
5757
function generate_space(::Type{SU2Irrep}, D::Int, sigma::Real=0.5)
58-
poisson_pdf(x) = ceil(Int, D * exp(-sigma) * sigma^x / factorial(x + 1))
58+
normal_pdf = let D = D
59+
x -> D * exp(-0.5 * (x / sigma)^2) / (sigma * sqrt(2π))
60+
end
5961

6062
sectors = SU2Irrep[]
6163
dims = Int[]
6264

6365
for sector in values(SU2Irrep)
64-
d = poisson_pdf(Int(sector.j * 2))
66+
d = ceil(Int, normal_pdf(sector.j) / dim(sector))
6567
push!(sectors, sector)
6668
push!(dims, d)
6769
D -= d * dim(sector)

src/TensorKit.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr,
126126
isposdef, isposdef!, ishermitian, rank, cond,
127127
Diagonal, Hermitian
128128

129-
using SparseArrays: SparseMatrixCSC, sparse, nzrange, rowvals, nonzeros
130-
131129
import Base.Meta
132130

133131
using Random: Random, rand!, randn!
@@ -185,6 +183,21 @@ include("fusiontrees/fusiontrees.jl")
185183
#-------------------------------------------
186184
include("spaces/vectorspaces.jl")
187185

186+
# Multithreading settings
187+
#-------------------------
188+
const TRANSFORMER_THREADS = Ref(1)
189+
190+
get_num_transformer_threads() = TRANSFORMER_THREADS[]
191+
192+
function set_num_transformer_threads(n::Int)
193+
N = Base.Threads.nthreads()
194+
if n > N
195+
n = N
196+
Strided._set_num_threads_warn(n)
197+
end
198+
return TRANSFORMER_THREADS[] = n
199+
end
200+
188201
# Definitions and methods for tensors
189202
#-------------------------------------
190203
# general definitions

src/auxiliary/auxiliary.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,27 @@ _interleave(::Tuple{}, ::Tuple{}) = ()
5555
function _interleave(a::NTuple{N}, b::NTuple{N}) where {N}
5656
return (a[1], b[1], _interleave(tail(a), tail(b))...)
5757
end
58+
59+
# Low-overhead implementation of `copyto!` for specific case of `stride(B, 1) < stride(B, 2)`
60+
# used in indexmanipulations: avoids the overhead of Strided.jl
61+
function _copyto!(A::StridedView{<:Any,1}, B::StridedView{<:Any,2})
62+
length(A) == length(B) || throw(DimensionMismatch())
63+
64+
Adata = parent(A)
65+
Astr = stride(A, 1)
66+
IA = A.offset
67+
68+
Bdata = parent(B)
69+
Bstr = strides(B)
70+
71+
IB_1 = B.offset
72+
@inbounds for _ in axes(B, 2)
73+
IB = IB_1
74+
for _ in axes(B, 1)
75+
Adata[IA += Astr] = Bdata[IB += Bstr[1]]
76+
end
77+
IB_1 += Bstr[2]
78+
end
79+
80+
return A
81+
end

src/spaces/homspace.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,15 @@ end
225225

226226
# Block and fusion tree ranges: structure information for building tensors
227227
#--------------------------------------------------------------------------
228+
229+
# sizes, strides, offset
230+
const StridedStructure{N} = Tuple{NTuple{N,Int},NTuple{N,Int},Int}
231+
228232
struct FusionBlockStructure{I,N,F₁,F₂}
229233
totaldim::Int
230234
blockstructure::SectorDict{I,Tuple{Tuple{Int,Int},UnitRange{Int}}}
231235
fusiontreelist::Vector{Tuple{F₁,F₂}}
232-
fusiontreestructure::Vector{Tuple{NTuple{N,Int},NTuple{N,Int},Int}}
236+
fusiontreestructure::Vector{StridedStructure{N}}
233237
fusiontreeindices::FusionTreeDict{Tuple{F₁,F₂},Int}
234238
end
235239

0 commit comments

Comments
 (0)