Skip to content

Commit 83888c8

Browse files
lkdvosJutho
authored andcommitted
Add treetransformers
1 parent 700ce2d commit 83888c8

File tree

4 files changed

+165
-18
lines changed

4 files changed

+165
-18
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "TensorKit"
22
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
33
authors = ["Jutho Haegeman"]
4-
version = "0.13"
4+
version = "0.13.0"
55

66
[deps]
77
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1112
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1213
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
1314
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
@@ -32,6 +33,7 @@ LRUCache = "1.0.2"
3233
LinearAlgebra = "1"
3334
PackageExtensionCompat = "1"
3435
Random = "1"
36+
SparseArrays = "1"
3537
Strided = "2"
3638
TensorKitSectors = "0.1"
3739
TensorOperations = "5.1"

src/TensorKit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr,
119119
eigen, eigen!, svd, svd!,
120120
isposdef, isposdef!, ishermitian,
121121
Diagonal, Hermitian
122+
123+
using SparseArrays: SparseMatrixCSC, sparse, nzrange, rowvals, nonzeros
124+
122125
import Base.Meta
123126

124127
using Random: Random
@@ -185,6 +188,7 @@ include("tensors/adjoint.jl")
185188
include("tensors/linalg.jl")
186189
include("tensors/vectorinterface.jl")
187190
include("tensors/tensoroperations.jl")
191+
include("tensors/treetransformers.jl")
188192
include("tensors/indexmanipulations.jl")
189193
include("tensors/truncation.jl")
190194
include("tensors/factorizations.jl")

