@@ -453,68 +453,71 @@ end
453453
454454function add_transform! (tdst:: AbstractTensorMap ,
455455 tsrc:: AbstractTensorMap ,
456- (p₁, p₂) :: Index2Tuple ,
456+ p :: Index2Tuple ,
457457 transformer,
458458 α:: Number ,
459459 β:: Number ,
460460 backend:: AbstractBackend... )
461461 @boundscheck begin
462- permute (space (tsrc), (p₁, p₂) ) == space (tdst) ||
462+ permute (space (tsrc), p ) == space (tdst) ||
463463 throw (SpaceMismatch (" source = $(codomain (tsrc)) ←$(domain (tsrc)) ,
464- dest = $(codomain (tdst)) ←$(domain (tdst)) , p₁ = $(p₁ ) , p₂ = $(p₂ ) " ))
464+ dest = $(codomain (tdst)) ←$(domain (tdst)) , p₁ = $(p[ 1 ] ) , p₂ = $(p[ 2 ] ) " ))
465465 end
466466
467- if p₁ === codomainind (tsrc) && p₂ === domainind (tsrc)
467+ if p[ 1 ] === codomainind (tsrc) && p[ 2 ] === domainind (tsrc)
468468 add! (tdst, tsrc, α, β)
469469 else
470- add_transform_kernel! (tdst, tsrc, (p₁, p₂), transformer, α, β, backend... )
470+ I = sectortype (tdst)
471+ if I === Trivial
472+ _add_trivial_kernel! (tdst, tsrc, p, transformer, α, β, backend... )
473+ elseif FusionStyle (I) === UniqueFusion ()
474+ if use_threaded_transform (tdst, transformer)
475+ _add_abelian_kernel_threaded! (tdst, tsrc, p, transformer, α, β, backend... )
476+ else
477+ _add_abelian_kernel_nonthreaded! (tdst, tsrc, p, transformer, α, β,
478+ backend... )
479+ end
480+ else
481+ if use_threaded_transform (tdst, transformer)
482+ _add_general_kernel_threaded! (tdst, tsrc, p, transformer, α, β, backend... )
483+ else
484+ _add_general_kernel_nonthreaded! (tdst, tsrc, p, transformer, α, β,
485+ backend... )
486+ end
487+ end
471488 end
472489
473490 return tdst
474491end
475492
476- function use_threaded_transform (t:: TensorMap , transformer:: TreeTransformer )
493+ function use_threaded_transform (t:: TensorMap , transformer)
477494 return get_num_transformer_threads () > 1 && length (t. data) > Strided. MINTHREADLENGTH
478495end
479-
480- function add_transform_kernel! (tdst:: TensorMap ,
481- tsrc:: TensorMap ,
482- p:: Index2Tuple ,
483- transformer:: TreeTransformer ,
484- α:: Number ,
485- β:: Number ,
486- backend:: AbstractBackend... )
487- if use_threaded_transform (tsrc, transformer)
488- _add_transform_threaded! (tdst, tsrc, p, transformer, α, β, backend... )
489- else
490- _add_transform_nonthreaded! (tdst, tsrc, p, transformer, α, β, backend... )
491- end
492-
493- return nothing
496+ function use_threaded_transform (t:: AbstractTensorMap , transformer)
497+ return get_num_transformer_threads () > 1 && dim (space (t)) > Strided. MINTHREADLENGTH
494498end
495499
496- # Trivial implementation
497- # ----------------------
498- # Hijack before threading is used
499- function add_transform_kernel! (tdst:: TensorMap , tsrc:: TensorMap , (p₁, p₂):: Index2Tuple ,
500- :: TrivialTreeTransformer ,
501- α:: Number , β:: Number , backend:: AbstractBackend... )
502- TO. tensoradd! (tdst[], tsrc[], (p₁, p₂), false , α, β, backend... )
500+ # Trivial implementations
501+ # -----------------------
502+ function _add_trivial_kernel! (tdst, tsrc, p, transformer, α, β, backend... )
503+ TO. tensoradd! (tdst[], tsrc[], p, false , α, β, backend... )
503504 return nothing
504505end
505506
506507# Abelian implementations
507508# -----------------------
508- function _add_transform_nonthreaded! (tdst, tsrc, p, transformer:: AbelianTreeTransformer ,
509- α, β, backend... )
509+ function _add_abelian_kernel_nonthreaded! (tdst, tsrc, p,
510+ transformer:: AbelianTreeTransformer ,
511+ α, β, backend... )
510512 for subtransformer in transformer. data
511513 _add_transform_single! (tdst, tsrc, p, subtransformer, α, β, backend... )
512514 end
513515 return nothing
514516end
515517
516- function _add_transform_threaded! (tdst, tsrc, p, transformer:: AbelianTreeTransformer , α, β,
517- backend... ; ntasks:: Int = get_num_transformer_threads ())
518+ function _add_abelian_kernel_threaded! (tdst, tsrc, p, transformer:: AbelianTreeTransformer ,
519+ α, β, backend... ;
520+ ntasks:: Int = get_num_transformer_threads ())
518521 nblocks = length (transformer. data)
519522 counter = Threads. Atomic {Int} (1 )
520523 Threads. @sync for _ in 1 : min (ntasks, nblocks)
@@ -530,14 +533,49 @@ function _add_transform_threaded!(tdst, tsrc, p, transformer::AbelianTreeTransfo
530533 return nothing
531534end
532535
536+ function _add_transform_single! (tdst, tsrc, p,
537+ (basistransform, structures_dst, structures_src),
538+ α, β, backend... )
539+ structure_dst = structures_dst isa Vector ? only (structures_dst) : structures_dst
540+ structure_src = structures_src isa Vector ? only (structures_src) : structures_src
541+ coeff = basistransform isa Number ? basistransform : only (basistransform)
542+
543+ subblock_dst = StridedView (tdst. data, structure_dst... )
544+ subblock_src = StridedView (tsrc. data, structure_src... )
545+ TO. tensoradd! (subblock_dst, subblock_src, p, false , α * coeff, β, backend... )
546+ return nothing
547+ end
548+
549+ function _add_abelian_kernel_nonthreaded! (tdst, tsrc, p, transformer, α, β, backend... )
550+ for (f₁, f₂) in fusiontrees (tsrc)
551+ _add_abelian_block! (tdst, tsrc, p, transformer, f₁, f₂, α, β, backend... )
552+ end
553+ return nothing
554+ end
555+
556+ function _add_abelian_kernel_threaded! (tdst, tsrc, p, transformer, α, β, backend... )
557+ Threads. @sync for (f₁, f₂) in fusiontrees (tsrc)
558+ Threads. @spawn _add_abelian_block! (tdst, tsrc, p, transformer, f₁, f₂, α, β,
559+ backend... )
560+ end
561+ return nothing
562+ end
563+
564+ function _add_abelian_block! (tdst, tsrc, p, transformer, f₁, f₂, α, β, backend... )
565+ (f₁′, f₂′), coeff = first (transformer (f₁, f₂))
566+ @inbounds TO. tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff, β,
567+ backend... )
568+ return nothing
569+ end
570+
533571# Non-abelian implementations
534572# ---------------------------
535- function _add_transform_nonthreaded! (tdst, tsrc, p, transformer:: GenericTreeTransformer ,
536- α, β, backend... )
573+ function _add_general_kernel_nonthreaded! (tdst, tsrc, p,
574+ transformer:: GenericTreeTransformer ,
575+ α, β, backend... )
537576 # preallocate buffers
538577 buffers = allocate_buffers (tdst, tsrc, transformer)
539578
540- # TODO : this could be multithreaded
541579 for subtransformer in transformer. data
542580 # Special case without intermediate buffers whenever there is only a single block
543581 if length (subtransformer[1 ]) == 1
@@ -549,9 +587,9 @@ function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::GenericTreeTran
549587 return nothing
550588end
551589
552- function _add_transform_threaded ! (tdst, tsrc, p, transformer:: GenericTreeTransformer ,
553- α, β, backend... ;
554- ntasks:: Int = get_num_transformer_threads ())
590+ function _add_general_kernel_threaded ! (tdst, tsrc, p, transformer:: GenericTreeTransformer ,
591+ α, β, backend... ;
592+ ntasks:: Int = get_num_transformer_threads ())
555593 nblocks = length (transformer. data)
556594
557595 counter = Threads. Atomic {Int} (1 )
@@ -577,18 +615,18 @@ function _add_transform_threaded!(tdst, tsrc, p, transformer::GenericTreeTransfo
577615 return nothing
578616end
579617
580- # Kernels
581- # -------
582- function _add_transform_single! ( tdst, tsrc, p,
583- (basistransform, structures_dst, structures_src),
584- α , β, backend ... )
585- structure_dst = structures_dst isa Vector ? only (structures_dst) : structures_dst
586- structure_src = structures_src isa Vector ? only (structures_src) : structures_src
587- coeff = basistransform isa Number ? basistransform : only (basistransform )
588-
589- subblock_dst = StridedView (tdst . data, structure_dst ... )
590- subblock_src = StridedView (tsrc . data, structure_src ... )
591- TO . tensoradd! (subblock_dst, subblock_src, p, false , α * coeff, β, backend ... )
618+ function _add_general_kernel_nonthreaded! (tdst, tsrc, p, transformer, α, β, backend ... )
619+ if iszero (β)
620+ tdst = zerovector! (tdst)
621+ elseif ! isone (β)
622+ tdst = scale! (tdst , β)
623+ end
624+ for (f₁, f₂) in fusiontrees (tsrc)
625+ for ((f₁′, f₂′), coeff) in transformer (f₁, f₂ )
626+ @inbounds TO . tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff,
627+ One (), backend ... )
628+ end
629+ end
592630 return nothing
593631end
594632
@@ -620,88 +658,25 @@ function _add_transform_multi!(tdst, tsrc, p,
620658 return nothing
621659end
622660
623- # Other implementations
624- # ---------------------
625-
626- function add_transform_kernel! (tdst:: AbstractTensorMap ,
627- tsrc:: AbstractTensorMap ,
628- (p₁, p₂):: Index2Tuple ,
629- fusiontreetransform:: Function ,
630- α:: Number ,
631- β:: Number ,
632- backend:: AbstractBackend... )
633- I = sectortype (spacetype (tdst))
634-
635- if I === Trivial
636- _add_trivial_kernel! (tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend... )
637- elseif FusionStyle (I) isa UniqueFusion
638- _add_abelian_kernel! (tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend... )
639- else
640- _add_general_kernel! (tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend... )
641- end
642-
643- return nothing
644- end
645-
646- # internal methods: no argument types
647- function _add_trivial_kernel! (tdst, tsrc, p, fusiontreetransform, α, β, backend... )
648- TO. tensoradd! (tdst[], tsrc[], p, false , α, β, backend... )
649- return nothing
650- end
651-
652- function _add_abelian_kernel! (tdst, tsrc, p, fusiontreetransform, α, β, backend... )
653- if Threads. nthreads () > 1
654- Threads. @sync for (f₁, f₂) in fusiontrees (tsrc)
655- Threads. @spawn _add_abelian_block! (tdst, tsrc, p, fusiontreetransform,
656- f₁, f₂, α, β, backend... )
657- end
658- else
659- for (f₁, f₂) in fusiontrees (tsrc)
660- _add_abelian_block! (tdst, tsrc, p, fusiontreetransform,
661- f₁, f₂, α, β, backend... )
662- end
663- end
664- return nothing
665- end
666-
667- function _add_abelian_block! (tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, backend... )
668- (f₁′, f₂′), coeff = first (fusiontreetransform (f₁, f₂))
669- @inbounds TO. tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff, β,
670- backend... )
671- return nothing
672- end
673-
674- function _add_general_kernel! (tdst, tsrc, p, fusiontreetransform, α, β, backend... )
661+ function _add_general_kernel_threaded! (tdst, tsrc, p, transformer, α, β, backend... )
675662 if iszero (β)
676663 tdst = zerovector! (tdst)
677- elseif β != 1
664+ elseif ! isone (β)
678665 tdst = scale! (tdst, β)
679666 end
680- β′ = One ()
681- if Threads. nthreads () > 1
682- Threads. @sync for s₁ in sectors (codomain (tsrc)), s₂ in sectors (domain (tsrc))
683- Threads. @spawn _add_nonabelian_sector! (tdst, tsrc, p, fusiontreetransform, s₁,
684- s₂, α, β′, backend... )
685- end
686- else
687- for (f₁, f₂) in fusiontrees (tsrc)
688- for ((f₁′, f₂′), coeff) in fusiontreetransform (f₁, f₂)
689- @inbounds TO. tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff,
690- β′, backend... )
691- end
692- end
667+ Threads. @sync for s₁ in sectors (codomain (tsrc)), s₂ in sectors (domain (tsrc))
668+ Threads. @spawn _add_nonabelian_sector! (tdst, tsrc, p, transformer, s₁, s₂, α,
669+ backend... )
693670 end
694671 return nothing
695672end
696673
697- # TODO : β argument is weird here because it has to be 1
698- function _add_nonabelian_sector! (tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, β,
699- backend... )
674+ function _add_nonabelian_sector! (tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, backend... )
700675 for (f₁, f₂) in fusiontrees (tsrc)
701676 (f₁. uncoupled == s₁ && f₂. uncoupled == s₂) || continue
702677 for ((f₁′, f₂′), coeff) in fusiontreetransform (f₁, f₂)
703- @inbounds TO. tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff, β,
704- backend... )
678+ @inbounds TO. tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff,
679+ One (), backend... )
705680 end
706681 end
707682 return nothing
0 commit comments