@@ -52,7 +52,14 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
5252 return TupleTools. vcat(f₁. uncoupled, f₂. uncoupled)
5353 end
5454
55- outer_data = map(uncoupleds_src_unique) do uncoupled
55+ T = sectorscalartype(I)
56+ N = numind(Vdst)
57+ L = length(uncoupleds_src_unique)
58+ TStrided = StridedStructure{N}
59+ data = Vector{Tuple{Matrix{T},Vector{TStrided},Vector{TStrided}}}(undef, L)
60+
61+ # TODO : this can be multithreaded
62+ for (i, uncoupled) in enumerate(uncoupleds_src_unique)
5663 ids_src = findall(== (uncoupled), uncoupleds_src)
5764 fusiontrees_outer_src = structure_src. fusiontreelist[ids_src]
5865
@@ -71,21 +78,21 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
7178 @debug(" Created recoupling block for uncoupled: $uncoupled " ,
7279 sz = size(matrix), sparsity = count(! iszero, matrix) / length(matrix))
7380
74- return (matrix,
75- structure_dst. fusiontreestructure[ids_dst],
76- structure_src. fusiontreestructure[ids_src])
81+ data[i] = (matrix,
82+ structure_dst. fusiontreestructure[ids_dst],
83+ structure_src. fusiontreestructure[ids_src])
7784 end
7885
7986 # sort by (approximate) weight to make the buffers happy
8087 # and use round-robin strategy for multi-threading
81- sort!(outer_data ; by= _transformer_weight, rev= true )
88+ sort!(data ; by= _transformer_weight, rev= true )
8289
8390 @debug(" TreeTransformer for $Vsrc to $Vdst via $p " ,
84- nblocks = length(outer_data ),
85- sz_median = size(outer_data [end ÷ 2 ][1 ], 1 ),
86- sz_max = size(outer_data [1 ][1 ], 1 ))
91+ nblocks = length(data ),
92+ sz_median = size(data [end ÷ 2 ][1 ], 1 ),
93+ sz_max = size(data [1 ][1 ], 1 ))
8794
88- return GenericTreeTransformer(outer_data )
95+ return GenericTreeTransformer{T,N}(data )
8996end
9097
9198function _transformer_weight((matrix, structures_dst, structures_src))
0 commit comments