Skip to content

Commit 9ccc644

Browse files
committed
implement "vectorized" fusiontree manipulations
1 parent 0c88157 commit 9ccc644

File tree

2 files changed

+357
-3
lines changed

2 files changed

+357
-3
lines changed

src/fusiontrees/fusiontrees.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,12 @@ function Base.show(io::IO, t::FusionTree{I}) where {I<:Sector}
225225
end
226226
end
227227

228-
# Manipulate fusion trees
229-
include("manipulations.jl")
230-
231228
# Fusion tree iterators
232229
include("iterator.jl")
230+
include("uncouplediterator.jl")
231+
232+
# Manipulate fusion trees
233+
include("manipulations.jl")
233234

234235
# auxiliary routines
235236
# _abelianinner: generate the inner indices for given outer indices in the abelian case
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
struct OuterTreeIterator{I<:Sector,N₁,N₂}
2+
uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}}
3+
isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}}
4+
end
5+
6+
sectortype(::Type{<:OuterTreeIterator{I}}) where {I} = I
7+
numout(fs::OuterTreeIterator) = numout(typeof(fs))
8+
numout(::Type{<:OuterTreeIterator{I,N₁}}) where {I,N₁} = N₁
9+
numin(fs::OuterTreeIterator) = numin(typeof(fs))
10+
numin(::Type{<:OuterTreeIterator{I,N₁,N₂}}) where {I,N₁,N₂} = N₂
11+
numind(fs::OuterTreeIterator) = numind(typeof(fs))
12+
numind(::Type{T}) where {T<:OuterTreeIterator} = numin(T) + numout(T)
13+
14+
# TODO: should we make this an actual iterator?
15+
function fusiontrees(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂}
16+
F₁ = fusiontreetype(I, N₁)
17+
F₂ = fusiontreetype(I, N₂)
18+
19+
trees = Vector{Tuple{F₁,F₂}}(undef, 0)
20+
for c in blocksectors(iter), f₁ in fusiontrees(iter.uncoupled[1], c, iter.isdual[1]),
21+
f₂ in fusiontrees(iter.uncoupled[2], c, iter.isdual[2])
22+
23+
push!(trees, (f₁, f₂))
24+
end
25+
return trees
26+
end
27+
28+
# TODO: better implementation
29+
Base.length(iter::OuterTreeIterator) = length(fusiontrees(iter))
30+
31+
function blocksectors(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂}
32+
I == Trivial && return (Trivial(),)
33+
34+
bs_codomain = Vector{I}()
35+
if N₁ == 0
36+
push!(bs_codomain, one(I))
37+
elseif N₁ == 1
38+
push!(bs_codomain, only(iter.uncoupled[1]))
39+
else
40+
for c in (iter.uncoupled[1]...)
41+
if !(c in bs_codomain)
42+
push!(bs_codomain, c)
43+
end
44+
end
45+
end
46+
47+
bs_domain = Vector{I}()
48+
if N₂ == 0
49+
push!(bs_domain, one(I))
50+
elseif N₂ == 1
51+
push!(bs_domain, only(iter.uncoupled[2]))
52+
else
53+
for c in (iter.uncoupled[2]...)
54+
if !(c in bs_domain)
55+
push!(bs_domain, c)
56+
end
57+
end
58+
end
59+
60+
return sort!(collect(intersect(bs_codomain, bs_domain)))
61+
end
62+
63+
# Manipulations
64+
# -------------
65+
66+
function bendright(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂}
67+
uncoupled_dst = (TupleTools.front(fs_src.uncoupled[1]),
68+
(fs_src.uncoupled[2]..., dual(fs_src.uncoupled[1][end])))
69+
isdual_dst = (TupleTools.front(fs_src.isdual[1]),
70+
(fs_src.isdual[2]..., !(fs_src.isdual[1][end])))
71+
fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst)
72+
73+
trees_src = fusiontrees(fs_src)
74+
trees_dst = fusiontrees(fs_dst)
75+
indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst))
76+
U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src))
77+
78+
for (col, f) in enumerate(trees_src)
79+
for (f′, c) in bendright(f)
80+
row = indexmap[f′]
81+
U[row, col] = c
82+
end
83+
end
84+
85+
return fs_dst, U
86+
end
87+
88+
# TODO: verify if this can be computed through an adjoint
89+
function bendleft(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂}
90+
uncoupled_dst = ((fs_src.uncoupled[1]..., dual(fs_src.uncoupled[2][end])),
91+
TupleTools.front(fs_src.uncoupled[2]))
92+
isdual_dst = ((fs_src.isdual[1]..., !(fs_src.isdual[2][end])),
93+
TupleTools.front(fs_src.isdual[2]))
94+
fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst)
95+
96+
trees_src = fusiontrees(fs_src)
97+
trees_dst = fusiontrees(fs_dst)
98+
indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst))
99+
U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src))
100+
101+
for (col, f) in enumerate(trees_src)
102+
for (f′, c) in bendleft(f)
103+
row = indexmap[f′]
104+
U[row, col] = c
105+
end
106+
end
107+
108+
return fs_dst, U
109+
end
110+
111+
function foldright(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂}
112+
uncoupled_dst = (Base.tail(fs_src.uncoupled[1]),
113+
(dual(first(fs_src.uncoupled[1])), fs_src.uncoupled[2]...))
114+
isdual_dst = (Base.tail(fs_src.isdual[1]),
115+
(!first(fs_src.isdual[1]), fs_src.isdual[2]...))
116+
fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst)
117+
118+
trees_src = fusiontrees(fs_src)
119+
trees_dst = fusiontrees(fs_dst)
120+
indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst))
121+
U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src))
122+
123+
for (col, f) in enumerate(trees_src)
124+
for (f′, c) in foldright(f)
125+
row = indexmap[f′]
126+
U[row, col] = c
127+
end
128+
end
129+
130+
return fs_dst, U
131+
end
132+
133+
# TODO: verify if this can be computed through an adjoint
134+
function foldleft(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂}
135+
uncoupled_dst = ((dual(first(fs_src.uncoupled[2])), fs_src.uncoupled[1]...),
136+
Base.tail(fs_src.uncoupled[2]))
137+
isdual_dst = ((!first(fs_src.isdual[2]), fs_src.isdual[1]...),
138+
Base.tail(fs_src.isdual[2]))
139+
fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst)
140+
141+
trees_src = fusiontrees(fs_src)
142+
trees_dst = fusiontrees(fs_dst)
143+
indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst))
144+
U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src))
145+
146+
for (col, f) in enumerate(trees_src)
147+
for (f′, c) in foldleft(f)
148+
row = indexmap[f′]
149+
U[row, col] = c
150+
end
151+
end
152+
153+
return fs_dst, U
154+
end
155+
156+
function cycleclockwise(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂}
157+
if N₁ > 0
158+
fs_tmp, U₁ = foldright(fs_src)
159+
fs_dst, U₂ = bendleft(fs_tmp)
160+
else
161+
fs_tmp, U₁ = bendleft(fs_src)
162+
fs_dst, U₂ = foldright(fs_tmp)
163+
end
164+
return fs_dst, U₂ * U₁
165+
end
166+
167+
function cycleanticlockwise(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂}
168+
if N₂ > 0
169+
fs_tmp, U₁ = foldleft(fs_src)
170+
fs_dst, U₂ = bendright(fs_tmp)
171+
else
172+
fs_tmp, U₁ = bendright(fs_src)
173+
fs_dst, U₂ = foldleft(fs_tmp)
174+
end
175+
return fs_dst, U₂ * U₁
176+
end
177+
178+
@inline function repartition(fs_src::OuterTreeIterator{I,N₁,N₂}, N::Int) where {I,N₁,N₂}
179+
@assert 0 <= N <= N₁ + N₂
180+
return _recursive_repartition(fs_src, Val(N))
181+
end
182+
183+
function _repartition_type(I, N, N₁, N₂)
184+
return Tuple{OuterTreeIterator{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}}
185+
end
186+
function _recursive_repartition(fs_src::OuterTreeIterator{I,N₁,N₂},
187+
::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N}
188+
if N == N₁
189+
fs_dst = fs_src
190+
U = zeros(sectorscalartype(I), length(fs_dst), length(fs_src))
191+
copyto!(U, LinearAlgebra.I)
192+
return fs_dst, U
193+
end
194+
195+
N == N₁ - 1 && return bendright(fs_src)
196+
N == N₁ + 1 && return bendleft(fs_src)
197+
198+
fs_tmp, U₁ = N < N₁ ? bendright(fs_src) : bendleft(fs_src)
199+
fs_dst, U₂ = _recursive_repartition(fs_tmp, Val(N))
200+
return fs_dst, U₂ * U₁
201+
end
202+
203+
function Base.transpose(fs_src::OuterTreeIterator{I}, p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂}
204+
N = N₁ + N₂
205+
@assert numind(fs_src) == N
206+
p′ = linearizepermutation(p..., numout(fs_src), numin(fs_src))
207+
@assert iscyclicpermutation(p′)
208+
return _fstranspose((fs_src, p))
209+
end
210+
211+
const _FSTransposeKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple{N₁,N₂}}
212+
213+
@cached function _fstranspose(key::_FSTransposeKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁,
214+
N₂},
215+
Matrix{sectorscalartype(I)}} where {I,
216+
N₁,
217+
N₂}
218+
fs_src, (p1, p2) = key
219+
220+
N = N₁ + N₂
221+
p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src))
222+
223+
fs_dst, U = repartition(fs_src, N₁)
224+
length(p) == 0 && return fs_dst, U
225+
i1 = findfirst(==(1), p)::Int
226+
i1 == 1 && return fs_dst, U
227+
228+
Nhalf = N >> 1
229+
while 1 < i1 Nhalf
230+
fs_dst, U_tmp = cycleanticlockwise(fs_dst)
231+
U = U_tmp * U
232+
i1 -= 1
233+
end
234+
while Nhalf < i1
235+
fs_dst, U_tmp = cycleclockwise(fs_dst)
236+
U = U_tmp * U
237+
i1 = mod1(i1 + 1, N)
238+
end
239+
240+
return fs_dst, U
241+
end
242+
243+
function CacheStyle(::typeof(_fstranspose), k::_FSTransposeKey{I}) where {I}
244+
if FusionStyle(I) == UniqueFusion()
245+
return NoCache()
246+
else
247+
return GlobalLRUCache()
248+
end
249+
end
250+
251+
function artin_braid(fs_src::OuterTreeIterator{I,N,0}, i; inv::Bool=false) where {I,N}
252+
1 <= i < N ||
253+
throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs"))
254+
255+
uncoupled = fs_src.uncoupled[1]
256+
uncoupled′ = TupleTools.setindex(uncoupled, uncoupled[i + 1], i)
257+
uncoupled′ = TupleTools.setindex(uncoupled′, uncoupled[i], i + 1)
258+
259+
isdual = fs_src.isdual[1]
260+
isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1)
261+
isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i)
262+
263+
fs_dst = OuterTreeIterator((uncoupled′, ()), (isdual′, ()))
264+
265+
trees_src = fusiontrees(fs_src)
266+
trees_dst = fusiontrees(fs_dst)
267+
indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst))
268+
U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src))
269+
270+
for (col, (f₁, f₂)) in enumerate(trees_src)
271+
for (f₁′, c) in artin_braid(f₁, i; inv)
272+
row = indexmap[(f₁′, f₂)]
273+
U[row, col] = c
274+
end
275+
end
276+
277+
return fs_dst, U
278+
end
279+
280+
function braid(fs_src::OuterTreeIterator{I,N,0}, levels::NTuple{N,Int},
281+
p::NTuple{N,Int}) where {I,N}
282+
TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p"))
283+
284+
if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding
285+
uncoupled′ = TupleTools._permute(fs_src.uncoupled[1], p)
286+
isdual′ = TupleTools._permute(fs_src.isdual[1], p)
287+
fs_dst = OuterTreeIterator(uncoupled′, isdual′)
288+
289+
trees_src = fusiontrees(fs_src)
290+
trees_dst = fusiontrees(fs_dst)
291+
indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst))
292+
U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src))
293+
294+
for (col, (f₁, f₂)) in enumerate(trees_src)
295+
for (f₁′, c) in braid(f₁, levels, p)
296+
row = indexmap[(f₁′, f₂)]
297+
U[row, col] = c
298+
end
299+
end
300+
301+
return fs_dst, U
302+
end
303+
304+
fs_dst, U = repartition(fs_src, N) # TODO: can we avoid this?
305+
for s in permutation2swaps(p)
306+
inv = levels[s] > levels[s + 1]
307+
fs_dst, U_tmp = artin_braid(fs_dst, s; inv)
308+
U = U_tmp * U
309+
end
310+
return fs_dst, U
311+
end
312+
313+
function braid(fs_src::OuterTreeIterator{I}, levels::Index2Tuple,
314+
p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂}
315+
@assert numind(fs_src) == N₁ + N₂
316+
@assert numout(fs_src) == length(levels[1]) && numin(fs_src) == length(levels[2])
317+
@assert TupleTools.isperm((p[1]..., p[2]...))
318+
return _fsbraid((fs_src, levels, p))
319+
end
320+
321+
const _FSBraidKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple,Index2Tuple{N₁,N₂}}
322+
323+
@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁,N₂},
324+
Matrix{sectorscalartype(I)}} where {I,
325+
N₁,
326+
N₂}
327+
fs_src, (l1, l2), (p1, p2) = key
328+
329+
p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src))
330+
levels = (l1..., reverse(l2)...)
331+
332+
fs_dst, U = repartition(fs_src, numind(fs_src))
333+
fs_dst, U_tmp = braid(fs_dst, levels, p)
334+
U = U_tmp * U
335+
fs_dst, U_tmp = repartition(fs_dst, N₁)
336+
U = U_tmp * U
337+
return fs_dst, U
338+
end
339+
340+
function CacheStyle(::typeof(_fsbraid), k::_FSBraidKey{I}) where {I}
341+
if FusionStyle(I) isa UniqueFusion
342+
return NoCache()
343+
else
344+
return GlobalLRUCache()
345+
end
346+
end
347+
348+
function permute(fs_src::OuterTreeIterator{I}, p::Index2Tuple) where {I}
349+
@assert BraidingStyle(I) isa SymmetricBraiding
350+
levels1 = ntuple(identity, numout(fs_src))
351+
levels2 = numout(fs_src) .+ ntuple(identity, numin(fs_src))
352+
return braid(fs_src, (levels1, levels2), p)
353+
end

0 commit comments

Comments
 (0)