Skip to content

Commit 722610d

Browse files
committed
Overhaul indexmanipulations to reuse multithreading criterion
1 parent 0b13b7b commit 722610d

File tree

1 file changed

+96
-121
lines changed

1 file changed

+96
-121
lines changed

src/tensors/indexmanipulations.jl

Lines changed: 96 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -453,68 +453,71 @@ end
453453

454454
function add_transform!(tdst::AbstractTensorMap,
455455
tsrc::AbstractTensorMap,
456-
(p₁, p₂)::Index2Tuple,
456+
p::Index2Tuple,
457457
transformer,
458458
α::Number,
459459
β::Number,
460460
backend::AbstractBackend...)
461461
@boundscheck begin
462-
permute(space(tsrc), (p₁, p₂)) == space(tdst) ||
462+
permute(space(tsrc), p) == space(tdst) ||
463463
throw(SpaceMismatch("source = $(codomain(tsrc))$(domain(tsrc)),
464-
dest = $(codomain(tdst))$(domain(tdst)), p₁ = $(p), p₂ = $(p)"))
464+
dest = $(codomain(tdst))$(domain(tdst)), p₁ = $(p[1]), p₂ = $(p[2])"))
465465
end
466466

467-
if p === codomainind(tsrc) && p === domainind(tsrc)
467+
if p[1] === codomainind(tsrc) && p[2] === domainind(tsrc)
468468
add!(tdst, tsrc, α, β)
469469
else
470-
add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend...)
470+
I = sectortype(tdst)
471+
if I === Trivial
472+
_add_trivial_kernel!(tdst, tsrc, p, transformer, α, β, backend...)
473+
elseif FusionStyle(I) === UniqueFusion()
474+
if use_threaded_transform(tdst, transformer)
475+
_add_abelian_kernel_threaded!(tdst, tsrc, p, transformer, α, β, backend...)
476+
else
477+
_add_abelian_kernel_nonthreaded!(tdst, tsrc, p, transformer, α, β,
478+
backend...)
479+
end
480+
else
481+
if use_threaded_transform(tdst, transformer)
482+
_add_general_kernel_threaded!(tdst, tsrc, p, transformer, α, β, backend...)
483+
else
484+
_add_general_kernel_nonthreaded!(tdst, tsrc, p, transformer, α, β,
485+
backend...)
486+
end
487+
end
471488
end
472489

473490
return tdst
474491
end
475492

476-
function use_threaded_transform(t::TensorMap, transformer::TreeTransformer)
493+
function use_threaded_transform(t::TensorMap, transformer)
477494
return get_num_transformer_threads() > 1 && length(t.data) > Strided.MINTHREADLENGTH
478495
end
479-
480-
function add_transform_kernel!(tdst::TensorMap,
481-
tsrc::TensorMap,
482-
p::Index2Tuple,
483-
transformer::TreeTransformer,
484-
α::Number,
485-
β::Number,
486-
backend::AbstractBackend...)
487-
if use_threaded_transform(tsrc, transformer)
488-
_add_transform_threaded!(tdst, tsrc, p, transformer, α, β, backend...)
489-
else
490-
_add_transform_nonthreaded!(tdst, tsrc, p, transformer, α, β, backend...)
491-
end
492-
493-
return nothing
496+
function use_threaded_transform(t::AbstractTensorMap, transformer)
497+
return get_num_transformer_threads() > 1 && dim(space(t)) > Strided.MINTHREADLENGTH
494498
end
495499

496-
# Trivial implementation
497-
# ----------------------
498-
# Hijack before threading is used
499-
function add_transform_kernel!(tdst::TensorMap, tsrc::TensorMap, (p₁, p₂)::Index2Tuple,
500-
::TrivialTreeTransformer,
501-
α::Number, β::Number, backend::AbstractBackend...)
502-
TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend...)
500+
# Trivial implementations
501+
# -----------------------
502+
function _add_trivial_kernel!(tdst, tsrc, p, transformer, α, β, backend...)
503+
TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend...)
503504
return nothing
504505
end
505506

