|
| 1 | +struct OuterTreeIterator{I<:Sector,N₁,N₂} |
| 2 | + uncoupled::Tuple{NTuple{N₁,I},NTuple{N₂,I}} |
| 3 | + isdual::Tuple{NTuple{N₁,Bool},NTuple{N₂,Bool}} |
| 4 | +end |
| 5 | + |
| 6 | +sectortype(::Type{<:OuterTreeIterator{I}}) where {I} = I |
| 7 | +numout(fs::OuterTreeIterator) = numout(typeof(fs)) |
| 8 | +numout(::Type{<:OuterTreeIterator{I,N₁}}) where {I,N₁} = N₁ |
| 9 | +numin(fs::OuterTreeIterator) = numin(typeof(fs)) |
| 10 | +numin(::Type{<:OuterTreeIterator{I,N₁,N₂}}) where {I,N₁,N₂} = N₂ |
| 11 | +numind(fs::OuterTreeIterator) = numind(typeof(fs)) |
| 12 | +numind(::Type{T}) where {T<:OuterTreeIterator} = numin(T) + numout(T) |
| 13 | + |
| 14 | +# TODO: should we make this an actual iterator? |
| 15 | +function fusiontrees(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} |
| 16 | + F₁ = fusiontreetype(I, N₁) |
| 17 | + F₂ = fusiontreetype(I, N₂) |
| 18 | + |
| 19 | + trees = Vector{Tuple{F₁,F₂}}(undef, 0) |
| 20 | + for c in blocksectors(iter), f₁ in fusiontrees(iter.uncoupled[1], c, iter.isdual[1]), |
| 21 | + f₂ in fusiontrees(iter.uncoupled[2], c, iter.isdual[2]) |
| 22 | + |
| 23 | + push!(trees, (f₁, f₂)) |
| 24 | + end |
| 25 | + return trees |
| 26 | +end |
| 27 | + |
| 28 | +# TODO: better implementation |
| 29 | +Base.length(iter::OuterTreeIterator) = length(fusiontrees(iter)) |
| 30 | + |
| 31 | +function blocksectors(iter::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} |
| 32 | + I == Trivial && return (Trivial(),) |
| 33 | + |
| 34 | + bs_codomain = Vector{I}() |
| 35 | + if N₁ == 0 |
| 36 | + push!(bs_codomain, one(I)) |
| 37 | + elseif N₁ == 1 |
| 38 | + push!(bs_codomain, only(iter.uncoupled[1])) |
| 39 | + else |
| 40 | + for c in ⊗(iter.uncoupled[1]...) |
| 41 | + if !(c in bs_codomain) |
| 42 | + push!(bs_codomain, c) |
| 43 | + end |
| 44 | + end |
| 45 | + end |
| 46 | + |
| 47 | + bs_domain = Vector{I}() |
| 48 | + if N₂ == 0 |
| 49 | + push!(bs_domain, one(I)) |
| 50 | + elseif N₂ == 1 |
| 51 | + push!(bs_domain, only(iter.uncoupled[2])) |
| 52 | + else |
| 53 | + for c in ⊗(iter.uncoupled[2]...) |
| 54 | + if !(c in bs_domain) |
| 55 | + push!(bs_domain, c) |
| 56 | + end |
| 57 | + end |
| 58 | + end |
| 59 | + |
| 60 | + return sort!(collect(intersect(bs_codomain, bs_domain))) |
| 61 | +end |
| 62 | + |
| 63 | +# Manipulations |
| 64 | +# ------------- |
| 65 | + |
| 66 | +function bendright(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} |
| 67 | + uncoupled_dst = (TupleTools.front(fs_src.uncoupled[1]), |
| 68 | + (fs_src.uncoupled[2]..., dual(fs_src.uncoupled[1][end]))) |
| 69 | + isdual_dst = (TupleTools.front(fs_src.isdual[1]), |
| 70 | + (fs_src.isdual[2]..., !(fs_src.isdual[1][end]))) |
| 71 | + fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) |
| 72 | + |
| 73 | + trees_src = fusiontrees(fs_src) |
| 74 | + trees_dst = fusiontrees(fs_dst) |
| 75 | + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) |
| 76 | + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) |
| 77 | + |
| 78 | + for (col, f) in enumerate(trees_src) |
| 79 | + for (f′, c) in bendright(f) |
| 80 | + row = indexmap[f′] |
| 81 | + U[row, col] = c |
| 82 | + end |
| 83 | + end |
| 84 | + |
| 85 | + return fs_dst, U |
| 86 | +end |
| 87 | + |
| 88 | +# TODO: verify if this can be computed through an adjoint |
| 89 | +function bendleft(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} |
| 90 | + uncoupled_dst = ((fs_src.uncoupled[1]..., dual(fs_src.uncoupled[2][end])), |
| 91 | + TupleTools.front(fs_src.uncoupled[2])) |
| 92 | + isdual_dst = ((fs_src.isdual[1]..., !(fs_src.isdual[2][end])), |
| 93 | + TupleTools.front(fs_src.isdual[2])) |
| 94 | + fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) |
| 95 | + |
| 96 | + trees_src = fusiontrees(fs_src) |
| 97 | + trees_dst = fusiontrees(fs_dst) |
| 98 | + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) |
| 99 | + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) |
| 100 | + |
| 101 | + for (col, f) in enumerate(trees_src) |
| 102 | + for (f′, c) in bendleft(f) |
| 103 | + row = indexmap[f′] |
| 104 | + U[row, col] = c |
| 105 | + end |
| 106 | + end |
| 107 | + |
| 108 | + return fs_dst, U |
| 109 | +end |
| 110 | + |
| 111 | +function foldright(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} |
| 112 | + uncoupled_dst = (Base.tail(fs_src.uncoupled[1]), |
| 113 | + (dual(first(fs_src.uncoupled[1])), fs_src.uncoupled[2]...)) |
| 114 | + isdual_dst = (Base.tail(fs_src.isdual[1]), |
| 115 | + (!first(fs_src.isdual[1]), fs_src.isdual[2]...)) |
| 116 | + fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) |
| 117 | + |
| 118 | + trees_src = fusiontrees(fs_src) |
| 119 | + trees_dst = fusiontrees(fs_dst) |
| 120 | + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) |
| 121 | + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) |
| 122 | + |
| 123 | + for (col, f) in enumerate(trees_src) |
| 124 | + for (f′, c) in foldright(f) |
| 125 | + row = indexmap[f′] |
| 126 | + U[row, col] = c |
| 127 | + end |
| 128 | + end |
| 129 | + |
| 130 | + return fs_dst, U |
| 131 | +end |
| 132 | + |
| 133 | +# TODO: verify if this can be computed through an adjoint |
| 134 | +function foldleft(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} |
| 135 | + uncoupled_dst = ((dual(first(fs_src.uncoupled[2])), fs_src.uncoupled[1]...), |
| 136 | + Base.tail(fs_src.uncoupled[2])) |
| 137 | + isdual_dst = ((!first(fs_src.isdual[2]), fs_src.isdual[1]...), |
| 138 | + Base.tail(fs_src.isdual[2])) |
| 139 | + fs_dst = OuterTreeIterator(uncoupled_dst, isdual_dst) |
| 140 | + |
| 141 | + trees_src = fusiontrees(fs_src) |
| 142 | + trees_dst = fusiontrees(fs_dst) |
| 143 | + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) |
| 144 | + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) |
| 145 | + |
| 146 | + for (col, f) in enumerate(trees_src) |
| 147 | + for (f′, c) in foldleft(f) |
| 148 | + row = indexmap[f′] |
| 149 | + U[row, col] = c |
| 150 | + end |
| 151 | + end |
| 152 | + |
| 153 | + return fs_dst, U |
| 154 | +end |
| 155 | + |
| 156 | +function cycleclockwise(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} |
| 157 | + if N₁ > 0 |
| 158 | + fs_tmp, U₁ = foldright(fs_src) |
| 159 | + fs_dst, U₂ = bendleft(fs_tmp) |
| 160 | + else |
| 161 | + fs_tmp, U₁ = bendleft(fs_src) |
| 162 | + fs_dst, U₂ = foldright(fs_tmp) |
| 163 | + end |
| 164 | + return fs_dst, U₂ * U₁ |
| 165 | +end |
| 166 | + |
| 167 | +function cycleanticlockwise(fs_src::OuterTreeIterator{I,N₁,N₂}) where {I,N₁,N₂} |
| 168 | + if N₂ > 0 |
| 169 | + fs_tmp, U₁ = foldleft(fs_src) |
| 170 | + fs_dst, U₂ = bendright(fs_tmp) |
| 171 | + else |
| 172 | + fs_tmp, U₁ = bendright(fs_src) |
| 173 | + fs_dst, U₂ = foldleft(fs_tmp) |
| 174 | + end |
| 175 | + return fs_dst, U₂ * U₁ |
| 176 | +end |
| 177 | + |
| 178 | +@inline function repartition(fs_src::OuterTreeIterator{I,N₁,N₂}, N::Int) where {I,N₁,N₂} |
| 179 | + @assert 0 <= N <= N₁ + N₂ |
| 180 | + return _recursive_repartition(fs_src, Val(N)) |
| 181 | +end |
| 182 | + |
| 183 | +function _repartition_type(I, N, N₁, N₂) |
| 184 | + return Tuple{OuterTreeIterator{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}} |
| 185 | +end |
| 186 | +function _recursive_repartition(fs_src::OuterTreeIterator{I,N₁,N₂}, |
| 187 | + ::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N} |
| 188 | + if N == N₁ |
| 189 | + fs_dst = fs_src |
| 190 | + U = zeros(sectorscalartype(I), length(fs_dst), length(fs_src)) |
| 191 | + copyto!(U, LinearAlgebra.I) |
| 192 | + return fs_dst, U |
| 193 | + end |
| 194 | + |
| 195 | + N == N₁ - 1 && return bendright(fs_src) |
| 196 | + N == N₁ + 1 && return bendleft(fs_src) |
| 197 | + |
| 198 | + fs_tmp, U₁ = N < N₁ ? bendright(fs_src) : bendleft(fs_src) |
| 199 | + fs_dst, U₂ = _recursive_repartition(fs_tmp, Val(N)) |
| 200 | + return fs_dst, U₂ * U₁ |
| 201 | +end |
| 202 | + |
| 203 | +function Base.transpose(fs_src::OuterTreeIterator{I}, p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} |
| 204 | + N = N₁ + N₂ |
| 205 | + @assert numind(fs_src) == N |
| 206 | + p′ = linearizepermutation(p..., numout(fs_src), numin(fs_src)) |
| 207 | + @assert iscyclicpermutation(p′) |
| 208 | + return _fstranspose((fs_src, p)) |
| 209 | +end |
| 210 | + |
| 211 | +const _FSTransposeKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple{N₁,N₂}} |
| 212 | + |
| 213 | +@cached function _fstranspose(key::_FSTransposeKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁, |
| 214 | + N₂}, |
| 215 | + Matrix{sectorscalartype(I)}} where {I, |
| 216 | + N₁, |
| 217 | + N₂} |
| 218 | + fs_src, (p1, p2) = key |
| 219 | + |
| 220 | + N = N₁ + N₂ |
| 221 | + p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src)) |
| 222 | + |
| 223 | + fs_dst, U = repartition(fs_src, N₁) |
| 224 | + length(p) == 0 && return fs_dst, U |
| 225 | + i1 = findfirst(==(1), p)::Int |
| 226 | + i1 == 1 && return fs_dst, U |
| 227 | + |
| 228 | + Nhalf = N >> 1 |
| 229 | + while 1 < i1 ≤ Nhalf |
| 230 | + fs_dst, U_tmp = cycleanticlockwise(fs_dst) |
| 231 | + U = U_tmp * U |
| 232 | + i1 -= 1 |
| 233 | + end |
| 234 | + while Nhalf < i1 |
| 235 | + fs_dst, U_tmp = cycleclockwise(fs_dst) |
| 236 | + U = U_tmp * U |
| 237 | + i1 = mod1(i1 + 1, N) |
| 238 | + end |
| 239 | + |
| 240 | + return fs_dst, U |
| 241 | +end |
| 242 | + |
| 243 | +function CacheStyle(::typeof(_fstranspose), k::_FSTransposeKey{I}) where {I} |
| 244 | + if FusionStyle(I) == UniqueFusion() |
| 245 | + return NoCache() |
| 246 | + else |
| 247 | + return GlobalLRUCache() |
| 248 | + end |
| 249 | +end |
| 250 | + |
| 251 | +function artin_braid(fs_src::OuterTreeIterator{I,N,0}, i; inv::Bool=false) where {I,N} |
| 252 | + 1 <= i < N || |
| 253 | + throw(ArgumentError("Cannot swap outputs i=$i and i+1 out of only $N outputs")) |
| 254 | + |
| 255 | + uncoupled = fs_src.uncoupled[1] |
| 256 | + uncoupled′ = TupleTools.setindex(uncoupled, uncoupled[i + 1], i) |
| 257 | + uncoupled′ = TupleTools.setindex(uncoupled′, uncoupled[i], i + 1) |
| 258 | + |
| 259 | + isdual = fs_src.isdual[1] |
| 260 | + isdual′ = TupleTools.setindex(isdual, isdual[i], i + 1) |
| 261 | + isdual′ = TupleTools.setindex(isdual′, isdual[i + 1], i) |
| 262 | + |
| 263 | + fs_dst = OuterTreeIterator((uncoupled′, ()), (isdual′, ())) |
| 264 | + |
| 265 | + trees_src = fusiontrees(fs_src) |
| 266 | + trees_dst = fusiontrees(fs_dst) |
| 267 | + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) |
| 268 | + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) |
| 269 | + |
| 270 | + for (col, (f₁, f₂)) in enumerate(trees_src) |
| 271 | + for (f₁′, c) in artin_braid(f₁, i; inv) |
| 272 | + row = indexmap[(f₁′, f₂)] |
| 273 | + U[row, col] = c |
| 274 | + end |
| 275 | + end |
| 276 | + |
| 277 | + return fs_dst, U |
| 278 | +end |
| 279 | + |
| 280 | +function braid(fs_src::OuterTreeIterator{I,N,0}, levels::NTuple{N,Int}, |
| 281 | + p::NTuple{N,Int}) where {I,N} |
| 282 | + TupleTools.isperm(p) || throw(ArgumentError("not a valid permutation: $p")) |
| 283 | + |
| 284 | + if FusionStyle(I) isa UniqueFusion && BraidingStyle(I) isa SymmetricBraiding |
| 285 | + uncoupled′ = TupleTools._permute(fs_src.uncoupled[1], p) |
| 286 | + isdual′ = TupleTools._permute(fs_src.isdual[1], p) |
| 287 | + fs_dst = OuterTreeIterator(uncoupled′, isdual′) |
| 288 | + |
| 289 | + trees_src = fusiontrees(fs_src) |
| 290 | + trees_dst = fusiontrees(fs_dst) |
| 291 | + indexmap = Dict(f => ind for (ind, f) in enumerate(trees_dst)) |
| 292 | + U = zeros(sectorscalartype(I), length(trees_dst), length(trees_src)) |
| 293 | + |
| 294 | + for (col, (f₁, f₂)) in enumerate(trees_src) |
| 295 | + for (f₁′, c) in braid(f₁, levels, p) |
| 296 | + row = indexmap[(f₁′, f₂)] |
| 297 | + U[row, col] = c |
| 298 | + end |
| 299 | + end |
| 300 | + |
| 301 | + return fs_dst, U |
| 302 | + end |
| 303 | + |
| 304 | + fs_dst, U = repartition(fs_src, N) # TODO: can we avoid this? |
| 305 | + for s in permutation2swaps(p) |
| 306 | + inv = levels[s] > levels[s + 1] |
| 307 | + fs_dst, U_tmp = artin_braid(fs_dst, s; inv) |
| 308 | + U = U_tmp * U |
| 309 | + end |
| 310 | + return fs_dst, U |
| 311 | +end |
| 312 | + |
| 313 | +function braid(fs_src::OuterTreeIterator{I}, levels::Index2Tuple, |
| 314 | + p::Index2Tuple{N₁,N₂}) where {I,N₁,N₂} |
| 315 | + @assert numind(fs_src) == N₁ + N₂ |
| 316 | + @assert numout(fs_src) == length(levels[1]) && numin(fs_src) == length(levels[2]) |
| 317 | + @assert TupleTools.isperm((p[1]..., p[2]...)) |
| 318 | + return _fsbraid((fs_src, levels, p)) |
| 319 | +end |
| 320 | + |
| 321 | +const _FSBraidKey{I,N₁,N₂} = Tuple{<:OuterTreeIterator{I},Index2Tuple,Index2Tuple{N₁,N₂}} |
| 322 | + |
| 323 | +@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{OuterTreeIterator{I,N₁,N₂}, |
| 324 | + Matrix{sectorscalartype(I)}} where {I, |
| 325 | + N₁, |
| 326 | + N₂} |
| 327 | + fs_src, (l1, l2), (p1, p2) = key |
| 328 | + |
| 329 | + p = linearizepermutation(p1, p2, numout(fs_src), numin(fs_src)) |
| 330 | + levels = (l1..., reverse(l2)...) |
| 331 | + |
| 332 | + fs_dst, U = repartition(fs_src, numind(fs_src)) |
| 333 | + fs_dst, U_tmp = braid(fs_dst, levels, p) |
| 334 | + U = U_tmp * U |
| 335 | + fs_dst, U_tmp = repartition(fs_dst, N₁) |
| 336 | + U = U_tmp * U |
| 337 | + return fs_dst, U |
| 338 | +end |
| 339 | + |
| 340 | +function CacheStyle(::typeof(_fsbraid), k::_FSBraidKey{I}) where {I} |
| 341 | + if FusionStyle(I) isa UniqueFusion |
| 342 | + return NoCache() |
| 343 | + else |
| 344 | + return GlobalLRUCache() |
| 345 | + end |
| 346 | +end |
| 347 | + |
| 348 | +function permute(fs_src::OuterTreeIterator{I}, p::Index2Tuple) where {I} |
| 349 | + @assert BraidingStyle(I) isa SymmetricBraiding |
| 350 | + levels1 = ntuple(identity, numout(fs_src)) |
| 351 | + levels2 = numout(fs_src) .+ ntuple(identity, numin(fs_src)) |
| 352 | + return braid(fs_src, (levels1, levels2), p) |
| 353 | +end |
0 commit comments