Skip to content

Commit 4e40fc3

Browse files
committed
refactor repartition to unroll loop
1 parent 59c8053 commit 4e40fc3

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

src/fusiontrees/fusiontreeblocks.jl

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,45 @@ end
132132

133133
@inline function repartition(src::FusionTreeBlock, N::Int)
134134
@assert 0 <= N <= numind(src)
135-
return _recursive_repartition(src, Val(N))
135+
return repartition(src, Val(N))
136136
end
137137

138-
function _repartition_type(I, N, N₁, N₂)
139-
return Tuple{FusionTreeBlock{I,N,N₁ + N₂ - N},Matrix{sectorscalartype(I)}}
140-
end
141-
function _recursive_repartition(src::FusionTreeBlock{I,N₁,N₂},
142-
::Val{N})::_repartition_type(I, N, N₁, N₂) where {I,N₁,N₂,N}
143-
if N == N₁
144-
dst = src
145-
U = zeros(sectorscalartype(I), length(dst), length(src))
146-
copyto!(U, LinearAlgebra.I)
147-
return dst, U
148-
end
138+
#=
139+
Using a generated function here to ensure type stability by unrolling the loops:
140+
```julia
141+
dst, U = bendleft/right(src)
149142
150-
N == N₁ - 1 && return bendright(src)
151-
N == N₁ + 1 && return bendleft(src)
143+
# repeat the following 2 lines N - 1 times
144+
dst, Utmp = bendleft/right(dst)
145+
U = Utmp * U
152146
153-
tmp, U₁ = N < N₁ ? bendright(src) : bendleft(src)
154-
dst, U₂ = _recursive_repartition(tmp, Val(N))
155-
return dst, U₂ * U₁
147+
return dst, U
148+
```
149+
=#
150+
@generated function repartition(src::FusionTreeBlock, ::Val{N}) where {N}
151+
return _repartition_body(numin(src) - N)
152+
end
153+
function _repartition_body(N)
154+
if N == 0
155+
ex = quote
156+
T = sectorscalartype(sectortype(src))
157+
U = copyto!(zeros(T, length(src), length(src)), LinearAlgebra.I)
158+
return src, U
159+
end
160+
else
161+
f = N < 0 ? bendleft : bendright
162+
ex_rep = Expr(:block)
163+
for _ in 1:(abs(N) - 1)
164+
push!(ex_rep.args, :((dst, Utmp) = $f(dst)))
165+
push!(ex_rep.args, :(U = Utmp * U))
166+
end
167+
ex = quote
168+
dst, U = $f(src)
169+
$ex_rep
170+
return dst, U
171+
end
172+
end
173+
return ex
156174
end
157175

158176
function Base.transpose(src::FusionTreeBlock, p::Index2Tuple{N₁,N₂}) where {N₁,N₂}

0 commit comments

Comments
 (0)