Skip to content

Commit ed9e260

Browse files
committed
fix multithreaded implementation
1 parent 9247e38 commit ed9e260

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/tensors/treetransformers.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,11 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
8585
push!(fusiontreeblocks, fs_src)
8686
end
8787
end
88+
nblocks = length(fusiontreeblocks)
8889

89-
resize!(data, length(fusiontreeblocks))
90+
resize!(data, nblocks)
9091
counter = Threads.Atomic{Int}(1)
91-
Threads.@sync for _ in 1:min(nthreads, length(fusiontreeblocks))
92+
Threads.@sync for _ in 1:min(nthreads, nblocks)
9293
Threads.@spawn begin
9394
while true
9495
local_counter = Threads.atomic_add!(counter, 1)
@@ -97,6 +98,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
9798
fs_dst, U = transform(fs_src)
9899
matrix = copy(transpose(U)) # TODO: should we avoid this
99100

101+
trees_src = fusiontrees(fs_src)
100102
inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices),
101103
trees_src)
102104
trees_dst = fusiontrees(fs_dst)
@@ -110,7 +112,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
110112
sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst,
111113
inds_dst)
112114

113-
data1[local_counter] = (matrix, (sz_dst, newstructs_dst),
115+
data[local_counter] = (matrix, (sz_dst, newstructs_dst),
114116
(sz_src, newstructs_src))
115117
end
116118
end

0 commit comments

Comments
 (0)