Skip to content

Commit 0d5bb31

Browse files
authored
Refactor repacking of transformer structure (#254)
* refactor repacking of transformer structure * rename id -> ind
1 parent d68876b commit 0d5bb31

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

src/tensors/treetransformers.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
5959
t₀ = Base.time()
6060
permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting."))
6161
structure_dst = fusionblockstructure(Vdst)
62+
fusionstructure_dst = structure_dst.fusiontreestructure
6263
structure_src = fusionblockstructure(Vsrc)
64+
fusionstructure_src = structure_src.fusiontreestructure
6365
I = sectortype(Vsrc)
6466

6567
uncoupleds_src = map(structure_src.fusiontreelist) do (f₁, f₂)
@@ -78,29 +80,26 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
7880

7981
# TODO: this can be multithreaded
8082
for (i, uncoupled) in enumerate(uncoupleds_src_unique)
81-
ids_src = findall(==(uncoupled), uncoupleds_src)
82-
fusiontrees_outer_src = structure_src.fusiontreelist[ids_src]
83+
inds_src = findall(==(uncoupled), uncoupleds_src)
84+
fusiontrees_outer_src = structure_src.fusiontreelist[inds_src]
8385

8486
uncoupled_dst = TupleTools.getindices(uncoupled, (p[1]..., p[2]...))
85-
ids_dst = findall(==(uncoupled_dst), uncoupleds_dst)
87+
inds_dst = findall(==(uncoupled_dst), uncoupleds_dst)
8688

87-
fusiontrees_outer_dst = structure_dst.fusiontreelist[ids_dst]
89+
fusiontrees_outer_dst = structure_dst.fusiontreelist[inds_dst]
8890

89-
matrix = zeros(sectorscalartype(I), length(ids_dst), length(ids_src))
91+
matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src))
9092
for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src)
9193
for ((f₃, f₄), coeff) in transform(f₁, f₂)
9294
col = findfirst(==((f₃, f₄)), fusiontrees_outer_dst)::Int
9395
matrix[row, col] = coeff
9496
end
9597
end
9698

97-
structs_src = structure_src.fusiontreestructure[ids_src]
98-
sz_src = structs_src[1][1]
99-
newstructs_src = map(x -> (x[2], x[3]), structs_src)
100-
101-
structs_dst = structure_dst.fusiontreestructure[ids_dst]
102-
sz_dst = structs_dst[1][1]
103-
newstructs_dst = map(x -> (x[2], x[3]), structs_dst)
99+
# size is shared between blocks, so repack:
100+
# from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
101+
sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src, inds_src)
102+
sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, inds_dst)
104103

105104
@debug("Created recoupling block for uncoupled: $uncoupled",
106105
sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix))
@@ -124,6 +123,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
124123
return transformer
125124
end
126125

126+
function repack_transformer_structure(structures, ids)
127+
sz = structures[first(ids)][1]
128+
strides_offsets = map(i -> (structures[i][2], structures[i][3]), ids)
129+
return sz, strides_offsets
130+
end
131+
127132
function buffersize(transformer::GenericTreeTransformer)
128133
return maximum(transformer.data; init=0) do (basistransform, structures_dst, _)
129134
return prod(structures_dst[1]) * size(basistransform, 1)

0 commit comments

Comments
 (0)