Skip to content

Commit 86dbe36

Browse files
committed
Rewrite threaded blockspares to avoid threadid
1 parent d38695e commit 86dbe36

File tree

8 files changed

+41
-214
lines changed

8 files changed

+41
-214
lines changed

NDTensors/src/blocksparse/contract_generic.jl

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,40 +18,9 @@ function contract_blockoffsets(
1818
labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(
1919
labels1, labels2, labelsR
2020
)
21-
22-
# Contraction plan element type
23-
T = Tuple{Block{N1},Block{N2},Block{NR}}
24-
25-
# Thread-local collections of block contractions.
26-
# Could use:
27-
# ```julia
28-
# FLoops.@reduce(contraction_plans = append!(T[], [(block1, block2, blockR)]))
29-
# ```
30-
# as a simpler alternative but it is slower.
31-
32-
contraction_plans = Vector{T}[T[] for _ in 1:nthreads()]
33-
34-
#
35-
# Reserve some capacity
36-
# In theory the maximum is length(boffs1) * length(boffs2)
37-
# but in practice that is too much
38-
#for contraction_plan in contraction_plans
39-
# sizehint!(contraction_plan, max(length(boffs1), length(boffs2)))
40-
#end
41-
#
42-
43-
contract_blocks!(
44-
alg,
45-
contraction_plans,
46-
boffs1,
47-
boffs2,
48-
labels1_to_labels2,
49-
labels1_to_labelsR,
50-
labels2_to_labelsR,
51-
ValNR,
21+
contraction_plan = contract_blocks(
22+
alg, boffs1, boffs2, labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR, ValNR
5223
)
53-
54-
contraction_plan = reduce(vcat, contraction_plans)
5524
blockoffsetsR = BlockOffsets{NR}()
5625
nnzR = 0
5726
for (_, _, blockR) in contraction_plan
@@ -60,7 +29,6 @@ function contract_blockoffsets(
6029
nnzR += blockdim(indsR, blockR)
6130
end
6231
end
63-
6432
return blockoffsetsR, contraction_plan
6533
end
6634

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using .Expose: expose
2-
function contract_blocks!(
2+
function contract_blocks(
33
alg::Algorithm"threaded_threads",
4-
contraction_plans,
54
boffs1,
65
boffs2,
76
labels1_to_labels2,
@@ -11,40 +10,45 @@ function contract_blocks!(
1110
)
1211
blocks1 = keys(boffs1)
1312
blocks2 = keys(boffs2)
14-
if length(blocks1) > length(blocks2)
15-
@sync for blocks1_partition in
16-
Iterators.partition(blocks1, max(1, length(blocks1) ÷ nthreads()))
17-
@spawn for block1 in blocks1_partition
18-
for block2 in blocks2
19-
maybe_contract_blocks!(
20-
contraction_plans[threadid()],
21-
block1,
22-
block2,
23-
labels1_to_labels2,
24-
labels1_to_labelsR,
25-
labels2_to_labelsR,
26-
ValNR,
27-
)
28-
end
13+
return if length(blocks1) > length(blocks2)
14+
tasks = map(
15+
Iterators.partition(blocks1, max(1, length(blocks1) ÷ nthreads()))
16+
) do blocks1_partition
17+
@spawn begin
18+
block_contractions =
19+
map(Iterators.product(blocks1_partition, blocks2)) do (block1, block2)
20+
maybe_contract_blocks(
21+
block1,
22+
block2,
23+
labels1_to_labels2,
24+
labels1_to_labelsR,
25+
labels2_to_labelsR,
26+
ValNR,
27+
)
28+
end
29+
block_contractions = filter(!isnothing, block_contractions)
2930
end
3031
end
32+
mapreduce(fetch, vcat, tasks)
3133
else
32-
@sync for blocks2_partition in
33-
Iterators.partition(blocks2, max(1, length(blocks2) ÷ nthreads()))
34-
@spawn for block2 in blocks2_partition
35-
for block1 in blocks1
36-
maybe_contract_blocks!(
37-
contraction_plans[threadid()],
38-
block1,
39-
block2,
40-
labels1_to_labels2,
41-
labels1_to_labelsR,
42-
labels2_to_labelsR,
43-
ValNR,
44-
)
45-
end
34+
tasks = map(
35+
Iterators.partition(blocks2, max(1, length(blocks2) ÷ nthreads()))
36+
) do blocks2_partition
37+
@spawn begin
38+
block_contractions =
39+
map(Iterators.product(blocks1, blocks2_partition)) do (block1, block2)
40+
maybe_contract_blocks(
41+
block1,
42+
block2,
43+
labels1_to_labels2,
44+
labels1_to_labelsR,
45+
labels2_to_labelsR,
46+
ValNR,
47+
)
48+
end
49+
block_contractions = filter(!isnothing, block_contractions)
4650
end
4751
end
52+
mapreduce(fetch, vcat, tasks)
4853
end
49-
return nothing
5054
end

NDTensors/src/blocksparse/contract_utilities.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,12 @@ function compute_alpha(
1414
return one(ElR)
1515
end
1616

17-
function maybe_contract_blocks!(
18-
contraction_plan,
19-
block1,
20-
block2,
21-
labels1_to_labels2,
22-
labels1_to_labelsR,
23-
labels2_to_labelsR,
24-
ValNR,
17+
function maybe_contract_blocks(
18+
block1, block2, labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR, ValNR
2519
)
2620
if are_blocks_contracted(block1, block2, labels1_to_labels2)
2721
blockR = contract_blocks(block1, labels1_to_labelsR, block2, labels2_to_labelsR, ValNR)
28-
push!(contraction_plan, (block1, block2, blockR))
22+
return block1, block2, blockR
2923
end
3024
return nothing
3125
end

NDTensors/test/backup/arraytensor/Project.toml

Lines changed: 0 additions & 3 deletions
This file was deleted.

NDTensors/test/backup/arraytensor/array.jl

Lines changed: 0 additions & 51 deletions
This file was deleted.

NDTensors/test/backup/arraytensor/blocksparsearray.jl

Lines changed: 0 additions & 52 deletions
This file was deleted.

NDTensors/test/backup/arraytensor/diagonalarray.jl

Lines changed: 0 additions & 25 deletions
This file was deleted.

NDTensors/test/backup/arraytensor/runtests.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)