Skip to content

Commit 16d7faf

Browse files
committed
separate treemanipulation threads
1 parent 42fc8ca commit 16d7faf

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/TensorKit.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,19 @@ function set_num_transformer_threads(n::Int)
210210
return TRANSFORMER_THREADS[] = n
211211
end
212212

213+
const TREEMANIPULATION_THREADS = Ref(1)
214+
215+
get_num_manipulation_threads() = TREEMANIPULATION_THREADS[]
216+
217+
function set_num_manipulation_threads(n::Int)
218+
N = Base.Threads.nthreads()
219+
if n > N
220+
n = N
221+
Strided._set_num_threads_warn(n)
222+
end
223+
return TREEMANIPULATION_THREADS[] = n
224+
end
225+
213226
# Definitions and methods for tensors
214227
#-------------------------------------
215228
# general definitions

src/tensors/treetransformers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
7373

7474
data = Vector{_GenericTransformerData{T, N}}()
7575

76-
nthreads = get_num_transformer_threads()
76+
nthreads = get_num_manipulation_threads()
7777
if nthreads > 1
7878
fusiontreeblocks = Vector{FusionTreeBlock{I, N₁, N₂, fusiontreetype(I, N₁, N₂)}}()
7979
for cod_uncoupled_src in sectors(codomain(Vsrc)),

0 commit comments

Comments
 (0)