Skip to content

Commit 8c129a4

Browse files
committed
Refactor foldleft to avoid extra dictionaries
1 parent f8eb207 commit 8c129a4

File tree

1 file changed

+58
-3
lines changed

1 file changed

+58
-3
lines changed

src/fusiontrees/fusiontreeblocks.jl

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,70 @@ function foldright(src::FusionTreeBlock)
244244
return dst, U
245245
end
246246

247-
# TODO: verify if this can be computed through an adjoint
247+
# !! note that this is more or less a copy of foldright through
248+
# (f1, f2) => conj(coeff) for ((f2, f1), coeff) in foldright(src)
248249
function foldleft(src::FusionTreeBlock)
249250
uncoupled_dst = ((dual(first(src.uncoupled[2])), src.uncoupled[1]...),
250251
Base.tail(src.uncoupled[2]))
251252
isdual_dst = ((!first(src.isdual[2]), src.isdual[1]...),
252253
Base.tail(src.isdual[2]))
253-
dst = FusionTreeBlock{sectortype(src)}(uncoupled_dst, isdual_dst)
254+
I = sectortype(src)
255+
N₁ = numin(src)
256+
N₂ = numout(src)
257+
@assert N₁ > 0
258+
259+
dst = FusionTreeBlock{I}(uncoupled_dst, isdual_dst)
260+
indexmap = fusiontreedict(I)(f => ind for (ind, f) in enumerate(fusiontrees(dst)))
261+
U = zeros(sectorscalartype(I), length(dst), length(src))
254262

255-
U = transformation_matrix(foldleft, dst, src)
263+
for (col, (f₂, f₁)) in enumerate(fusiontrees(src))
264+
# map first splitting vertex (a, b)<-c to fusion vertex b<-(dual(a), c)
265+
a = f₁.uncoupled[1]
266+
isduala = f₁.isdual[1]
267+
factor = sqrtdim(a)
268+
if !isduala
269+
factor *= conj(frobeniusschur(a))
270+
end
271+
c1 = dual(a)
272+
c2 = f₁.coupled
273+
uncoupled = Base.tail(f₁.uncoupled)
274+
isdual = Base.tail(f₁.isdual)
275+
if FusionStyle(I) isa UniqueFusion
276+
c = first(c1 c2)
277+
fl = FusionTree{I}(Base.tail(f₁.uncoupled), c, Base.tail(f₁.isdual))
278+
fr = FusionTree{I}((c1, f₂.uncoupled...), c, (!isduala, f₂.isdual...))
279+
row = indexmap[(fr, fl)]
280+
@inbounds U[row, col] = conj(factor)
281+
else
282+
if N₁ == 1
283+
cset = (leftone(c1),) # or rightone(a)
284+
elseif N₁ == 2
285+
cset = (f₁.uncoupled[2],)
286+
else
287+
cset = (Base.tail(f₁.uncoupled)...)
288+
end
289+
for c in c1 c2
290+
c cset || continue
291+
for μ in 1:Nsymbol(c1, c2, c)
292+
fc = FusionTree((c1, c2), c, (!isduala, false), (), (μ,))
293+
for (fl′, coeff1) in insertat(fc, 2, f₁)
294+
N₁ > 1 && !isone(fl′.innerlines[1]) && continue
295+
coupled = fl′.coupled
296+
uncoupled = Base.tail(Base.tail(fl′.uncoupled))
297+
isdual = Base.tail(Base.tail(fl′.isdual))
298+
inner = N₁ <= 3 ? () : Base.tail(Base.tail(fl′.innerlines))
299+
vertices = N₁ <= 2 ? () : Base.tail(Base.tail(fl′.vertices))
300+
fl = FusionTree{I}(uncoupled, coupled, isdual, inner, vertices)
301+
for (fr, coeff2) in insertat(fc, 2, f₂)
302+
coeff = factor * coeff1 * conj(coeff2)
303+
row = indexmap[(fr, fl)]
304+
@inbounds U[row, col] = conj(coeff)
305+
end
306+
end
307+
end
308+
end
309+
end
310+
end
256311
return dst, U
257312
end
258313

0 commit comments

Comments
 (0)