Skip to content

Commit 09a4075

Browse files
committed
refactor in terms of FusionTreeBlock
1 parent 7b257f7 commit 09a4075

File tree

4 files changed

+307
-354
lines changed

4 files changed

+307
-354
lines changed
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
struct FusionTreeBlock{I,N₁,N₂,F<:FusionTreePair{I,N₁,N₂}}
2+
trees::Vector{F}
3+
end
4+
5+
function FusionTreeBlock(uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}},
6+
isdual::Tuple{NTuple{N₁,I},NTuple{N₂,I}}) where {I<:Sector,N₁,N₂}
7+
F₁ = fusiontreetype(I, N₁)
8+
F₂ = fusiontreetype(I, N₂)
9+
trees = Vector{Tuple{F₁,F₂}}(undef, 0)
10+
11+
cs = sort!(intersect((uncoupled[1]...), (uncoupled[2]...)))
12+
for c in cs
13+
for f₁ in fusiontrees(uncoupled[1], c, isdual[1]),
14+
f₂ in fusiontrees(uncoupled[2], c, isdual[2])
15+
16+
push!(trees, (f₁, f₂))
17+
end
18+
end
19+
return FusionTreeBlock(trees)
20+
end
21+
22+
Base.@constprop :aggressive function Base.getproperty(block::FusionTreeBlock, prop::Symbol)
23+
if prop === :uncoupled
24+
f₁, f₂ = first(block.trees)
25+
return f₁.uncoupled, f₂.uncoupled
26+
elseif prop === :isdual
27+
f₁, f₂ = first(block.trees)
28+
return f₁.isdual, f₂.isdual
29+
else
30+
return getfield(block, prop)
31+
end
32+
end
33+
34+
Base.propertynames(::FusionTreeBlock, private::Bool=false) = (:trees, :uncoupled, :isdual)
35+
36+
sectortype(::Type{<:FusionTreeBlock{I}}) where {I} = I
37+
numout(fs::FusionTreeBlock) = numout(typeof(fs))
38+
numout(::Type{<:FusionTreeBlock{I,N₁}}) where {I,N₁} = N₁
39+
numin(fs::FusionTreeBlock) = numin(typeof(fs))
40+
numin(::Type{<:FusionTreeBlock{I,N₁,N₂}}) where {I,N₁,N₂} = N₂
41+
numind(fs::FusionTreeBlock) = numind(typeof(fs))
42+
numind(::Type{T}) where {T<:FusionTreeBlock} = numin(T) + numout(T)
43+
44+
fusiontrees(block::FusionTreeBlock) = block.trees
45+
Base.length(block::FusionTreeBlock) = length(fusiontrees(block))
46+
47+
# Manipulations
48+
# -------------
49+
function transformation_matrix(f, dst::FusionTreeBlock{I},
50+
src::FusionTreeBlock{I}) where {I}
51+
U = zeros(sectorscalartype(I), length(dst), length(src))
52+
indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst))
53+
for (col, f) in enumerate(fusiontrees(src))
54+
for (f′, c) in transform(f)
55+
row = indexmap[f′]
56+
U[row, col] = c
57+
end
58+
end
59+
return U
60+
end
61+
62+
function bendright(src::FusionTreeBlock)
63+
uncoupled_dst = (TupleTools.front(src.uncoupled[1]),
64+
(src.uncoupled[2]..., dual(src.uncoupled[1][end])))
65+
isdual_dst = (TupleTools.front(src.isdual[1]),
66+
(src.isdual[2]..., !(src.isdual[1][end])))
67+
dst = FusionTreeBlock(uncoupled_dst, isdual_dst)
68+
69+
U = transformation_matrix(bendright, dst, src)
70+
return dst, U
71+
end
72+
73+
# TODO: verify if this can be computed through an adjoint
74+
function bendleft(src::FusionTreeBlock)
75+
uncoupled_dst = ((src.uncoupled[1]..., dual(src.uncoupled[2][end])),
76+
TupleTools.front(src.uncoupled[2]))
77+
isdual_dst = ((src.isdual[1]..., !(src.isdual[2][end])),
78+
TupleTools.front(src.isdual[2]))
79+
dst = FusionTreeBlock(uncoupled_dst, isdual_dst)
80+
81+
U = transformation_matrix(bendleft, dst, src)
82+
return dst, U
83+
end
84+
85+
function foldright(src::FusionTreeBlock)
86+
uncoupled_dst = (Base.tail(src.uncoupled[1]),
87+
(dual(first(src.uncoupled[1])), src.uncoupled[2]...))
88+
isdual_dst = (Base.tail(src.isdual[1]),
89+
(!first(src.isdual[1]), src.isdual[2]...))
90+
dst = FusionTreeBlock(uncoupled_dst, isdual_dst)
91+
92+
U = transformation_matrix(foldright, dst, src)
93+
return dst, U
94+
end
95+
96+
# TODO: verify if this can be computed through an adjoint
97+
function foldleft(src::FusionTreeBlock)
98+
uncoupled_dst = ((dual(first(src.uncoupled[2])), src.uncoupled[1]...),
99+
Base.tail(src.uncoupled[2]))
100+
isdual_dst = ((!first(src.isdual[2]), src.isdual[1]...),
101+
Base.tail(src.isdual[2]))
102+
dst = FusionTreeBlock(uncoupled_dst, isdual_dst)
103+
104+
U = transformation_matrix(foldright, dst, src)
105+
return dst, U
106+
end
107+
108+
function cycleclockwise(src::FusionTreeBlock)
109+
if N₁ > 0
110+
tmp, U₁ = foldright(src)
111+
dst, U₂ = bendleft(tmp)
112+
else
113+
tmp, U₁ = bendleft(src)
114+
dst, U₂ = foldright(tmp)
115+
end
116+
return dst, U₂ * U₁
117+
end
118+
119+
function cycleanticlockwise(src::FusionTreeBlock)
120+
if N₂ > 0
121+
tmp, U₁ = foldleft(src)
122+
dst, U₂ = bendright(tmp)
123+
else
124+
tmp, U₁ = bendright(src)
125+
dst, U₂ = foldleft(tmp)
126+
end
127+
return dst, U₂ * U₁
128+
end
129+
130+
@inline function repartition(src::FusionTreeBlock{I,N₁,N₂}, N::Int) where {I,N₁,N₂}
131+
@assert 0 <= N <= N₁ + N₂
132+
return _recursive_repartition(src, Val(N))
133+
end
134+
135+
function _repartition_type(I, N, N₁, N₂)
136+
return Tuple{FusionTreeBlock{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}}
137+
end
138+
function _recursive_repartition(src::FusionTreeBlock{I,N₁,N₂},
139+
::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N}
140+
if N == N₁
141+
dst = src
142+
U = zeros(sectorscalartype(I), length(dst), length(src))
143+
copyto!(U, LinearAlgebra.I)
144+
return dst, U
145+
end
146+
147+
N == N₁ - 1 && return bendright(src)
148+
N == N₁ + 1 && return bendleft(src)
149+
150+
tmp, U₁ = N < N₁ ? bendright(src) : bendleft(src)
151+
dst, U₂ = _recursive_repartition(tmp, Val(N))
152+
return dst, U₂ * U₁
153+
end
154+
155+
function Base.transpose(src::FusionTreeBlock{I}, p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂}
156+
N = N₁ + N₂
157+
@assert numind(src) == N
158+
p′ = linearizepermutation(p..., numout(src), numin(src))
159+
@assert iscyclicpermutation(p′)
160+
return _fstranspose((src, p))
161+
end
162+
163+
const _FSTransposeKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N₂}}
164+
165+
@cached function _fstranspose(key::_FSTransposeKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,
166+
N₂},
167+
Matrix{sectorscalartype(I)}} where {I,
168+
N₁,
169+
N₂}
170+
src, (p1, p2) = key
171+
172+
N = N₁ + N₂
173+
p = linearizepermutation(p1, p2, numout(src), numin(src))
174+
175+
dst, U = repartition(src, N₁)
176+
length(p) == 0 && return dst, U
177+
i1 = findfirst(==(1), p)::Int
178+
i1 == 1 && return dst, U
179+
180+
Nhalf = N >> 1
181+
while 1 < i1 Nhalf
182+
dst, U_tmp = cycleanticlockwise(dst)
183+
U = U_tmp * U
184+
i1 -= 1
185+
end
186+
while Nhalf < i1
187+
dst, U_tmp = cycleclockwise(dst)
188+
U = U_tmp * U
189+
i1 = mod1(i1 + 1, N)
190+
end
191+
192+
return dst, U
193+
end
194+
195+
function CacheStyle(::typeof(_fstranspose), k::_FSTransposeKey{I}) where {I}
196+
if FusionStyle(I) == UniqueFusion()
197+
return NoCache()
198+
else
199+
return GlobalLRUCache()
200+
end
201+
end
202+
203+
function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N}
204+
1 <= i < N ||
205+
throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs"))
206+
207+
uncoupled = src.uncoupled[1]
208+
uncoupled′ = TupleTools.setindex(uncoupled, uncoupled[i + 1], i)
209+
uncoupled′ = TupleTools.setindex(uncoupled′, uncoupled[i], i + 1)
210+
isdual = src.isdual[1]
211+
isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1)
212+
isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i)
213+
dst = FusionTreeBlock((uncoupled′, ()), (isdual′, ()))
214+
215+
# TODO: do we want to rewrite `artin_braid` to take double trees instead?
216+
U = transformation_matrix(dst, src) do (f₁, f₂)
217+
return ((f₁′, f₂) => c for (f₁, c) in artin_braid(f₁, i; inv))
218+
end
219+
return dst, U
220+
end
221+
222+
function braid(src::FusionTreeBlock{I,N,0}, p::NTuple{N,Int},
223+
levels::NTuple{N,Int}) where {I,N}
224+
TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p"))
225+
226+
if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding
227+
uncoupled′ = TupleTools._permute(src.uncoupled[1], p)
228+
isdual′ = TupleTools._permute(src.isdual[1], p)
229+
dst = FusionTreeBlock(uncoupled′, isdual′)
230+
U = transformation_matrix(dst, src) do (f₁, f₂)
231+
return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels))
232+
end
233+
else
234+
dst, U = repartition(src, N) # TODO: can we avoid this?
235+
for s in permutation2swaps(p)
236+
inv = levels[s] > levels[s + 1]
237+
dst, U_tmp = artin_braid(dst, s; inv)
238+
U = U_tmp * U
239+
end
240+
end
241+
return dst, U
242+
end
243+
244+
function braid(src::FusionTreeBlock{I}, p::Index2Tuple{N₁,N₂},
245+
levels::Index2Tuple) where {I,N₁,N₂}
246+
@assert numind(src) == N₁ + N₂
247+
@assert numout(src) == length(levels[1]) && numin(src) == length(levels[2])
248+
@assert TupleTools.isperm((p[1]..., p[2]...))
249+
return _fsbraid((src, p, levels))
250+
end
251+
252+
const _FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N₂},Index2Tuple}
253+
254+
@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,N₂},
255+
Matrix{sectorscalartype(I)}} where {I,
256+
N₁,
257+
N₂}
258+
src, (p1, p2), (l1, l2) = key
259+
260+
p = linearizepermutation(p1, p2, numout(src), numin(src))
261+
levels = (l1..., reverse(l2)...)
262+
263+
dst, U = repartition(src, numind(src))
264+
265+
if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding
266+
uncoupled′ = TupleTools._permute(dst.uncoupled[1], p)
267+
isdual′ = TupleTools._permute(dst.isdual[1], p)
268+
269+
dst′ = FusionTreeBlock(uncoupled′, isdual′)
270+
U_tmp = transformation_matrix(dst′, dst) do (f₁, f₂)
271+
return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels))
272+
end
273+
dst = dst′
274+
U = U_tmp * U
275+
else
276+
for s in permutation2swaps(p)
277+
inv = levels[s] > levels[s + 1]
278+
dst, U_tmp = artin_braid(dst, s; inv)
279+
U = U_tmp * U
280+
end
281+
end
282+
283+
if N₁ == 0
284+
return dst, U
285+
else
286+
dst, U_tmp = repartition(dst, N₁)
287+
U = U_tmp * U
288+
return dst, U
289+
end
290+
end
291+
292+
function CacheStyle(::typeof(_fsbraid), k::_FSBraidKey{I}) where {I}
293+
if FusionStyle(I) isa UniqueFusion
294+
return NoCache()
295+
else
296+
return GlobalLRUCache()
297+
end
298+
end
299+
300+
function permute(src::FusionTreeBlock{I}, p::Index2Tuple) where {I}
301+
@assert BraidingStyle(I) isa SymmetricBraiding
302+
levels1 = ntuple(identity, numout(src))
303+
levels2 = numout(src) .+ ntuple(identity, numin(src))
304+
return braid(src, p, (levels1, levels2))
305+
end

src/fusiontrees/fusiontrees.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ end
227227

228228
# Fusion tree iterators
229229
include("iterator.jl")
230-
include("uncouplediterator.jl")
230+
include("fusiontreeblocks.jl")
231231

232232
# Manipulate fusion trees
233233
include("manipulations.jl")

0 commit comments

Comments
 (0)