Skip to content

Commit 7f541b5

Browse files
committed
refactor repacking of transformer structure
1 parent d68876b commit 7f541b5

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

src/tensors/treetransformers.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
5959
t₀ = Base.time()
6060
permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting."))
6161
structure_dst = fusionblockstructure(Vdst)
62+
fusionstructure_dst = structure_dst.fusiontreestructure
6263
structure_src = fusionblockstructure(Vsrc)
64+
fusionstructure_src = structure_src.fusiontreestructure
6365
I = sectortype(Vsrc)
6466

6567
uncoupleds_src = map(structure_src.fusiontreelist) do (f₁, f₂)
@@ -94,13 +96,10 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
9496
end
9597
end
9698

97-
structs_src = structure_src.fusiontreestructure[ids_src]
98-
sz_src = structs_src[1][1]
99-
newstructs_src = map(x -> (x[2], x[3]), structs_src)
100-
101-
structs_dst = structure_dst.fusiontreestructure[ids_dst]
102-
sz_dst = structs_dst[1][1]
103-
newstructs_dst = map(x -> (x[2], x[3]), structs_dst)
99+
# size is shared between blocks, so repack:
100+
# from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
101+
sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src, ids_src)
102+
sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, ids_dst)
104103

105104
@debug("Created recoupling block for uncoupled: $uncoupled",
106105
sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix))
@@ -124,6 +123,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
124123
return transformer
125124
end
126125

126+
function repack_transformer_structure(structures, ids)
127+
sz = structures[first(ids)][1]
128+
strides_offsets = map(i -> (structures[i][2], structures[i][3]), ids)
129+
return sz, strides_offsets
130+
end
131+
127132
function buffersize(transformer::GenericTreeTransformer)
128133
return maximum(transformer.data; init=0) do (basistransform, structures_dst, _)
129134
return prod(structures_dst[1]) * size(basistransform, 1)

0 commit comments

Comments
 (0)