@@ -59,7 +59,9 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
5959 t₀ = Base. time()
6060 permute(Vsrc, p) == Vdst || throw(SpaceMismatch(" Incompatible spaces for permuting." ))
6161 structure_dst = fusionblockstructure(Vdst)
62+ fusionstructure_dst = structure_dst. fusiontreestructure
6263 structure_src = fusionblockstructure(Vsrc)
64+ fusionstructure_src = structure_src. fusiontreestructure
6365 I = sectortype(Vsrc)
6466
6567 uncoupleds_src = map(structure_src. fusiontreelist) do (f₁, f₂)
@@ -94,13 +96,10 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
9496 end
9597 end
9698
97- structs_src = structure_src. fusiontreestructure[ids_src]
98- sz_src = structs_src[1 ][1 ]
99- newstructs_src = map(x -> (x[2 ], x[3 ]), structs_src)
100-
101- structs_dst = structure_dst. fusiontreestructure[ids_dst]
102- sz_dst = structs_dst[1 ][1 ]
103- newstructs_dst = map(x -> (x[2 ], x[3 ]), structs_dst)
99+ # size is shared between blocks, so repack:
100+ # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
101+ sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src, ids_src)
102+ sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, ids_dst)
104103
105104 @debug(" Created recoupling block for uncoupled: $uncoupled " ,
106105 sz = size(matrix), sparsity = count(! iszero, matrix) / length(matrix))
@@ -124,6 +123,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
124123 return transformer
125124end
126125
126+ function repack_transformer_structure(structures, ids)
127+ sz = structures[first(ids)][1 ]
128+ strides_offsets = map(i -> (structures[i][2 ], structures[i][3 ]), ids)
129+ return sz, strides_offsets
130+ end
131+
127132function buffersize(transformer:: GenericTreeTransformer )
128133 return maximum(transformer. data; init= 0 ) do (basistransform, structures_dst, _)
129134 return prod(structures_dst[1 ]) * size(basistransform, 1 )
0 commit comments