@@ -328,20 +328,24 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad
328328 return add_transform! (tdst, tsrc, p, transformer, α, β, backend... )
329329end
330330
331- function add_transform! (tdst:: AbstractTensorMap{T,S,N₁,N₂} ,
331+ function add_transform! (tdst:: AbstractTensorMap ,
332332 tsrc:: AbstractTensorMap ,
333- (p₁, p₂):: Index2Tuple{N₁,N₂} ,
333+ (p₁, p₂):: Index2Tuple ,
334334 transformer,
335335 α:: Number ,
336336 β:: Number ,
337- backend:: AbstractBackend... ) where {T,S,N₁,N₂}
337+ backend:: AbstractBackend... )
338338 @boundscheck begin
339339 permute (space (tsrc), (p₁, p₂)) == space (tdst) ||
340340 throw (SpaceMismatch (" source = $(codomain (tsrc)) ←$(domain (tsrc)) ,
341341 dest = $(codomain (tdst)) ←$(domain (tdst)) , p₁ = $(p₁) , p₂ = $(p₂) " ))
342342 end
343343
344- add_transform_kernel! (tdst, tsrc, (p₁, p₂), transformer, α, β, backend... )
344+ if p₁ === codomainind (tsrc) && p₂ === domainind (tsrc)
345+ add! (tdst, tsrc, α, β)
346+ else
347+ add_transform_kernel! (tdst, tsrc, (p₁, p₂), transformer, α, β, backend... )
348+ end
345349
346350 return tdst
347351end
@@ -417,3 +421,87 @@ function add_transform_kernel!(tdst::TensorMap,
417421
418422 return tdst
419423end
424+
425+ function add_transform_kernel! (tdst:: AbstractTensorMap ,
426+ tsrc:: AbstractTensorMap ,
427+ (p₁, p₂):: Index2Tuple ,
428+ fusiontreetransform:: Function ,
429+ α:: Number ,
430+ β:: Number ,
431+ backend:: AbstractBackend... )
432+ I = sectortype (spacetype (tdst))
433+
434+ if I === Trivial
435+ _add_trivial_kernel! (tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend... )
436+ elseif FusionStyle (I) isa UniqueFusion
437+ _add_abelian_kernel! (tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend... )
438+ else
439+ _add_general_kernel! (tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend... )
440+ end
441+
442+ return nothing
443+ end
444+
445+ # internal methods: no argument types
446+ function _add_trivial_kernel! (tdst, tsrc, p, fusiontreetransform, α, β, backend... )
447+ TO. tensoradd! (tdst[], tsrc[], p, false , α, β, backend... )
448+ return nothing
449+ end
450+
451+ function _add_abelian_kernel! (tdst, tsrc, p, fusiontreetransform, α, β, backend... )
452+ if Threads. nthreads () > 1
453+ Threads. @sync for (f₁, f₂) in fusiontrees (tsrc)
454+ Threads. @spawn _add_abelian_block! (tdst, tsrc, p, fusiontreetransform,
455+ f₁, f₂, α, β, backend... )
456+ end
457+ else
458+ for (f₁, f₂) in fusiontrees (tsrc)
459+ _add_abelian_block! (tdst, tsrc, p, fusiontreetransform,
460+ f₁, f₂, α, β, backend... )
461+ end
462+ end
463+ return nothing
464+ end
465+
466+ function _add_abelian_block! (tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, backend... )
467+ (f₁′, f₂′), coeff = first (fusiontreetransform (f₁, f₂))
468+ @inbounds TO. tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff, β,
469+ backend... )
470+ return nothing
471+ end
472+
473+ function _add_general_kernel! (tdst, tsrc, p, fusiontreetransform, α, β, backend... )
474+ if iszero (β)
475+ tdst = zerovector! (tdst)
476+ elseif β != 1
477+ tdst = scale! (tdst, β)
478+ end
479+ β′ = One ()
480+ if Threads. nthreads () > 1
481+ Threads. @sync for s₁ in sectors (codomain (tsrc)), s₂ in sectors (domain (tsrc))
482+ Threads. @spawn _add_nonabelian_sector! (tdst, tsrc, p, fusiontreetransform, s₁,
483+ s₂, α, β′, backend... )
484+ end
485+ else
486+ for (f₁, f₂) in fusiontrees (tsrc)
487+ for ((f₁′, f₂′), coeff) in fusiontreetransform (f₁, f₂)
488+ @inbounds TO. tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff,
489+ β′, backend... )
490+ end
491+ end
492+ end
493+ return nothing
494+ end
495+
496+ # TODO : β argument is weird here because it has to be 1
497+ function _add_nonabelian_sector! (tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, β,
498+ backend... )
499+ for (f₁, f₂) in fusiontrees (tsrc)
500+ (f₁. uncoupled == s₁ && f₂. uncoupled == s₂) || continue
501+ for ((f₁′, f₂′), coeff) in fusiontreetransform (f₁, f₂)
502+ @inbounds TO. tensoradd! (tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false , α * coeff, β,
503+ backend... )
504+ end
505+ end
506+ return nothing
507+ end
0 commit comments