377377#-------------------------------------
378378"""
379379 add_permute!(tdst:: AbstractTensorMap , tsrc:: AbstractTensorMap , (p₁, p₂):: Index2Tuple ,
380- α:: Number , β:: Number , backend:: AbstractBackend ... )
380+ α:: Number , β:: Number , backend... )
381381
382382Return the updated ` tdst` , which is the result of adding ` α * tsrc` to ` tdst` after permuting
383383the indices of ` tsrc` according to ` (p₁, p₂)` .
@@ -389,14 +389,14 @@ See also [`permute`](@ref), [`permute!`](@ref), [`add_braid!`](@ref), [`add_tran
389389 p::Index2Tuple,
390390 α::Number,
391391 β::Number,
392- backend::AbstractBackend ...)
392+ backend...)
393393 transformer = treepermuter(tdst, tsrc, p)
394394 return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
395395end
396396
397397"""
398398 add_braid!(tdst:: AbstractTensorMap , tsrc:: AbstractTensorMap , (p₁, p₂):: Index2Tuple ,
399- levels:: IndexTuple , α:: Number , β:: Number , backend:: AbstractBackend ... )
399+ levels:: IndexTuple , α:: Number , β:: Number , backend... )
400400
401401Return the updated ` tdst` , which is the result of adding ` α * tsrc` to ` tdst` after braiding
402402the indices of ` tsrc` according to ` (p₁, p₂)` and ` levels` .
@@ -409,7 +409,7 @@ See also [`braid`](@ref), [`braid!`](@ref), [`add_permute!`](@ref), [`add_transp
409409 levels::IndexTuple,
410410 α::Number,
411411 β::Number,
412- backend::AbstractBackend ...)
412+ backend...)
413413 length(levels) == numind(tsrc) ||
414414 throw(ArgumentError("incorrect levels $levels for tensor map $(codomain(tsrc)) ← $(domain(tsrc)) "))
415415
422422
423423"""
424424 add_transpose!(tdst:: AbstractTensorMap , tsrc:: AbstractTensorMap , (p₁, p₂):: Index2Tuple ,
425- α:: Number , β:: Number , backend:: AbstractBackend ... )
425+ α:: Number , β:: Number , backend... )
426426
427427Return the updated ` tdst` , which is the result of adding ` α * tsrc` to ` tdst` after transposing
428428the indices of ` tsrc` according to ` (p₁, p₂)` .
@@ -434,18 +434,54 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad
434434 p::Index2Tuple,
435435 α::Number,
436436 β::Number,
437- backend::AbstractBackend ...)
437+ backend...)
438438 transformer = treetransposer(tdst, tsrc, p)
439439 return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
440440end
441441
442+ # Implementation
443+ # --------------
444+ """
445+ add_transform!(C, A, pA, transformer, α, β, [backend], [allocator])
446+
447+ Return the updated ` C` , which is the result of adding ` α * A` to ` β * B` ,
448+ permuting the data with ` pA` while transforming the fusiontrees with ` transformer` .
449+ """
450+ function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple,
451+ transformer, α::Number, β::Number)
452+ return add_transform!(C, A, pA, transformer, α, β, TO.DefaultBackend())
453+ end
454+ function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple,
455+ transformer, α::Number, β::Number, backend)
456+ return add_transform!(C, A, pA, transformer, α, β, backend, TO.DefaultAllocator())
457+ end
458+ function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple,
459+ transformer, α::Number, β::Number, backend, allocator)
460+ if backend isa TO.DefaultBackend
461+ newbackend = TO.select_backend(add_transform!, C, A)
462+ return add_transform!(C, A, pA, transformer, α, β, newbackend, allocator)
463+ elseif backend isa TO.NoBackend # error for missing backend
464+ TC = typeof(C)
465+ TA = typeof(A)
466+ throw(ArgumentError("No suitable backend found for `add_transform!` and tensor types $TC and $TA "))
467+ else # error for unknown backend
468+ TC = typeof(C)
469+ TA = typeof(A)
470+ throw(ArgumentError("Unknown backend $backend for `add_transform!` and tensor types $TC and $TA "))
471+ end
472+ end
473+ function TO.select_backend(::typeof(add_transform!), C::AbstractTensorMap,
474+ A::AbstractTensorMap)
475+ return TensorKitBackend()
476+ end
477+
442478function add_transform!(tdst::AbstractTensorMap,
443479 tsrc::AbstractTensorMap,
444480 (p₁, p₂)::Index2Tuple,
445481 transformer,
446482 α::Number,
447483 β::Number,
448- backend::AbstractBackend... )
484+ backend::TensorKitBackend, allocator )
449485 @boundscheck begin
450486 permute(space(tsrc), (p₁, p₂)) == space(tdst) ||
451487 throw(SpaceMismatch("source = $(codomain(tsrc)) ←$(domain(tsrc)) ,
@@ -455,7 +491,7 @@ function add_transform!(tdst::AbstractTensorMap,
455491 if p₁ === codomainind(tsrc) && p₂ === domainind(tsrc)
456492 add!(tdst, tsrc, α, β)
457493 else
458- add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend... )
494+ add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend, allocator )
459495 end
460496
461497 return tdst
@@ -467,8 +503,9 @@ function add_transform_kernel!(tdst::TensorMap,
467503 ::TrivialTreeTransformer,
468504 α::Number,
469505 β::Number,
470- backend::AbstractBackend...)
471- return TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend...)
506+ backend::TensorKitBackend, allocator)
507+ return TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend.arraybackend,
508+ allocator)
472509end
473510
474511function add_transform_kernel!(tdst::TensorMap,
@@ -477,19 +514,20 @@ function add_transform_kernel!(tdst::TensorMap,
477514 transformer::AbelianTreeTransformer,
478515 α::Number,
479516 β::Number,
480- backend::AbstractBackend... )
517+ backend::TensorKitBackend, allocator )
481518 structure_dst = transformer.structure_dst.fusiontreestructure
482519 structure_src = transformer.structure_src.fusiontreestructure
483520
484- # TODO: this could be multithreaded
485- for (row, col, val) in zip(transformer.rows, transformer.cols, transformer.vals)
521+ tforeach(transformer.rows, transformer.cols, transformer.vals;
522+ backend.scheduler) do row, col, val
486523 sz_dst, str_dst, offset_dst = structure_dst[col]
487524 subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst)
488525
489526 sz_src, str_src, offset_src = structure_src[row]
490527 subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src)
491528
492- TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * val, β, backend...)
529+ return TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * val, β,
530+ backend.arraybackend, allocator)
493531 end
494532
495533 return nothing
@@ -501,15 +539,14 @@ function add_transform_kernel!(tdst::TensorMap,
501539 transformer::GenericTreeTransformer,
502540 α::Number,
503541 β::Number,
504- backend::AbstractBackend... )
542+ backend::TensorKitBackend, allocator )
505543 structure_dst = transformer.structure_dst.fusiontreestructure
506544 structure_src = transformer.structure_src.fusiontreestructure
507545
508546 rows = rowvals(transformer.matrix)
509547 vals = nonzeros(transformer.matrix)
510548
511- # TODO: this could be multithreaded
512- for j in axes(transformer.matrix, 2)
549+ tforeach(axes(transformer.matrix, 2); backend.scheduler) do j
513550 sz_dst, str_dst, offset_dst = structure_dst[j]
514551 subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst)
515552 nzrows = nzrange(transformer.matrix, j)
@@ -519,14 +556,14 @@ function add_transform_kernel!(tdst::TensorMap,
519556 subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src)
520557 TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * vals[first(nzrows)],
521558 β,
522- backend... )
559+ backend.arraybackend, allocator )
523560
524561 # treat remaining entries
525562 for i in @view(nzrows[2:end])
526563 sz_src, str_src, offset_src = structure_src[rows[i]]
527564 subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src)
528565 TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * vals[i], One(),
529- backend... )
566+ backend.arraybackend, allocator )
530567 end
531568 end
532569
@@ -539,79 +576,79 @@ function add_transform_kernel!(tdst::AbstractTensorMap,
539576 fusiontreetransform::Function,
540577 α::Number,
541578 β::Number,
542- backend::AbstractBackend... )
579+ backend::TensorKitBackend, allocator )
543580 I = sectortype(spacetype(tdst))
544581
545582 if I === Trivial
546- _add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
583+ _add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β,
584+ backend, allocator)
547585 elseif FusionStyle(I) isa UniqueFusion
548- _add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
586+ _add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β,
587+ backend, allocator)
549588 else
550- _add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
589+ _add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β,
590+ backend, allocator)
551591 end
552592
553593 return nothing
554594end
555595
556596# internal methods: no argument types
557- function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
558- TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend...)
597+ function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β,
598+ backend::TensorKitBackend, allocator)
599+ TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend.arraybackend, allocator)
559600 return nothing
560601end
561602
562- function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
563- if Threads.nthreads() > 1
564- Threads.@sync for (f₁, f₂) in fusiontrees(tsrc)
565- Threads.@spawn _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
566- f₁, f₂, α, β, backend...)
567- end
568- else
569- for (f₁, f₂) in fusiontrees(tsrc)
570- _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
571- f₁, f₂, α, β, backend...)
572- end
603+ function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β,
604+ backend::TensorKitBackend, allocator)
605+ tforeach(fusiontrees(tsrc); backend.scheduler) do (f₁, f₂)
606+ return _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
607+ f₁, f₂, α, β, backend.arraybackend, allocator)
573608 end
574609 return nothing
575610end
576611
577- function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, backend...)
612+ function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β,
613+ backend, allocator)
578614 (f₁′, f₂′), coeff = first(fusiontreetransform(f₁, f₂))
579615 @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
580- backend... )
616+ backend, allocator )
581617 return nothing
582618end
583619
584- function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend... )
620+ function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend, allocator )
585621 if iszero(β)
586622 tdst = zerovector!(tdst)
587623 elseif β != 1
588624 tdst = scale!(tdst, β)
589625 end
590626 β′ = One()
591- if Threads.nthreads() > 1
592- Threads.@sync for s₁ in sectors(codomain(tsrc)), s₂ in sectors(domain(tsrc))
593- Threads.@spawn _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁,
594- s₂, α, β′, backend...)
595- end
596- else
627+ if backend.scheduler isa SerialScheduler
597628 for (f₁, f₂) in fusiontrees(tsrc)
598629 for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
599630 @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff,
600- β′, backend... )
631+ β′, backend.arraybackend, allocator )
601632 end
602633 end
634+ else
635+ tforeach(Iterators.product(sectors(codomain(tsrc)), sectors(domain(tsrc)))) do (s₁,
636+ s₂)
637+ return _add_nonabelian_sector!(tdts, tsrc, p, fusiontreetransform, s₁, s₂, α,
638+ β′, backend.arraybackend, allocator)
639+ end
603640 end
604641 return nothing
605642end
606643
607644# TODO: β argument is weird here because it has to be 1
608645function _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, β,
609- backend... )
646+ backend, allocator )
610647 for (f₁, f₂) in fusiontrees(tsrc)
611648 (f₁.uncoupled == s₁ && f₂.uncoupled == s₂) || continue
612649 for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
613650 @inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
614- backend... )
651+ backend, allocator )
615652 end
616653 end
617654 return nothing
0 commit comments