@@ -473,38 +473,68 @@ function add_transform!(tdst::AbstractTensorMap,
473473 return tdst
474474end
475475
476- function add_transform_kernel! (tdst:: TensorMap ,
477- tsrc:: TensorMap ,
478- (p₁, p₂):: Index2Tuple ,
479- :: TrivialTreeTransformer ,
480- α:: Number ,
481- β:: Number ,
482- backend:: AbstractBackend... )
483- return TO. tensoradd! (tdst[], tsrc[], (p₁, p₂), false , α, β, backend... )
476+ function use_threaded_transform (t:: TensorMap , transformer:: TreeTransformer )
477+ # TODO : heuristic for not threading over small tensors
478+ return get_num_transformer_threads () > 1
484479end
485480
486481function add_transform_kernel! (tdst:: TensorMap ,
487482 tsrc:: TensorMap ,
488483 p:: Index2Tuple ,
489- transformer:: AbelianTreeTransformer ,
484+ transformer:: TreeTransformer ,
490485 α:: Number ,
491486 β:: Number ,
492487 backend:: AbstractBackend... )
493- # TODO : this could be multithreaded
488+ if use_threaded_transform (tsrc, transformer)
489+ _add_transform_threaded! (tdst, tsrc, p, transformer, α, β, backend... )
490+ else
491+ _add_transform_nonthreaded! (tdst, tsrc, p, transformer, α, β, backend... )
492+ end
493+
494+ return nothing
495+ end
496+
497+ # Trivial implementation
498+ # ----------------------
499+ # Hijack before threading is used
500+ function add_transform_kernel! (tdst:: TensorMap , tsrc:: TensorMap , (p₁, p₂):: Index2Tuple ,
501+ :: TrivialTreeTransformer ,
502+ α:: Number , β:: Number , backend:: AbstractBackend... )
503+ TO. tensoradd! (tdst[], tsrc[], (p₁, p₂), false , α, β, backend... )
504+ return nothing
505+ end
506+
507+ # Abelian implementations
508+ # -----------------------
509+ function _add_transform_nonthreaded! (tdst, tsrc, p, transformer:: AbelianTreeTransformer ,
510+ α, β, backend... )
494511 for subtransformer in transformer. data
495- _add_transform_single! (tdst, tsrc, p, α, β, subtransformer , backend... )
512+ _add_transform_single! (tdst, tsrc, p, subtransformer, α, β , backend... )
496513 end
514+ return nothing
515+ end
497516
517+ function _add_transform_threaded! (tdst, tsrc, p, transformer:: AbelianTreeTransformer , α, β,
518+ backend... ; ntasks:: Int = get_num_transformer_threads ())
519+ nblocks = length (transformer. data)
520+ counter = Threads. Atomic {Int} (1 )
521+ Threads. @sync for _ in 1 : min (ntasks, nblocks)
522+ Threads. @spawn begin
523+ while true
524+ local_counter = Threads. atomic_add! (counter, 1 )
525+ local_counter > nblocks && break
526+ @inbounds subtransformer = transformer. data[local_counter]
527+ _add_transform_single! (tdst, tsrc, p, subtransformer, α, β, backend... )
528+ end
529+ end
530+ end
498531 return nothing
499532end
500533
501- function add_transform_kernel! (tdst:: TensorMap ,
502- tsrc:: TensorMap ,
503- p:: Index2Tuple ,
504- transformer:: GenericTreeTransformer ,
505- α:: Number ,
506- β:: Number ,
507- backend:: AbstractBackend... )
534+ # Non-abelian implementations
535+ # ---------------------------
536+ function _add_transform_nonthreaded! (tdst, tsrc, p, transformer:: GenericTreeTransformer ,
537+ α, β, backend... )
508538 # preallocate buffers
509539 buffersize = maximum (transformer. data) do (_, structures_dst, _)
510540 return prod (structures_dst[1 ][1 ])
@@ -522,24 +552,59 @@ function add_transform_kernel!(tdst::TensorMap,
522552 (buffer1, buffer2), backend... )
523553 end
524554 end
525- return tdst
555+ return nothing
526556end
527557
528- function _add_transform_single! (tdst, tsrc, p, α, β,
558+ function _add_transform_threaded! (tdst, tsrc, p, transformer:: GenericTreeTransformer ,
559+ α, β, backend... ;
560+ ntasks:: Int = get_num_transformer_threads ())
561+ buffersize = maximum (transformer. data) do (_, structures_dst, _)
562+ return prod (structures_dst[1 ][1 ])
563+ end
564+ nblocks = length (transformer. data)
565+
566+ counter = Threads. Atomic {Int} (1 )
567+ Threads. @sync for _ in 1 : min (ntasks, nblocks)
568+ Threads. @spawn begin
569+ # preallocate buffers for each task
570+ buffer1 = similar (tsrc. data, buffersize)
571+ buffer2 = similar (tdst. data, buffersize)
572+
573+ while true
574+ local_counter = Threads. atomic_add! (counter, 1 )
575+ local_counter > nblocks && break
576+ @inbounds subtransformer = transformer. data[local_counter]
577+ if length (subtransformer[1 ]) == 1
578+ _add_transform_single! (tdst, tsrc, p, subtransformer, α, β, backend... )
579+ else
580+ _add_transform_multi! (tdst, tsrc, p, subtransformer, (buffer1, buffer2),
581+ α, β, backend... )
582+ end
583+ end
584+ end
585+ end
586+
587+ return nothing
588+ end
589+
590+ # Kernels
591+ # -------
592+ function _add_transform_single! (tdst, tsrc, p,
529593 (basistransform, structures_dst, structures_src),
530- backend... )
594+ α, β, backend... )
531595 structure_dst = structures_dst isa Vector ? only (structures_dst) : structures_dst
532596 structure_src = structures_src isa Vector ? only (structures_src) : structures_src
533597 coeff = basistransform isa Number ? basistransform : only (basistransform)
598+
534599 subblock_dst = StridedView (tdst. data, structure_dst... )
535600 subblock_src = StridedView (tsrc. data, structure_src... )
536601 TO. tensoradd! (subblock_dst, subblock_src, p, false , α * coeff, β, backend... )
537602 return nothing
538603end
539604
540- function _add_transform_multi! (tdst, tsrc, p, α, β,
605+ function _add_transform_multi! (tdst, tsrc, p,
541606 (basistransform, structures_dst, structures_src),
542- (buffer1, buffer2), backend... )
607+ (buffer1, buffer2), α, β, backend... )
543608 rows, cols = size (basistransform)
544609 sz_src = first (first (structures_src))
545610 blocksize = prod (sz_src)
@@ -553,7 +618,7 @@ function _add_transform_multi!(tdst, tsrc, p, α, β,
553618
554619 # Resummation into a second buffer using BLAS
555620 buffer_dst = StridedView (buffer2, (blocksize, rows), (1 , blocksize), 0 )
556- mul! (buffer_dst, buffer_src, basistransform, α)
621+ mul! (buffer_dst, buffer_src, basistransform, α, Zero () )
557622
558623 # Filling up the output
559624 for (i, structure_dst) in enumerate (structures_dst)
@@ -565,6 +630,9 @@ function _add_transform_multi!(tdst, tsrc, p, α, β,
565630 return nothing
566631end
567632
633+ # Other implementations
634+ # ---------------------
635+
568636function add_transform_kernel! (tdst:: AbstractTensorMap ,
569637 tsrc:: AbstractTensorMap ,
570638 (p₁, p₂):: Index2Tuple ,
0 commit comments