506507
# Abelian implementations
507508
# -----------------------
508-
function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::AbelianTreeTransformer,
509-
α, β, backend...)
509+
function _add_abelian_kernel_nonthreaded!(tdst, tsrc, p,
510+
transformer::AbelianTreeTransformer,
511+
α, β, backend...)
510512
for subtransformer in transformer.data
511513
_add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
512514
end
513515
return nothing
514516
end
515517

516-
function _add_transform_threaded!(tdst, tsrc, p, transformer::AbelianTreeTransformer, α, β,
517-
backend...; ntasks::Int=get_num_transformer_threads())
518+
function _add_abelian_kernel_threaded!(tdst, tsrc, p, transformer::AbelianTreeTransformer,
519+
α, β, backend...;
520+
ntasks::Int=get_num_transformer_threads())
518521
nblocks = length(transformer.data)
519522
counter = Threads.Atomic{Int}(1)
520523
Threads.@sync for _ in 1:min(ntasks, nblocks)
@@ -530,14 +533,49 @@ function _add_transform_threaded!(tdst, tsrc, p, transformer::AbelianTreeTransfo
530533
return nothing
531534
end
532535

536+
function _add_transform_single!(tdst, tsrc, p,
537+
(basistransform, structures_dst, structures_src),
538+
α, β, backend...)
539+
structure_dst = structures_dst isa Vector ? only(structures_dst) : structures_dst
540+
structure_src = structures_src isa Vector ? only(structures_src) : structures_src
541+
coeff = basistransform isa Number ? basistransform : only(basistransform)
542+
543+
subblock_dst = StridedView(tdst.data, structure_dst...)
544+
subblock_src = StridedView(tsrc.data, structure_src...)
545+
TO.tensoradd!(subblock_dst, subblock_src, p, false, α * coeff, β, backend...)
546+
return nothing
547+
end
548+
549+
function _add_abelian_kernel_nonthreaded!(tdst, tsrc, p, transformer, α, β, backend...)
550+
for (f₁, f₂) in fusiontrees(tsrc)
551+
_add_abelian_block!(tdst, tsrc, p, transformer, f₁, f₂, α, β, backend...)
552+
end
553+
return nothing
554+
end
555+
556+
function _add_abelian_kernel_threaded!(tdst, tsrc, p, transformer, α, β, backend...)
557+
Threads.@sync for (f₁, f₂) in fusiontrees(tsrc)
558+
Threads.@spawn _add_abelian_block!(tdst, tsrc, p, transformer, f₁, f₂, α, β,
559+
backend...)
560+
end
561+
return nothing
562+
end
563+
564+
function _add_abelian_block!(tdst, tsrc, p, transformer, f₁, f₂, α, β, backend...)
565+
(f₁′, f₂′), coeff = first(transformer(f₁, f₂))
566+
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
567+
backend...)
568+
return nothing
569+
end
570+
533571
# Non-abelian implementations
534572
# ---------------------------
535-
function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::GenericTreeTransformer,
536-
α, β, backend...)
573+
function _add_general_kernel_nonthreaded!(tdst, tsrc, p,
574+
transformer::GenericTreeTransformer,
575+
α, β, backend...)
537576
# preallocate buffers
538577
buffers = allocate_buffers(tdst, tsrc, transformer)
539578

540-
# TODO: this could be multithreaded
541579
for subtransformer in transformer.data
542580
# Special case without intermediate buffers whenever there is only a single block
543581
if length(subtransformer[1]) == 1
@@ -549,9 +587,9 @@ function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::GenericTreeTran
549587
return nothing
550588
end
551589

552-
function _add_transform_threaded!(tdst, tsrc, p, transformer::GenericTreeTransformer,
553-
α, β, backend...;
554-
ntasks::Int=get_num_transformer_threads())
590+
function _add_general_kernel_threaded!(tdst, tsrc, p, transformer::GenericTreeTransformer,
591+
α, β, backend...;
592+
ntasks::Int=get_num_transformer_threads())
555593
nblocks = length(transformer.data)
556594

