Skip to content

Commit 4c39252

Browse files
committed
Refactor and add multithreading
1 parent 273b3e7 commit 4c39252

File tree

3 files changed

+112
-27
lines changed

3 files changed

+112
-27
lines changed

src/TensorKit.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,21 @@ include("fusiontrees/fusiontrees.jl")
183183
#-------------------------------------------
184184
include("spaces/vectorspaces.jl")
185185

186+
# Multithreading settings
187+
#-------------------------
188+
const TRANSFORMER_THREADS = Ref(1)
189+
190+
get_num_transformer_threads() = TRANSFORMER_THREADS[]
191+
192+
function set_num_transformer_threads(n::Int)
193+
N = Base.Threads.nthreads()
194+
if n > N
195+
n = N
196+
Strided._set_num_threads_warn(n)
197+
end
198+
return TRANSFORMER_THREADS[] = n
199+
end
200+
186201
# Definitions and methods for tensors
187202
#-------------------------------------
188203
# general definitions
@@ -218,6 +233,8 @@ include("auxiliary/deprecate.jl")
218233
# ----------
219234
function __init__()
220235
@require_extensions
236+
set_num_transformer_threads(Threads.nthreads())
237+
return nothing
221238
end
222239

223240
end

src/tensors/indexmanipulations.jl

Lines changed: 92 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -473,38 +473,68 @@ function add_transform!(tdst::AbstractTensorMap,
473473
return tdst
474474
end
475475

476-
function add_transform_kernel!(tdst::TensorMap,
477-
tsrc::TensorMap,
478-
(p₁, p₂)::Index2Tuple,
479-
::TrivialTreeTransformer,
480-
α::Number,
481-
β::Number,
482-
backend::AbstractBackend...)
483-
return TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend...)
476+
function use_threaded_transform(t::TensorMap, transformer::TreeTransformer)
477+
# TODO: heuristic for not threading over small tensors
478+
return get_num_transformer_threads() > 1
484479
end
485480

486481
function add_transform_kernel!(tdst::TensorMap,
487482
tsrc::TensorMap,
488483
p::Index2Tuple,
489-
transformer::AbelianTreeTransformer,
484+
transformer::TreeTransformer,
490485
α::Number,
491486
β::Number,
492487
backend::AbstractBackend...)
493-
# TODO: this could be multithreaded
488+
if use_threaded_transform(tsrc, transformer)
489+
_add_transform_threaded!(tdst, tsrc, p, transformer, α, β, backend...)
490+
else
491+
_add_transform_nonthreaded!(tdst, tsrc, p, transformer, α, β, backend...)
492+
end
493+
494+
return nothing
495+
end
496+
497+
# Trivial implementation
498+
# ----------------------
499+
# Hijack before threading is used
500+
function add_transform_kernel!(tdst::TensorMap, tsrc::TensorMap, (p₁, p₂)::Index2Tuple,
501+
::TrivialTreeTransformer,
502+
α::Number, β::Number, backend::AbstractBackend...)
503+
TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend...)
504+
return nothing
505+
end
506+
507+
# Abelian implementations
508+
# -----------------------
509+
function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::AbelianTreeTransformer,
510+
α, β, backend...)
494511
for subtransformer in transformer.data
495-
_add_transform_single!(tdst, tsrc, p, α, β, subtransformer, backend...)
512+
_add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
496513
end
514+
return nothing
515+
end
497516

517+
function _add_transform_threaded!(tdst, tsrc, p, transformer::AbelianTreeTransformer, α, β,
518+
backend...; ntasks::Int=get_num_transformer_threads())
519+
nblocks = length(transformer.data)
520+
counter = Threads.Atomic{Int}(1)
521+
Threads.@sync for _ in 1:min(ntasks, nblocks)
522+
Threads.@spawn begin
523+
while true
524+
local_counter = Threads.atomic_add!(counter, 1)
525+
local_counter > nblocks && break
526+
@inbounds subtransformer = transformer.data[local_counter]
527+
_add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
528+
end
529+
end
530+
end
498531
return nothing
499532
end
500533

501-
function add_transform_kernel!(tdst::TensorMap,
502-
tsrc::TensorMap,
503-
p::Index2Tuple,
504-
transformer::GenericTreeTransformer,
505-
α::Number,
506-
β::Number,
507-
backend::AbstractBackend...)
534+
# Non-abelian implementations
535+
# ---------------------------
536+
function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::GenericTreeTransformer,
537+
α, β, backend...)
508538
# preallocate buffers
509539
buffersize = maximum(transformer.data) do (_, structures_dst, _)
510540
return prod(structures_dst[1][1])
@@ -522,24 +552,59 @@ function add_transform_kernel!(tdst::TensorMap,
522552
(buffer1, buffer2), backend...)
523553
end
524554
end
525-
return tdst
555+
return nothing
526556
end
527557

