Skip to content

Commit 9b61d37

Browse files
committed
Better performance and type stability
1 parent 83e664b commit 9b61d37

File tree

6 files changed

+80
-75
lines changed

6 files changed

+80
-75
lines changed

NDTensors/src/NDTensors.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ include("blocksparse/contract.jl")
7272
include("blocksparse/contract_utilities.jl")
7373
include("blocksparse/contract_generic.jl")
7474
include("blocksparse/contract_sequential.jl")
75-
include("blocksparse/contract_folds.jl")
76-
include("blocksparse/contract_threads.jl")
75+
include("blocksparse/contract_threaded.jl")
7776
include("blocksparse/diagblocksparse.jl")
7877
include("blocksparse/similar.jl")
7978
include("blocksparse/combiner.jl")

NDTensors/src/blocksparse/contract_folds.jl

Lines changed: 0 additions & 15 deletions
This file was deleted.

NDTensors/src/blocksparse/contract_generic.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ function contract_blockoffsets(
1111
indsR,
1212
labelsR,
1313
)
14-
N1 = length(blocktype(boffs1))
15-
N2 = length(blocktype(boffs2))
1614
NR = length(labelsR)
1715
ValNR = ValLength(labelsR)
1816
labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using .Expose: expose
2+
function contract_blocks(
3+
alg::Algorithm"threaded_threads",
4+
boffs1,
5+
boffs2,
6+
labels1_to_labels2,
7+
labels1_to_labelsR,
8+
labels2_to_labelsR,
9+
ValNR::Val{NR},
10+
) where {NR}
11+
N1 = length(blocktype(boffs1))
12+
N2 = length(blocktype(boffs2))
13+
blocks1 = keys(boffs1)
14+
blocks2 = keys(boffs2)
15+
return if length(blocks1) > length(blocks2)
16+
tasks = map(
17+
Iterators.partition(blocks1, max(1, length(blocks1) ÷ nthreads()))
18+
) do blocks1_partition
19+
@spawn begin
20+
block_contractions = Tuple{Block{N1},Block{N2},Block{NR}}[]
21+
foreach(Iterators.product(blocks1_partition, blocks2)) do (block1, block2)
22+
block_contraction = maybe_contract_blocks(
23+
block1,
24+
block2,
25+
labels1_to_labels2,
26+
labels1_to_labelsR,
27+
labels2_to_labelsR,
28+
ValNR,
29+
)
30+
if !isnothing(block_contraction)
31+
push!(block_contractions, block_contraction)
32+
end
33+
end
34+
return block_contractions
35+
end
36+
end
37+
mapreduce(fetch, vcat, tasks)
38+
else
39+
tasks = map(
40+
Iterators.partition(blocks2, max(1, length(blocks2) ÷ nthreads()))
41+
) do blocks2_partition
42+
@spawn begin
43+
block_contractions = Tuple{Block{N1},Block{N2},Block{NR}}[]
44+
foreach(Iterators.product(blocks1, blocks2_partition)) do (block1, block2)
45+
block_contraction = maybe_contract_blocks(
46+
block1,
47+
block2,
48+
labels1_to_labels2,
49+
labels1_to_labelsR,
50+
labels2_to_labelsR,
51+
ValNR,
52+
)
53+
if !isnothing(block_contraction)
54+
push!(block_contractions, block_contraction)
55+
end
56+
end
57+
return block_contractions
58+
end
59+
end
60+
mapreduce(fetch, vcat, tasks)
61+
end
62+
end
63+
64+
function contract!(
65+
::Algorithm"threaded_folds",
66+
R::BlockSparseTensor,
67+
labelsR,
68+
tensor1::BlockSparseTensor,
69+
labelstensor1,
70+
tensor2::BlockSparseTensor,
71+
labelstensor2,
72+
contraction_plan,
73+
)
74+
executor = ThreadedEx()
75+
return contract!(
76+
R, labelsR, tensor1, labelstensor1, tensor2, labelstensor2, contraction_plan, executor
77+
)
78+
end

NDTensors/src/blocksparse/contract_threads.jl

Lines changed: 0 additions & 54 deletions
This file was deleted.

test/threading/test_threading.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using Compat
21
using ITensors
32
using Test
43
using LinearAlgebra
@@ -8,7 +7,7 @@ if isone(Threads.nthreads())
87
end
98

109
@testset "Threading" begin
11-
blas_num_threads = Compat.get_num_threads()
10+
blas_num_threads = BLAS.get_num_threads()
1211
strided_num_threads = ITensors.NDTensors.Strided.get_num_threads()
1312

1413
BLAS.set_num_threads(1)

0 commit comments

Comments
 (0)