Skip to content

Commit f5cd7bf

Browse files
committed
improve type stability
1 parent 3889d05 commit f5cd7bf

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

src/tensors/treetransformers.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,14 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
5252
return TupleTools.vcat(f₁.uncoupled, f₂.uncoupled)
5353
end
5454

55-
outer_data = map(uncoupleds_src_unique) do uncoupled
55+
T = sectorscalartype(I)
56+
N = numind(Vdst)
57+
L = length(uncoupleds_src_unique)
58+
TStrided = StridedStructure{N}
59+
data = Vector{Tuple{Matrix{T},Vector{TStrided},Vector{TStrided}}}(undef, L)
60+
61+
# TODO: this can be multithreaded
62+
for (i, uncoupled) in enumerate(uncoupleds_src_unique)
5663
ids_src = findall(==(uncoupled), uncoupleds_src)
5764
fusiontrees_outer_src = structure_src.fusiontreelist[ids_src]
5865

@@ -71,21 +78,21 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
7178
@debug("Created recoupling block for uncoupled: $uncoupled",
7279
sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix))
7380

74-
return (matrix,
75-
structure_dst.fusiontreestructure[ids_dst],
76-
structure_src.fusiontreestructure[ids_src])
81+
data[i] = (matrix,
82+
structure_dst.fusiontreestructure[ids_dst],
83+
structure_src.fusiontreestructure[ids_src])
7784
end
7885

7986
# sort by (approximate) weight to make the buffers happy
8087
# and use round-robin strategy for multi-threading
81-
sort!(outer_data; by=_transformer_weight, rev=true)
88+
sort!(data; by=_transformer_weight, rev=true)
8289

8390
@debug("TreeTransformer for $Vsrc to $Vdst via $p",
84-
nblocks = length(outer_data),
85-
sz_median = size(outer_data[end ÷ 2][1], 1),
86-
sz_max = size(outer_data[1][1], 1))
91+
nblocks = length(data),
92+
sz_median = size(data[end ÷ 2][1], 1),
93+
sz_max = size(data[1][1], 1))
8794

88-
return GenericTreeTransformer(outer_data)
95+
return GenericTreeTransformer{T,N}(data)
8996
end
9097

9198
function _transformer_weight((matrix, structures_dst, structures_src))

0 commit comments

Comments
 (0)