@@ -12,52 +12,65 @@ function contract_blocks(
1212 N2 = length (blocktype (boffs2))
1313 blocks1 = keys (boffs1)
1414 blocks2 = keys (boffs2)
15+ T = Tuple{Block{N1},Block{N2},Block{NR}}
1516 return if length (blocks1) > length (blocks2)
1617 tasks = map (
1718 Iterators. partition (blocks1, max (1 , length (blocks1) ÷ nthreads ()))
1819 ) do blocks1_partition
1920 @spawn begin
20- block_contractions = Tuple{Block{N1},Block{N2},Block{NR}}[]
21- foreach (Iterators. product (blocks1_partition, blocks2)) do (block1, block2)
22- block_contraction = maybe_contract_blocks (
23- block1,
24- block2,
25- labels1_to_labels2,
26- labels1_to_labelsR,
27- labels2_to_labelsR,
28- ValNR,
29- )
30- if ! isnothing (block_contraction)
31- push! (block_contractions, block_contraction)
21+ block_contractions = T[]
22+ for block1 in blocks1_partition
23+ for block2 in blocks2
24+ block_contraction = maybe_contract_blocks (
25+ block1,
26+ block2,
27+ labels1_to_labels2,
28+ labels1_to_labelsR,
29+ labels2_to_labelsR,
30+ ValNR,
31+ )
32+ if ! isnothing (block_contraction)
33+ push! (block_contractions, block_contraction)
34+ end
3235 end
3336 end
3437 return block_contractions
3538 end
3639 end
37- mapreduce (fetch, vcat, tasks)
40+ all_block_contractions = T[]
41+ for task in tasks
42+ append! (all_block_contractions, fetch (task))
43+ end
44+ return all_block_contractions
3845 else
3946 tasks = map (
4047 Iterators. partition (blocks2, max (1 , length (blocks2) ÷ nthreads ()))
4148 ) do blocks2_partition
4249 @spawn begin
43- block_contractions = Tuple{Block{N1},Block{N2},Block{NR}}[]
44- foreach (Iterators. product (blocks1, blocks2_partition)) do (block1, block2)
45- block_contraction = maybe_contract_blocks (
46- block1,
47- block2,
48- labels1_to_labels2,
49- labels1_to_labelsR,
50- labels2_to_labelsR,
51- ValNR,
52- )
53- if ! isnothing (block_contraction)
54- push! (block_contractions, block_contraction)
50+ block_contractions = T[]
51+ for block2 in blocks2_partition
52+ for block1 in blocks1
53+ block_contraction = maybe_contract_blocks (
54+ block1,
55+ block2,
56+ labels1_to_labels2,
57+ labels1_to_labelsR,
58+ labels2_to_labelsR,
59+ ValNR,
60+ )
61+ if ! isnothing (block_contraction)
62+ push! (block_contractions, block_contraction)
63+ end
5564 end
5665 end
5766 return block_contractions
5867 end
5968 end
60- mapreduce (fetch, vcat, tasks)
69+ all_block_contractions = T[]
70+ for task in tasks
71+ append! (all_block_contractions, fetch (task))
72+ end
73+ return all_block_contractions
6174 end
6275end
6376
0 commit comments