src/tensors/indexmanipulations.jl

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ See also [`permute`](@ref), [`permute!`](@ref), [`add_braid!`](@ref), [`add_tran
279279
α::Number,
280280
β::Number,
281281
backend::AbstractBackend...) where {T,S,N₁,N₂}
282-
treepermuter(f₁, f₂) = permute(f₁, f₂, p[1], p[2])
283-
return add_transform!(tdst, tsrc, p, treepermuter, α, β, backend...)
282+
transformer = treepermuter(tdst, tsrc, p)
283+
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
284284
end
285285

286286
"""
@@ -305,8 +305,8 @@ See also [`braid`](@ref), [`braid!`](@ref), [`add_permute!`](@ref), [`add_transp
305305
levels1 = TupleTools.getindices(levels, codomainind(tsrc))
306306
levels2 = TupleTools.getindices(levels, domainind(tsrc))
307307
# TODO: arg order for tensormaps is different than for fusiontrees
308-
treebraider(f₁, f₂) = braid(f₁, f₂, levels1, levels2, p...)
309-
return add_transform!(tdst, tsrc, p, treebraider, α, β, backend...)
308+
transformer = treebraider(tdst, tsrc, p, (levels1, levels2))
309+
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
310310
end
311311

312312
"""
@@ -324,14 +324,14 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad
324324
α::Number,
325325
β::Number,
326326
backend::AbstractBackend...) where {T,S,N₁,N₂}
327-
treetransposer(f₁, f₂) = transpose(f₁, f₂, p[1], p[2])
328-
return add_transform!(tdst, tsrc, p, treetransposer, α, β, backend...)
327+
transformer = treetransposer(tdst, tsrc, p)
328+
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
329329
end
330330

331331
function add_transform!(tdst::AbstractTensorMap{T,S,N₁,N₂},
332332
tsrc::AbstractTensorMap,
333333
(p₁, p₂)::Index2Tuple{N₁,N₂},
334-
fusiontreetransform,
334+
transformer::GenericTreeTransformer,
335335
α::Number,
336336
β::Number,
337337
backend::AbstractBackend...) where {T,S,N₁,N₂}
@@ -341,16 +341,34 @@ function add_transform!(tdst::AbstractTensorMap{T,S,N₁,N₂},
341341
dest = $(codomain(tdst))$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)"))
342342
end
343343

344-
I = sectortype(S)
345-
if p₁ == codomainind(tsrc) && p₂ == domainind(tsrc)
346-
add!(tdst, tsrc, α, β)
347-
elseif I === Trivial
348-
_add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
349-
elseif FusionStyle(I) isa UniqueFusion
350-
_add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
351-
else
352-
_add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
344+
structure_dst = transformer.structure_dst.fusiontreestructure
345+
structure_src = transformer.structure_src.fusiontreestructure
346+
347+
rows = rowvals(transformer.matrix)
348+
vals = nonzeros(transformer.matrix)
349+
350+
# TODO: this could be multithreaded
351+
for j in axes(transformer.matrix, 2)
352+
sz_dst, str_dst, offset_dst = structure_dst[j]
353+
subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst)
354+
nzrows = nzrange(transformer.matrix, j)
355+
356+
# treat first entry
357+
sz_src, str_src, offset_src = structure_src[rows[first(nzrows)]]
358+
subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src)
359+
TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * vals[first(nzrows)],
360+
β,
361+
backend...)
362+
363+
# treat remaining entries
364+
for i in @view(nzrows[2:end])
365+
sz_src, str_src, offset_src = structure_src[rows[i]]
366+
subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src)
367+
TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * vals[i], One(),
368+
backend...)
369+
end
353370
end
371+
354372
return tdst
355373
end
356374

@@ -389,7 +407,7 @@ function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backen
389407
tdst = scale!(tdst, β)
390408
end
391409
β′ = One()
392-
if Threads.nthreads() > 1
410+
if false # Threads.nthreads() > 1
393411
Threads.@sync for s₁ in sectors(codomain(tsrc)), s₂ in sectors(domain(tsrc))
394412
Threads.@spawn _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁,
395413
s₂, α, β′, backend...)

src/tensors/treetransformers.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
TreeTransformer
3+
4+
Supertype for structures containing the data for a tree transformation.
5+
"""
6+
abstract type TreeTransformer end
7+
8+
function treetransformertype(Vdst, Vsrc)
9+
I = sectortype(Vdst)
10+
N = numind(Vdst)
11+
F1 = fusiontreetype(I, numout(Vdst))
12+
F2 = fusiontreetype(I, numin(Vdst))
13+
F3 = fusiontreetype(I, numout(Vsrc))
14+
F4 = fusiontreetype(I, numin(Vsrc))
15+
return GenericTreeTransformer{sectorscalartype(I),I,N,F1,F2,F3,F4}
16+
end
17+
18+
struct GenericTreeTransformer{T,I,N,F1,F2,F3,F4} <: TreeTransformer
19+
matrix::SparseMatrixCSC{T,Int}
20+
structure_dst::FusionBlockStructure{I,N,F1,F2}
21+
structure_src::FusionBlockStructure{I,N,F3,F4}
22+
end
23+
24+
function GenericTreeTransformer(transform::Function, Vsrc::HomSpace, Vdst::HomSpace)
25+
structure_dst = fusionblockstructure(Vdst)
26+
structure_src = fusionblockstructure(Vsrc)
27+
28+
ldst = length(structure_dst.fusiontreelist)
29+
lsrc = length(structure_src.fusiontreelist)
30+
31+
rows = Int[]
32+
cols = Int[]
33+
vals = sectorscalartype(sectortype(Vdst))[]
34+
35+
for (f1, f2) in structure_src.fusiontreelist
36+
row = structure_src.fusiontreeindices[(f1, f2)]
37+
for ((f3, f4), coeff) in transform(f1, f2)
38+
col = structure_dst.fusiontreeindices[(f3, f4)]
39+
push!(rows, row)
40+
push!(cols, col)
41+
push!(vals, coeff)
42+
end
43+
end
44+
matrix = sparse(rows, cols, vals, ldst, lsrc)
45+
46+
return GenericTreeTransformer(matrix, structure_dst, structure_src)
47+
end
48+
49+
# Transpose
50+
# ---------
51+
const treetransposercache = LRU{Any,Any}(; maxsize=10^5)
52+
const usetreetransposercache = Ref{Bool}(true)
53+
54+
function treetransposer(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple)
55+
if usetreetransposercache[]
56+
key = (space(tdst), space(tsrc), p)
57+
A = treetransformertype(space(tdst), space(tsrc))
58+
return _get_treetransposer(A, key)
59+
else
60+
return _treetransposer((space(tdst), space(tsrc), p))
61+
end
62+
end
63+
@noinline function _get_treetransposer(A, key)
64+
d::A = get!(treetransposercache, key) do
65+
return _treetransposer(key)
66+
end
67+
return d
68+
end
69+
function _treetransposer((Vdst, Vsrc, p))
70+
fusiontreetransform(f1, f2) = transpose(f1, f2, p...)
71+
return GenericTreeTransformer(fusiontreetransform, Vsrc, Vdst)
72+
end
73+
74+
# Braid
75+
# -----
76+
const treebraidercache = LRU{Any,Any}(; maxsize=10^5)
77+
const usetreebraidercache = Ref{Bool}(true)
78+
79+
function treebraider(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple,
80+
l::Index2Tuple)
81+
if usetreebraidercache[]
82+
key = (space(tdst), space(tsrc), p, l)
83+
A = treetransformertype(space(tdst), space(tsrc))
84+
return _get_treebraider(A, key)
85+
else
86+
return _treebraider((space(tdst), space(tsrc), p, l))
87+
end
88+
end
89+
@noinline function _get_treebraider(A, key)
90+
d::A = get!(treebraidercache, key) do
91+
return _treebraider(key)
92+
end
93+
return d
94+
end
95+
function _treebraider((Vdst, Vsrc, p, l))
96+
fusiontreetransform(f1, f2) = braid(f1, f2, p..., l...)
97+
return GenericTreeTransformer(fusiontreetransform, Vsrc, Vdst)
98+
end
99+
100+
# Permute
101+
# -------
102+
const treepermutercache = LRU{Any,Any}(; maxsize=10^5)
103+
const usetreepermutercache = Ref{Bool}(true)
104+
105+
function treepermuter(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple)
106+
if usetreepermutercache[]
107+
key = (space(tdst), space(tsrc), p)
108+
A = treetransformertype(space(tdst), space(tsrc))
109+
return _get_treepermuter(A, key)
110+
else
111+
return _treepermuter((space(tdst), space(tsrc), p))
112+
end
113+
end
114+
@noinline function _get_treepermuter(A, key)
115+
d::A = get!(treepermutercache, key) do
116+
return _treepermuter(key)
117+
end
118+
return d
119+
end
120+
function _treepermuter((Vdst, Vsrc, p))
121+
fusiontreetransform(f1, f2) = permute(f1, f2, p...)
122+
return GenericTreeTransformer(fusiontreetransform, Vsrc, Vdst)
123+
end

0 commit comments

Comments
 (0)