Skip to content

Commit 84f04ab

Browse files
committed
centralize sorting
1 parent d07ca89 commit 84f04ab

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

src/tensors/treetransformers.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc)
3232
data[i] = (coeff, stridestructure_dst, stridestructure_src)
3333
end
3434

35-
return AbelianTreeTransformer(data)
35+
transformer = AbelianTreeTransformer(data)
36+
37+
# sort by (approximate) weight to facilitate multi-threading strategies
38+
# sort!(transformer)
39+
40+
return transformer
3641
end
3742

3843
const _GenericTransformerData{T,N} = Tuple{Matrix{T},Vector{StridedStructure{N}},
@@ -87,25 +92,17 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
8792
structure_src.fusiontreestructure[ids_src])
8893
end
8994

90-
# sort by (approximate) weight to make the buffers happy
91-
# and use round-robin strategy for multi-threading
92-
sort!(data; by=_transformer_weight, rev=true)
93-
9495
@debug("TreeTransformer for $Vsrc to $Vdst via $p",
9596
nblocks = length(data),
9697
sz_median = size(data[end ÷ 2][1], 1),
9798
sz_max = size(data[1][1], 1))
9899

99-
return GenericTreeTransformer{T,N}(data)
100-
end
100+
transformer = GenericTreeTransformer{T,N}(data)
101101

102-
# Cost model for transforming a set of subblocks with fixed uncoupled sectors:
103-
# L x L x length(subblock) where L is the number of subblocks
104-
# this is L input blocks each going to L output blocks of given length
105-
# Note that it might be the case that the permutations are dominant, in which case the
106-
# actual cost model would scale like L x length(subblock)
107-
function _transformer_weight((matrix, structures_dst, structures_src))
108-
return length(matrix) * prod(structures_dst[1][1])
102+
# sort by (approximate) weight to facilitate multi-threading strategies
103+
# sort!(transformer)
104+
105+
return transformer
109106
end
110107

111108
function buffersize(transformer::GenericTreeTransformer)
@@ -174,3 +171,24 @@ for (transform, treetransformer) in
174171
end
175172

176173
# default cachestyle is GlobalLRUCache
174+
175+
# Sorting based on cost model
176+
# ---------------------------
177+
function Base.sort!(transformer::Union{AbelianTreeTransformer,GenericTreeTransformer};
178+
by=_transformer_weight, rev::Bool=true)
179+
sort!(transformer.data; by, rev)
180+
return transformer
181+
end
182+
183+
function _transformer_weight((coeff, struct_dst, struct_src)::_AbelianTransformerData)
184+
return prod(struct_dst[1])
185+
end
186+
187+
# Cost model for transforming a set of subblocks with fixed uncoupled sectors:
188+
# L x L x length(subblock) where L is the number of subblocks
189+
# this is L input blocks each going to L output blocks of given length
190+
# Note that it might be the case that the permutations are dominant, in which case the
191+
# actual cost model would scale like L x length(subblock)
192+
function _transformer_weight((mat, structs_dst, structs_src)::_GenericTransformerData)
193+
return length(mat) * prod(structs_dst[1][1])
194+
end

0 commit comments

Comments
 (0)