Skip to content

Commit d90034b

Browse files
committed
Fix transformer weight
1 parent 00bfd04 commit d90034b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/tensors/treetransformers.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,11 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
9595
return GenericTreeTransformer{T,N}(data)
9696
end
9797

98+
# Cost model for transforming a set of subblocks with fixed uncoupled sectors:
99+
# L x L x size(subblock) where L is the number of subblocks
100+
# this is L input blocks each going to L output blocks of given size
98101
function _transformer_weight((matrix, structures_dst, structures_src))
99-
return size(matrix, 1) * prod(structures_dst[1][1])
102+
return length(matrix) * prod(structures_dst[1][1])
100103
end
101104

102105
function buffersize(transformer::GenericTreeTransformer)

0 commit comments

Comments
 (0)