Skip to content

Commit 99c8ae5

Browse files
committed
separate treemanipulation threads
1 parent aae6602 commit 99c8ae5

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/TensorKit.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,19 @@ function set_num_transformer_threads(n::Int)
198198
return TRANSFORMER_THREADS[] = n
199199
end
200200

201+
const TREEMANIPULATION_THREADS = Ref(1)
202+
203+
get_num_manipulation_threads() = TREEMANIPULATION_THREADS[]
204+
205+
function set_num_transformer_threads(n::Int)
206+
N = Base.Threads.nthreads()
207+
if n > N
208+
n = N
209+
Strided._set_num_threads_warn(n)
210+
end
211+
return TREEMANIPULATION_THREADS[] = n
212+
end
213+
201214
# Definitions and methods for tensors
202215
#-------------------------------------
203216
# general definitions

src/tensors/treetransformers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
7373

7474
data = Vector{_GenericTransformerData{T,N}}()
7575

76-
nthreads = get_num_transformer_threads()
76+
nthreads = get_num_manipulation_threads()
7777
if nthreads > 1
7878
fusiontreeblocks = Vector{FusionTreeBlock{I,N₁,N₂,fusiontreetype(I, N₁, N₂)}}()
7979
for cod_uncoupled_src in sectors(codomain(Vsrc)),

0 commit comments

Comments
 (0)