Skip to content

Commit 0a4341f

Browse files
committed
Improve type stability and performance
1 parent 9b61d37 commit 0a4341f

File tree

1 file changed

+39
-26
lines changed

1 file changed

+39
-26
lines changed

NDTensors/src/blocksparse/contract_threaded.jl

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,52 +12,65 @@ function contract_blocks(
1212
N2 = length(blocktype(boffs2))
1313
blocks1 = keys(boffs1)
1414
blocks2 = keys(boffs2)
15+
T = Tuple{Block{N1},Block{N2},Block{NR}}
1516
return if length(blocks1) > length(blocks2)
1617
tasks = map(
1718
Iterators.partition(blocks1, max(1, length(blocks1) ÷ nthreads()))
1819
) do blocks1_partition
1920
@spawn begin
20-
block_contractions = Tuple{Block{N1},Block{N2},Block{NR}}[]
21-
foreach(Iterators.product(blocks1_partition, blocks2)) do (block1, block2)
22-
block_contraction = maybe_contract_blocks(
23-
block1,
24-
block2,
25-
labels1_to_labels2,
26-
labels1_to_labelsR,
27-
labels2_to_labelsR,
28-
ValNR,
29-
)
30-
if !isnothing(block_contraction)
31-
push!(block_contractions, block_contraction)
21+
block_contractions = T[]
22+
for block1 in blocks1_partition
23+
for block2 in blocks2
24+
block_contraction = maybe_contract_blocks(
25+
block1,
26+
block2,
27+
labels1_to_labels2,
28+
labels1_to_labelsR,
29+
labels2_to_labelsR,
30+
ValNR,
31+
)
32+
if !isnothing(block_contraction)
33+
push!(block_contractions, block_contraction)
34+
end
3235
end
3336
end
3437
return block_contractions
3538
end
3639
end
37-
mapreduce(fetch, vcat, tasks)
40+
all_block_contractions = T[]
41+
for task in tasks
42+
append!(all_block_contractions, fetch(task))
43+
end
44+
return all_block_contractions
3845
else
3946
tasks = map(
4047
Iterators.partition(blocks2, max(1, length(blocks2) ÷ nthreads()))
4148
) do blocks2_partition
4249
@spawn begin
43-
block_contractions = Tuple{Block{N1},Block{N2},Block{NR}}[]
44-
foreach(Iterators.product(blocks1, blocks2_partition)) do (block1, block2)
45-
block_contraction = maybe_contract_blocks(
46-
block1,
47-
block2,
48-
labels1_to_labels2,
49-
labels1_to_labelsR,
50-
labels2_to_labelsR,
51-
ValNR,
52-
)
53-
if !isnothing(block_contraction)
54-
push!(block_contractions, block_contraction)
50+
block_contractions = T[]
51+
for block2 in blocks2_partition
52+
for block1 in blocks1
53+
block_contraction = maybe_contract_blocks(
54+
block1,
55+
block2,
56+
labels1_to_labels2,
57+
labels1_to_labelsR,
58+
labels2_to_labelsR,
59+
ValNR,
60+
)
61+
if !isnothing(block_contraction)
62+
push!(block_contractions, block_contraction)
63+
end
5564
end
5665
end
5766
return block_contractions
5867
end
5968
end
60-
mapreduce(fetch, vcat, tasks)
69+
all_block_contractions = T[]
70+
for task in tasks
71+
append!(all_block_contractions, fetch(task))
72+
end
73+
return all_block_contractions
6174
end
6275
end
6376

0 commit comments

Comments
 (0)