@@ -66,32 +66,85 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
6666 I = sectortype (Vsrc)
6767 T = sectorscalartype (I)
6868 N = numind (Vdst)
69- data = Vector {_GenericTransformerData{T,N}} ()
7069
7170 isdual_src = (map (isdual, codomain (Vsrc). spaces), map (isdual, domain (Vsrc). spaces))
72- for cod_uncoupled_src in sectors (codomain (Vsrc)),
73- dom_uncoupled_src in sectors (domain (Vsrc))
7471
75- fs_src = FusionTreeBlock {I} ((cod_uncoupled_src, dom_uncoupled_src), isdual_src)
76- trees_src = fusiontrees (fs_src)
77- isempty (trees_src) && continue
72+ nthreads = get_num_transformer_threads ()
73+ if nthreads > 1
74+ fusiontreeblocks = Vector {FusionTreeBlock{I,N₁,N₂,fusiontreetype(I, N₁, N₂)}} ()
75+ for cod_uncoupled_src in sectors (codomain (Vsrc)),
76+ dom_uncoupled_src in sectors (domain (Vsrc))
77+
78+ fs_src = FusionTreeBlock {I} ((cod_uncoupled_src, dom_uncoupled_src), isdual_src)
79+ trees_src = fusiontrees (fs_src)
80+ if ! isempty (trees_src)
81+ push! (fusiontreeblocks, fs_src)
82+ end
83+ end
84+
85+ data = Vector {_GenericTransformerData{T,N}} (undef, length (fusiontreeblocks))
86+ counter = Threads. Atomic {Int} (1 )
87+ Threads. @sync for _ in 1 : min (nthreads, length (fusiontreeblocks))
88+ Threads. @spawn begin
89+ while true
90+ local_counter = Threads. atomic_add! (counter, 1 )
91+ local_counter > nblocks && break
92+ fs_src = fusiontreeblocks[local_counter]
93+ fs_dst, U = transform (fs_src)
94+ matrix = copy (transpose (U)) # TODO : should we avoid this
95+
96+ inds_src = map (Base. Fix1 (getindex, structure_src. fusiontreeindices),
97+ trees_src)
98+ trees_dst = fusiontrees (fs_dst)
99+ inds_dst = map (Base. Fix1 (getindex, structure_dst. fusiontreeindices),
100+ trees_dst)
101+
102+ # size is shared between blocks, so repack:
103+ # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
104+ sz_src, newstructs_src = repack_transformer_structure (fusionstructure_src,
105+ inds_src)
106+ sz_dst, newstructs_dst = repack_transformer_structure (fusionstructure_dst,
107+ inds_dst)
108+
109+ @debug (" Created recoupling block for uncoupled: $uncoupled " ,
110+ sz = size (matrix),
111+ sparsity = count (! iszero, matrix) / length (matrix))
112+
113+ data[local_counter] = (matrix, (sz_dst, newstructs_dst),
114+ (sz_src, newstructs_src))
115+ end
116+ end
117+ end
118+ else
119+ data = Vector {_GenericTransformerData{T,N}} ()
78120
79- fs_dst, U = transform (fs_src)
80- matrix = copy (transpose (U)) # TODO : should we avoid this
121+ isdual_src = (map (isdual, codomain (Vsrc). spaces), map (isdual, domain (Vsrc). spaces))
122+ for cod_uncoupled_src in sectors (codomain (Vsrc)),
123+ dom_uncoupled_src in sectors (domain (Vsrc))
81124
82- inds_src = map (Base . Fix1 (getindex, structure_src . fusiontreeindices ), trees_src )
83- trees_dst = fusiontrees (fs_dst )
84- inds_dst = map (Base . Fix1 (getindex, structure_dst . fusiontreeindices), trees_dst)
125+ fs_src = FusionTreeBlock {I} ((cod_uncoupled_src, dom_uncoupled_src ), isdual_src )
126+ trees_src = fusiontrees (fs_src )
127+ isempty (trees_src) && continue
85128
86- # size is shared between blocks, so repack:
87- # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
88- sz_src, newstructs_src = repack_transformer_structure (fusionstructure_src, inds_src)
89- sz_dst, newstructs_dst = repack_transformer_structure (fusionstructure_dst, inds_dst)
129+ fs_dst, U = transform (fs_src)
130+ matrix = copy (transpose (U)) # TODO : should we avoid this
90131
91- @debug (" Created recoupling block for uncoupled: $uncoupled " ,
92- sz = size (matrix), sparsity = count (! iszero, matrix) / length (matrix))
132+ inds_src = map (Base. Fix1 (getindex, structure_src. fusiontreeindices), trees_src)
133+ trees_dst = fusiontrees (fs_dst)
134+ inds_dst = map (Base. Fix1 (getindex, structure_dst. fusiontreeindices), trees_dst)
93135
94- push! (data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src)))
136+ # size is shared between blocks, so repack:
137+ # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
138+ sz_src, newstructs_src = repack_transformer_structure (fusionstructure_src,
139+ inds_src)
140+ sz_dst, newstructs_dst = repack_transformer_structure (fusionstructure_dst,
141+ inds_dst)
142+
143+ @debug (" Created recoupling block for uncoupled: $uncoupled " ,
144+ sz = size (matrix), sparsity = count (! iszero, matrix) / length (matrix))
145+
146+ push! (data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src)))
147+ end
95148 end
96149
97150 transformer = GenericTreeTransformer {T,N} (data)
0 commit comments