@@ -59,7 +59,9 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
5959 t₀ = Base. time ()
6060 permute (Vsrc, p) == Vdst || throw (SpaceMismatch (" Incompatible spaces for permuting." ))
6161 structure_dst = fusionblockstructure (Vdst)
62+ fusionstructure_dst = structure_dst. fusiontreestructure
6263 structure_src = fusionblockstructure (Vsrc)
64+ fusionstructure_src = structure_src. fusiontreestructure
6365 I = sectortype (Vsrc)
6466
6567 uncoupleds_src = map (structure_src. fusiontreelist) do (f₁, f₂)
@@ -78,29 +80,26 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
7880
7981 # TODO : this can be multithreaded
8082 for (i, uncoupled) in enumerate (uncoupleds_src_unique)
81- ids_src = findall (== (uncoupled), uncoupleds_src)
82- fusiontrees_outer_src = structure_src. fusiontreelist[ids_src ]
83+ inds_src = findall (== (uncoupled), uncoupleds_src)
84+ fusiontrees_outer_src = structure_src. fusiontreelist[inds_src ]
8385
8486 uncoupled_dst = TupleTools. getindices (uncoupled, (p[1 ]. .. , p[2 ]. .. ))
85- ids_dst = findall (== (uncoupled_dst), uncoupleds_dst)
87+ inds_dst = findall (== (uncoupled_dst), uncoupleds_dst)
8688
87- fusiontrees_outer_dst = structure_dst. fusiontreelist[ids_dst ]
89+ fusiontrees_outer_dst = structure_dst. fusiontreelist[inds_dst ]
8890
89- matrix = zeros (sectorscalartype (I), length (ids_dst ), length (ids_src ))
91+ matrix = zeros (sectorscalartype (I), length (inds_dst ), length (inds_src ))
9092 for (row, (f₁, f₂)) in enumerate (fusiontrees_outer_src)
9193 for ((f₃, f₄), coeff) in transform (f₁, f₂)
9294 col = findfirst (== ((f₃, f₄)), fusiontrees_outer_dst):: Int
9395 matrix[row, col] = coeff
9496 end
9597 end
9698
97- structs_src = structure_src. fusiontreestructure[ids_src]
98- sz_src = structs_src[1 ][1 ]
99- newstructs_src = map (x -> (x[2 ], x[3 ]), structs_src)
100-
101- structs_dst = structure_dst. fusiontreestructure[ids_dst]
102- sz_dst = structs_dst[1 ][1 ]
103- newstructs_dst = map (x -> (x[2 ], x[3 ]), structs_dst)
99+ # size is shared between blocks, so repack:
100+ # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
101+ sz_src, newstructs_src = repack_transformer_structure (fusionstructure_src, inds_src)
102+ sz_dst, newstructs_dst = repack_transformer_structure (fusionstructure_dst, inds_dst)
104103
105104 @debug (" Created recoupling block for uncoupled: $uncoupled " ,
106105 sz = size (matrix), sparsity = count (! iszero, matrix) / length (matrix))
@@ -124,6 +123,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
124123 return transformer
125124end
126125
126+ function repack_transformer_structure (structures, ids)
127+ sz = structures[first (ids)][1 ]
128+ strides_offsets = map (i -> (structures[i][2 ], structures[i][3 ]), ids)
129+ return sz, strides_offsets
130+ end
131+
127132function buffersize (transformer:: GenericTreeTransformer )
128133 return maximum (transformer. data; init= 0 ) do (basistransform, structures_dst, _)
129134 return prod (structures_dst[1 ]) * size (basistransform, 1 )
0 commit comments