528-
function _add_transform_single!(tdst, tsrc, p, α, β,
558+
function _add_transform_threaded!(tdst, tsrc, p, transformer::GenericTreeTransformer,
559+
α, β, backend...;
560+
ntasks::Int=get_num_transformer_threads())
561+
buffersize = maximum(transformer.data) do (_, structures_dst, _)
562+
return prod(structures_dst[1][1])
563+
end
564+
nblocks = length(transformer.data)
565+
566+
counter = Threads.Atomic{Int}(1)
567+
Threads.@sync for _ in 1:min(ntasks, nblocks)
568+
Threads.@spawn begin
569+
# preallocate buffers for each task
570+
buffer1 = similar(tsrc.data, buffersize)
571+
buffer2 = similar(tdst.data, buffersize)
572+
573+
while true
574+
local_counter = Threads.atomic_add!(counter, 1)
575+
local_counter > nblocks && break
576+
@inbounds subtransformer = transformer.data[local_counter]
577+
if length(subtransformer[1]) == 1
578+
_add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
579+
else
580+
_add_transform_multi!(tdst, tsrc, p, subtransformer, (buffer1, buffer2),
581+
α, β, backend...)
582+
end
583+
end
584+
end
585+
end
586+
587+
return nothing
588+
end
589+
590+
# Kernels
591+
# -------
592+
function _add_transform_single!(tdst, tsrc, p,
529593
(basistransform, structures_dst, structures_src),
530-
backend...)
594+
α, β, backend...)
531595
structure_dst = structures_dst isa Vector ? only(structures_dst) : structures_dst
532596
structure_src = structures_src isa Vector ? only(structures_src) : structures_src
533597
coeff = basistransform isa Number ? basistransform : only(basistransform)
598+
534599
subblock_dst = StridedView(tdst.data, structure_dst...)
535600
subblock_src = StridedView(tsrc.data, structure_src...)
536601
TO.tensoradd!(subblock_dst, subblock_src, p, false, α * coeff, β, backend...)
537602
return nothing
538603
end
539604

540-
function _add_transform_multi!(tdst, tsrc, p, α, β,
605+
function _add_transform_multi!(tdst, tsrc, p,
541606
(basistransform, structures_dst, structures_src),
542-
(buffer1, buffer2), backend...)
607+
(buffer1, buffer2), α, β, backend...)
543608
rows, cols = size(basistransform)
544609
sz_src = first(first(structures_src))
545610
blocksize = prod(sz_src)
@@ -553,7 +618,7 @@ function _add_transform_multi!(tdst, tsrc, p, α, β,
553618

554619
# Resummation into a second buffer using BLAS
555620
buffer_dst = StridedView(buffer2, (blocksize, rows), (1, blocksize), 0)
556-
mul!(buffer_dst, buffer_src, basistransform, α)
621+
mul!(buffer_dst, buffer_src, basistransform, α, Zero())
557622

558623
# Filling up the output
559624
for (i, structure_dst) in enumerate(structures_dst)
@@ -565,6 +630,9 @@ function _add_transform_multi!(tdst, tsrc, p, α, β,
565630
return nothing
566631
end
567632

633+
# Other implementations
634+
# ---------------------
635+
568636
function add_transform_kernel!(tdst::AbstractTensorMap,
569637
tsrc::AbstractTensorMap,
570638
(p₁, p₂)::Index2Tuple,

src/tensors/treetransformers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ abstract type TreeTransformer end
88
struct TrivialTreeTransformer <: TreeTransformer end
99

1010
struct AbelianTreeTransformer{T,N} <: TreeTransformer
11-
data::Vector{Tuple{StridedStructure{N},StridedStructure{N},T}}
11+
data::Vector{Tuple{T,StridedStructure{N},StridedStructure{N}}}
1212
end
1313

1414
function AbelianTreeTransformer(transform, p, Vsrc, Vdst)
@@ -19,15 +19,15 @@ function AbelianTreeTransformer(transform, p, Vsrc, Vdst)
1919
L = length(structure_src.fusiontreelist)
2020
T = sectorscalartype(sectortype(Vdst))
2121
N = numind(Vsrc)
22-
data = Vector{Tuple{StridedStructure{N},StridedStructure{N},T}}(undef, L)
22+
data = Vector{Tuple{T,StridedStructure{N},StridedStructure{N}}}(undef, L)
2323

2424
for i in 1:L
2525
f₁, f₂ = structure_src.fusiontreelist[i]
2626
(f₃, f₄), coeff = only(transform(f₁, f₂))
2727
j = structure_dst.fusiontreeindices[(f₃, f₄)]
2828
stridestructure_dst = structure_dst.fusiontreestructure[j]
2929
stridestructure_src = structure_src.fusiontreestructure[i]
30-
data[i] = (stridestructure_dst, stridestructure_src, coeff)
30+
data[i] = (coeff, stridestructure_dst, stridestructure_src)
3131
end
3232

3333
return AbelianTreeTransformer(data)

0 commit comments

Comments
 (0)