557595
counter = Threads.Atomic{Int}(1)
@@ -577,18 +615,18 @@ function _add_transform_threaded!(tdst, tsrc, p, transformer::GenericTreeTransfo
577615
return nothing
578616
end
579617

580-
# Kernels
581-
# -------
582-
function _add_transform_single!(tdst, tsrc, p,
583-
(basistransform, structures_dst, structures_src),
584-
α, β, backend...)
585-
structure_dst = structures_dst isa Vector ? only(structures_dst) : structures_dst
586-
structure_src = structures_src isa Vector ? only(structures_src) : structures_src
587-
coeff = basistransform isa Number ? basistransform : only(basistransform)
588-
589-
subblock_dst = StridedView(tdst.data, structure_dst...)
590-
subblock_src = StridedView(tsrc.data, structure_src...)
591-
TO.tensoradd!(subblock_dst, subblock_src, p, false, α * coeff, β, backend...)
618+
function _add_general_kernel_nonthreaded!(tdst, tsrc, p, transformer, α, β, backend...)
619+
if iszero(β)
620+
tdst = zerovector!(tdst)
621+
elseif !isone(β)
622+
tdst = scale!(tdst, β)
623+
end
624+
for (f₁, f₂) in fusiontrees(tsrc)
625+
for ((f₁′, f₂′), coeff) in transformer(f₁, f₂)
626+
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff,
627+
One(), backend...)
628+
end
629+
end
592630
return nothing
593631
end
594632

@@ -620,88 +658,25 @@ function _add_transform_multi!(tdst, tsrc, p,
620658
return nothing
621659
end
622660

623-
# Other implementations
624-
# ---------------------
625-
626-
function add_transform_kernel!(tdst::AbstractTensorMap,
627-
tsrc::AbstractTensorMap,
628-
(p₁, p₂)::Index2Tuple,
629-
fusiontreetransform::Function,
630-
α::Number,
631-
β::Number,
632-
backend::AbstractBackend...)
633-
I = sectortype(spacetype(tdst))
634-
635-
if I === Trivial
636-
_add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
637-
elseif FusionStyle(I) isa UniqueFusion
638-
_add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
639-
else
640-
_add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
641-
end
642-
643-
return nothing
644-
end
645-
646-
# internal methods: no argument types
647-
function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
648-
TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend...)
649-
return nothing
650-
end
651-
652-
function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
653-
if Threads.nthreads() > 1
654-
Threads.@sync for (f₁, f₂) in fusiontrees(tsrc)
655-
Threads.@spawn _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
656-
f₁, f₂, α, β, backend...)
657-
end
658-
else
659-
for (f₁, f₂) in fusiontrees(tsrc)
660-
_add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
661-
f₁, f₂, α, β, backend...)
662-
end
663-
end
664-
return nothing
665-
end
666-
667-
function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, backend...)
668-
(f₁′, f₂′), coeff = first(fusiontreetransform(f₁, f₂))
669-
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
670-
backend...)
671-
return nothing
672-
end
673-
674-
function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
661+
function _add_general_kernel_threaded!(tdst, tsrc, p, transformer, α, β, backend...)
675662
if iszero(β)
676663
tdst = zerovector!(tdst)
677-
elseif β != 1
664+
elseif !isone(β)
678665
tdst = scale!(tdst, β)
679666
end
680-
β′ = One()
681-
if Threads.nthreads() > 1
682-
Threads.@sync for s₁ in sectors(codomain(tsrc)), s₂ in sectors(domain(tsrc))
683-
Threads.@spawn _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁,
684-
s₂, α, β′, backend...)
685-
end
686-
else
687-
for (f₁, f₂) in fusiontrees(tsrc)
688-
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
689-
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff,
690-
β′, backend...)
691-
end
692-
end
667+
Threads.@sync for s₁ in sectors(codomain(tsrc)), s₂ in sectors(domain(tsrc))
668+
Threads.@spawn _add_nonabelian_sector!(tdst, tsrc, p, transformer, s₁, s₂, α,
669+
backend...)
693670
end
694671
return nothing
695672
end
696673

697-
# TODO: β argument is weird here because it has to be 1
698-
function _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, β,
699-
backend...)
674+
function _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, backend...)
700675
for (f₁, f₂) in fusiontrees(tsrc)
701676
(f₁.uncoupled == s₁ && f₂.uncoupled == s₂) || continue
702677
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
703-
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
704-
backend...)
678+
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff,
679+
One(), backend...)
705680
end
706681
end
707682
return nothing

0 commit comments

Comments
 (0)