@@ -132,27 +132,45 @@ end
132132
133133@inline function repartition (src:: FusionTreeBlock , N:: Int )
134134 @assert 0 <= N <= numind (src)
135- return _recursive_repartition (src, Val (N))
135+ return repartition (src, Val (N))
136136end
137137
138- function _repartition_type (I, N, N₁, N₂)
139- return Tuple{FusionTreeBlock{I,N,N₁ + N₂ - N},Matrix{sectorscalartype (I)}}
140- end
141- function _recursive_repartition (src:: FusionTreeBlock{I,N₁,N₂} ,
142- :: Val{N} ):: _repartition_type (I, N, N₁, N₂) where {I,N₁,N₂,N}
143- if N == N₁
144- dst = src
145- U = zeros (sectorscalartype (I), length (dst), length (src))
146- copyto! (U, LinearAlgebra. I)
147- return dst, U
148- end
138+ #=
139+ Using a generated function here to ensure type stability by unrolling the loops:
140+ ```julia
141+ dst, U = bendleft/right(src)
149142
150- N == N₁ - 1 && return bendright (src)
151- N == N₁ + 1 && return bendleft (src)
143+ # repeat the following 2 lines N - 1 times
144+ dst, Utmp = bendleft/right(dst)
145+ U = Utmp * U
152146
153- tmp, U₁ = N < N₁ ? bendright (src) : bendleft (src)
154- dst, U₂ = _recursive_repartition (tmp, Val (N))
155- return dst, U₂ * U₁
147+ return dst, U
148+ ```
149+ =#
150+ @generated function repartition (src:: FusionTreeBlock , :: Val{N} ) where {N}
151+ return _repartition_body (numout (src) - N)
152+ end
153+ function _repartition_body (N)
154+ if N == 0
155+ ex = quote
156+ T = sectorscalartype (sectortype (src))
157+ U = copyto! (zeros (T, length (src), length (src)), LinearAlgebra. I)
158+ return src, U
159+ end
160+ else
161+ f = N < 0 ? bendleft : bendright
162+ ex_rep = Expr (:block )
163+ for _ in 1 : (abs (N) - 1 )
164+ push! (ex_rep. args, :((dst, Utmp) = $ f (dst)))
165+ push! (ex_rep. args, :(U = Utmp * U))
166+ end
167+ ex = quote
168+ dst, U = $ f (src)
169+ $ ex_rep
170+ return dst, U
171+ end
172+ end
173+ return ex
156174end
157175
158176function Base. transpose (src:: FusionTreeBlock , p:: Index2Tuple{N₁,N₂} ) where {N₁,N₂}
0 commit comments