Skip to content

Commit 44c848f

Browse files
committed
refactor in terms of FusionTreeBlock
1 parent 4ebd7f5 commit 44c848f

File tree

4 files changed

+309
-354
lines changed

4 files changed

+309
-354
lines changed
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
struct FusionTreeBlock{I,N₁,N₂,F<:FusionTreePair{I,N₁,N₂}}
2+
trees::Vector{F}
3+
end
4+
5+
function FusionTreeBlock(uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}},
6+
isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}}) where {I<:Sector,N₁,N₂}
7+
F₁ = fusiontreetype(I, N₁)
8+
F₂ = fusiontreetype(I, N₂)
9+
trees = Vector{Tuple{F₁,F₂}}(undef, 0)
10+
11+
cleft = N₁ == 0 ? (one(I),) : (uncoupled[1]...)
12+
cright = N₂ == 0 ? (one(I),) : (uncoupled[2]...)
13+
cs = sort!(collect(intersect(cleft, cright)))
14+
for c in cs
15+
for f₁ in fusiontrees(uncoupled[1], c, isdual[1]),
16+
f₂ in fusiontrees(uncoupled[2], c, isdual[2])
17+
18+
push!(trees, (f₁, f₂))
19+
end
20+
end
21+
return FusionTreeBlock(trees)
22+
end
23+
24+
Base.@constprop :aggressive function Base.getproperty(block::FusionTreeBlock, prop::Symbol)
25+
if prop === :uncoupled
26+
f₁, f₂ = first(block.trees)
27+
return f₁.uncoupled, f₂.uncoupled
28+
elseif prop === :isdual
29+
f₁, f₂ = first(block.trees)
30+
return f₁.isdual, f₂.isdual
31+
else
32+
return getfield(block, prop)
33+
end
34+
end
35+
36+
Base.propertynames(::FusionTreeBlock, private::Bool=false) = (:trees, :uncoupled, :isdual)
37+
38+
sectortype(::Type{<:FusionTreeBlock{I}}) where {I} = I
39+
numout(fs::FusionTreeBlock) = numout(typeof(fs))
40+
numout(::Type{<:FusionTreeBlock{I,N₁}}) where {I,N₁} = N₁
41+
numin(fs::FusionTreeBlock) = numin(typeof(fs))
42+
numin(::Type{<:FusionTreeBlock{I,N₁,N₂}}) where {I,N₁,N₂} = N₂
43+
numind(fs::FusionTreeBlock) = numind(typeof(fs))
44+
numind(::Type{T}) where {T<:FusionTreeBlock} = numin(T) + numout(T)
45+
46+
fusiontrees(block::FusionTreeBlock) = block.trees
47+
Base.length(block::FusionTreeBlock) = length(fusiontrees(block))
48+
49+
# Manipulations
50+
# -------------
51+
function transformation_matrix(transform, dst::FusionTreeBlock{I},
52+
src::FusionTreeBlock{I}) where {I}
53+
U = zeros(sectorscalartype(I), length(dst), length(src))
54+
indexmap = Dict(f => ind for (ind, f) in enumerate(fusiontrees(dst)))
55+
for (col, f) in enumerate(fusiontrees(src))
56+
for (f′, c) in transform(f)
57+
row = indexmap[f′]
58+
U[row, col] = c
59+
end
60+
end
61+
return U
62+
end
63+
64+
function bendright(src::FusionTreeBlock)
65+
uncoupled_dst = (TupleTools.front(src.uncoupled[1]),
66+
(src.uncoupled[2]..., dual(src.uncoupled[1][end])))
67+
isdual_dst = (TupleTools.front(src.isdual[1]),
68+
(src.isdual[2]..., !(src.isdual[1][end])))
69+
dst = FusionTreeBlock(uncoupled_dst, isdual_dst)
70+
71+
U = transformation_matrix(bendright, dst, src)
72+
return dst, U
73+
end
74+
75+
# TODO: verify if this can be computed through an adjoint
76+
function bendleft(src::FusionTreeBlock)
77+
uncoupled_dst = ((src.uncoupled[1]..., dual(src.uncoupled[2][end])),
78+
TupleTools.front(src.uncoupled[2]))
79+
isdual_dst = ((src.isdual[1]..., !(src.isdual[2][end])),
80+
TupleTools.front(src.isdual[2]))
81+
dst = FusionTreeBlock(uncoupled_dst, isdual_dst)
82+
83+
U = transformation_matrix(bendleft, dst, src)
84+
return dst, U
85+
end
86+
87+
function foldright(src::FusionTreeBlock)
88+
uncoupled_dst = (Base.tail(src.uncoupled[1]),
89+
(dual(first(src.uncoupled[1])), src.uncoupled[2]...))
90+
isdual_dst = (Base.tail(src.isdual[1]),
91+
(!first(src.isdual[1]), src.isdual[2]...))
92+
dst = FusionTreeBlock(uncoupled_dst, isdual_dst)
93+
94+
U = transformation_matrix(foldright, dst, src)
95+
return dst, U
96+
end
97+
98+
# TODO: verify if this can be computed through an adjoint
99+
function foldleft(src::FusionTreeBlock)
100+
uncoupled_dst = ((dual(first(src.uncoupled[2])), src.uncoupled[1]...),
101+
Base.tail(src.uncoupled[2]))
102+
isdual_dst = ((!first(src.isdual[2]), src.isdual[1]...),
103+
Base.tail(src.isdual[2]))
104+
dst = FusionTreeBlock(uncoupled_dst, isdual_dst)
105+
106+
U = transformation_matrix(foldleft, dst, src)
107+
return dst, U
108+
end
109+
110+
function cycleclockwise(src::FusionTreeBlock)
111+
if numout(src) > 0
112+
tmp, U₁ = foldright(src)
113+
dst, U₂ = bendleft(tmp)
114+
else
115+
tmp, U₁ = bendleft(src)
116+
dst, U₂ = foldright(tmp)
117+
end
118+
return dst, U₂ * U₁
119+
end
120+
121+
function cycleanticlockwise(src::FusionTreeBlock)
122+
if numin(src) > 0
123+
tmp, U₁ = foldleft(src)
124+
dst, U₂ = bendright(tmp)
125+
else
126+
tmp, U₁ = bendright(src)
127+
dst, U₂ = foldleft(tmp)
128+
end
129+
return dst, U₂ * U₁
130+
end
131+
132+
@inline function repartition(src::FusionTreeBlock, N::Int)
133+
@assert 0 <= N <= numind(src)
134+
return _recursive_repartition(src, Val(N))
135+
end
136+
137+
function _repartition_type(I, N, N₁, N₂)
138+
return Tuple{FusionTreeBlock{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}}
139+
end
140+
function _recursive_repartition(src::FusionTreeBlock{I,N₁,N₂},
141+
::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N}
142+
if N == N₁
143+
dst = src
144+
U = zeros(sectorscalartype(I), length(dst), length(src))
145+
copyto!(U, LinearAlgebra.I)
146+
return dst, U
147+
end
148+
149+
N == N₁ - 1 && return bendright(src)
150+
N == N₁ + 1 && return bendleft(src)
151+
152+
tmp, U₁ = N < N₁ ? bendright(src) : bendleft(src)
153+
dst, U₂ = _recursive_repartition(tmp, Val(N))
154+
return dst, U₂ * U₁
155+
end
156+
157+
function Base.transpose(src::FusionTreeBlock, p::Index2Tuple{N₁,N₂}) where {N₁,N₂}
158+
N = N₁ + N₂
159+
@assert numind(src) == N
160+
p′ = linearizepermutation(p..., numout(src), numin(src))
161+
@assert iscyclicpermutation(p′)
162+
return _fstranspose((src, p))
163+
end
164+
165+
const _FSTransposeKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N₂}}
166+
167+
@cached function _fstranspose(key::_FSTransposeKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,
168+
N₂},
169+
Matrix{sectorscalartype(I)}} where {I,
170+
N₁,
171+
N₂}
172+
src, (p1, p2) = key
173+
174+
N = N₁ + N₂
175+
p = linearizepermutation(p1, p2, numout(src), numin(src))
176+
177+
dst, U = repartition(src, N₁)
178+
length(p) == 0 && return dst, U
179+
i1 = findfirst(==(1), p)::Int
180+
i1 == 1 && return dst, U
181+
182+
Nhalf = N >> 1
183+
while 1 < i1 Nhalf
184+
dst, U_tmp = cycleanticlockwise(dst)
185+
U = U_tmp * U
186+
i1 -= 1
187+
end
188+
while Nhalf < i1
189+
dst, U_tmp = cycleclockwise(dst)
190+
U = U_tmp * U
191+
i1 = mod1(i1 + 1, N)
192+
end
193+
194+
return dst, U
195+
end
196+
197+
function CacheStyle(::typeof(_fstranspose), k::_FSTransposeKey{I}) where {I}
198+
if FusionStyle(I) == UniqueFusion()
199+
return NoCache()
200+
else
201+
return GlobalLRUCache()
202+
end
203+
end
204+
205+
function artin_braid(src::FusionTreeBlock{I,N,0}, i; inv::Bool=false) where {I,N}
206+
1 <= i < N ||
207+
throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs"))
208+
209+
uncoupled = src.uncoupled[1]
210+
uncoupled′ = TupleTools.setindex(uncoupled, uncoupled[i + 1], i)
211+
uncoupled′ = TupleTools.setindex(uncoupled′, uncoupled[i], i + 1)
212+
isdual = src.isdual[1]
213+
isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1)
214+
isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i)
215+
dst = FusionTreeBlock((uncoupled′, ()), (isdual′, ()))
216+
217+
# TODO: do we want to rewrite `artin_braid` to take double trees instead?
218+
U = transformation_matrix(dst, src) do (f₁, f₂)
219+
return ((f₁′, f₂) => c for (f₁′, c) in artin_braid(f₁, i; inv))
220+
end
221+
return dst, U
222+
end
223+
224+
function braid(src::FusionTreeBlock{I,N,0}, p::NTuple{N,Int},
225+
levels::NTuple{N,Int}) where {I,N}
226+
TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p"))
227+
228+
if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding
229+
uncoupled′ = TupleTools._permute(src.uncoupled[1], p)
230+
isdual′ = TupleTools._permute(src.isdual[1], p)
231+
dst = FusionTreeBlock(uncoupled′, isdual′)
232+
U = transformation_matrix(dst, src) do (f₁, f₂)
233+
return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels))
234+
end
235+
else
236+
dst, U = repartition(src, N) # TODO: can we avoid this?
237+
for s in permutation2swaps(p)
238+
inv = levels[s] > levels[s + 1]
239+
dst, U_tmp = artin_braid(dst, s; inv)
240+
U = U_tmp * U
241+
end
242+
end
243+
return dst, U
244+
end
245+
246+
function braid(src::FusionTreeBlock{I}, p::Index2Tuple{N₁,N₂},
247+
levels::Index2Tuple) where {I,N₁,N₂}
248+
@assert numind(src) == N₁ + N₂
249+
@assert numout(src) == length(levels[1]) && numin(src) == length(levels[2])
250+
@assert TupleTools.isperm((p[1]..., p[2]...))
251+
return _fsbraid((src, p, levels))
252+
end
253+
254+
const _FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N₂},Index2Tuple}
255+
256+
@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,N₂},
257+
Matrix{sectorscalartype(I)}} where {I,
258+
N₁,
259+
N₂}
260+
src, (p1, p2), (l1, l2) = key
261+
262+
p = linearizepermutation(p1, p2, numout(src), numin(src))
263+
levels = (l1..., reverse(l2)...)
264+
265+
dst, U = repartition(src, numind(src))
266+
267+
if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding
268+
uncoupled′ = TupleTools._permute(dst.uncoupled[1], p)
269+
isdual′ = TupleTools._permute(dst.isdual[1], p)
270+
271+
dst′ = FusionTreeBlock(uncoupled′, isdual′)
272+
U_tmp = transformation_matrix(dst′, dst) do (f₁, f₂)
273+
return ((f₁′, f₂) => c for (f₁, c) in braid(f₁, p, levels))
274+
end
275+
dst = dst′
276+
U = U_tmp * U
277+
else
278+
for s in permutation2swaps(p)
279+
inv = levels[s] > levels[s + 1]
280+
dst, U_tmp = artin_braid(dst, s; inv)
281+
U = U_tmp * U
282+
end
283+
end
284+
285+
if N₂ == 0
286+
return dst, U
287+
else
288+
dst, U_tmp = repartition(dst, N₁)
289+
U = U_tmp * U
290+
return dst, U
291+
end
292+
end
293+
294+
function CacheStyle(::typeof(_fsbraid), k::_FSBraidKey{I}) where {I}
295+
if FusionStyle(I) isa UniqueFusion
296+
return NoCache()
297+
else
298+
return GlobalLRUCache()
299+
end
300+
end
301+
302+
function permute(src::FusionTreeBlock{I}, p::Index2Tuple) where {I}
303+
@assert BraidingStyle(I) isa SymmetricBraiding
304+
levels1 = ntuple(identity, numout(src))
305+
levels2 = numout(src) .+ ntuple(identity, numin(src))
306+
return braid(src, p, (levels1, levels2))
307+
end

src/fusiontrees/fusiontrees.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ end
227227

228228
# Fusion tree iterators
229229
include("iterator.jl")
230-
include("uncouplediterator.jl")
230+
include("fusiontreeblocks.jl")
231231

232232
# Manipulate fusion trees
233233
include("manipulations.jl")

0 commit comments

Comments
 (0)