Skip to content

Commit 9f2c273

Browse files
committed
Add backend/allocator support in add_transform!
1 parent a4d0e9d commit 9f2c273

File tree

3 files changed

+90
-53
lines changed

3 files changed

+90
-53
lines changed

src/tensors/braidingtensor.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ function add_transform!(tdst::AbstractTensorMap,
158158
fusiontreetransform,
159159
α::Number,
160160
β::Number,
161-
backend::AbstractBackend...)
161+
backend::TensorKitBackend, allocator)
162162
return add_transform!(tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
163-
backend...)
163+
backend, allocator)
164164
end
165165

166166
# VectorInterface
@@ -173,8 +173,8 @@ end
173173

174174
function TO.tensoradd!(C::AbstractTensorMap,
175175
A::BraidingTensor, pA::Index2Tuple, conjA::Symbol,
176-
α::Number, β::Number, backend=TO.DefaultBackend(),
177-
allocator=TO.DefaultAllocator())
176+
α::Number, β::Number, backend::AbstractBackend,
177+
allocator)
178178
return TO.tensoradd!(C, TensorMap(A), pA, conjA, α, β, backend, allocator)
179179
end
180180

src/tensors/indexmanipulations.jl

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ end
377377
#-------------------------------------
378378
"""
379379
add_permute!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple,
380-
α::Number, β::Number, backend::AbstractBackend...)
380+
α::Number, β::Number, backend...)
381381

382382
Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after permuting
383383
the indices of `tsrc` according to `(p₁, p₂)`.
@@ -389,14 +389,14 @@ See also [`permute`](@ref), [`permute!`](@ref), [`add_braid!`](@ref), [`add_tran
389389
p::Index2Tuple,
390390
α::Number,
391391
β::Number,
392-
backend::AbstractBackend...)
392+
backend...)
393393
transformer = treepermuter(tdst, tsrc, p)
394394
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
395395
end
396396
397397
"""
398398
add_braid!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple,
399-
levels::IndexTuple, α::Number, β::Number, backend::AbstractBackend...)
399+
levels::IndexTuple, α::Number, β::Number, backend...)
400400

401401
Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after braiding
402402
the indices of `tsrc` according to `(p₁, p₂)` and `levels`.
@@ -409,7 +409,7 @@ See also [`braid`](@ref), [`braid!`](@ref), [`add_permute!`](@ref), [`add_transp
409409
levels::IndexTuple,
410410
α::Number,
411411
β::Number,
412-
backend::AbstractBackend...)
412+
backend...)
413413
length(levels) == numind(tsrc) ||
414414
throw(ArgumentError("incorrect levels $levels for tensor map $(codomain(tsrc))$(domain(tsrc))"))
415415
@@ -422,7 +422,7 @@ end
422422
423423
"""
424424
add_transpose!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple,
425-
α::Number, β::Number, backend::AbstractBackend...)
425+
α::Number, β::Number, backend...)
426426

427427
Return the updated `tdst`, which is the result of adding `α * tsrc` to `tdst` after transposing
428428
the indices of `tsrc` according to `(p₁, p₂)`.
@@ -434,18 +434,54 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad
434434
p::Index2Tuple,
435435
α::Number,
436436
β::Number,
437-
backend::AbstractBackend...)
437+
backend...)
438438
transformer = treetransposer(tdst, tsrc, p)
439439
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
440440
end
441441
442+
# Implementation
443+
# --------------
444+
"""
445+
add_transform!(C, A, pA, transformer, α, β, [backend], [allocator])
446+
447+
Return the updated `C`, which is the result of adding `α * A` to `β * B`,
448+
permuting the data with `pA` while transforming the fusiontrees with `transformer`.
449+
"""
450+
function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple,
451+
transformer, α::Number, β::Number)
452+
return add_transform!(C, A, pA, transformer, α, β, TO.DefaultBackend())
453+
end
454+
function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple,
455+
transformer, α::Number, β::Number, backend)
456+
return add_transform!(C, A, pA, transformer, α, β, backend, TO.DefaultAllocator())
457+
end
458+
function add_transform!(C::AbstractTensorMap, A::AbstractTensorMap, pA::Index2Tuple,
459+
transformer, α::Number, β::Number, backend, allocator)
460+
if backend isa TO.DefaultBackend
461+
newbackend = TO.select_backend(add_transform!, C, A)
462+
return add_transform!(C, A, pA, transformer, α, β, newbackend, allocator)
463+
elseif backend isa TO.NoBackend # error for missing backend
464+
TC = typeof(C)
465+
TA = typeof(A)
466+
throw(ArgumentError("No suitable backend found for `add_transform!` and tensor types $TC and $TA"))
467+
else # error for unknown backend
468+
TC = typeof(C)
469+
TA = typeof(A)
470+
throw(ArgumentError("Unknown backend $backend for `add_transform!` and tensor types $TC and $TA"))
471+
end
472+
end
473+
function TO.select_backend(::typeof(add_transform!), C::AbstractTensorMap,
474+
A::AbstractTensorMap)
475+
return TensorKitBackend()
476+
end
477+
442478
function add_transform!(tdst::AbstractTensorMap,
443479
tsrc::AbstractTensorMap,
444480
(p₁, p₂)::Index2Tuple,
445481
transformer,
446482
α::Number,
447483
β::Number,
448-
backend::AbstractBackend...)
484+
backend::TensorKitBackend, allocator)
449485
@boundscheck begin
450486
permute(space(tsrc), (p₁, p₂)) == space(tdst) ||
451487
throw(SpaceMismatch("source = $(codomain(tsrc))$(domain(tsrc)),
@@ -455,7 +491,7 @@ function add_transform!(tdst::AbstractTensorMap,
455491
if p₁ === codomainind(tsrc) && p₂ === domainind(tsrc)
456492
add!(tdst, tsrc, α, β)
457493
else
458-
add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend...)
494+
add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend, allocator)
459495
end
460496
461497
return tdst
@@ -467,8 +503,9 @@ function add_transform_kernel!(tdst::TensorMap,
467503
::TrivialTreeTransformer,
468504
α::Number,
469505
β::Number,
470-
backend::AbstractBackend...)
471-
return TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend...)
506+
backend::TensorKitBackend, allocator)
507+
return TO.tensoradd!(tdst[], tsrc[], (p₁, p₂), false, α, β, backend.arraybackend,
508+
allocator)
472509
end
473510
474511
function add_transform_kernel!(tdst::TensorMap,
@@ -477,19 +514,20 @@ function add_transform_kernel!(tdst::TensorMap,
477514
transformer::AbelianTreeTransformer,
478515
α::Number,
479516
β::Number,
480-
backend::AbstractBackend...)
517+
backend::TensorKitBackend, allocator)
481518
structure_dst = transformer.structure_dst.fusiontreestructure
482519
structure_src = transformer.structure_src.fusiontreestructure
483520
484-
# TODO: this could be multithreaded
485-
for (row, col, val) in zip(transformer.rows, transformer.cols, transformer.vals)
521+
tforeach(transformer.rows, transformer.cols, transformer.vals;
522+
backend.scheduler) do row, col, val
486523
sz_dst, str_dst, offset_dst = structure_dst[col]
487524
subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst)
488525
489526
sz_src, str_src, offset_src = structure_src[row]
490527
subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src)
491528
492-
TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * val, β, backend...)
529+
return TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * val, β,
530+
backend.arraybackend, allocator)
493531
end
494532
495533
return nothing
@@ -501,15 +539,14 @@ function add_transform_kernel!(tdst::TensorMap,
501539
transformer::GenericTreeTransformer,
502540
α::Number,
503541
β::Number,
504-
backend::AbstractBackend...)
542+
backend::TensorKitBackend, allocator)
505543
structure_dst = transformer.structure_dst.fusiontreestructure
506544
structure_src = transformer.structure_src.fusiontreestructure
507545
508546
rows = rowvals(transformer.matrix)
509547
vals = nonzeros(transformer.matrix)
510548
511-
# TODO: this could be multithreaded
512-
for j in axes(transformer.matrix, 2)
549+
tforeach(axes(transformer.matrix, 2); backend.scheduler) do j
513550
sz_dst, str_dst, offset_dst = structure_dst[j]
514551
subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst)
515552
nzrows = nzrange(transformer.matrix, j)
@@ -519,14 +556,14 @@ function add_transform_kernel!(tdst::TensorMap,
519556
subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src)
520557
TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * vals[first(nzrows)],
521558
β,
522-
backend...)
559+
backend.arraybackend, allocator)
523560
524561
# treat remaining entries
525562
for i in @view(nzrows[2:end])
526563
sz_src, str_src, offset_src = structure_src[rows[i]]
527564
subblock_src = StridedView(tsrc.data, sz_src, str_src, offset_src)
528565
TO.tensoradd!(subblock_dst, subblock_src, (p₁, p₂), false, α * vals[i], One(),
529-
backend...)
566+
backend.arraybackend, allocator)
530567
end
531568
end
532569
@@ -539,79 +576,79 @@ function add_transform_kernel!(tdst::AbstractTensorMap,
539576
fusiontreetransform::Function,
540577
α::Number,
541578
β::Number,
542-
backend::AbstractBackend...)
579+
backend::TensorKitBackend, allocator)
543580
I = sectortype(spacetype(tdst))
544581
545582
if I === Trivial
546-
_add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
583+
_add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β,
584+
backend, allocator)
547585
elseif FusionStyle(I) isa UniqueFusion
548-
_add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
586+
_add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β,
587+
backend, allocator)
549588
else
550-
_add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
589+
_add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β,
590+
backend, allocator)
551591
end
552592
553593
return nothing
554594
end
555595
556596
# internal methods: no argument types
557-
function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
558-
TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend...)
597+
function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β,
598+
backend::TensorKitBackend, allocator)
599+
TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend.arraybackend, allocator)
559600
return nothing
560601
end
561602
562-
function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
563-
if Threads.nthreads() > 1
564-
Threads.@sync for (f₁, f₂) in fusiontrees(tsrc)
565-
Threads.@spawn _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
566-
f₁, f₂, α, β, backend...)
567-
end
568-
else
569-
for (f₁, f₂) in fusiontrees(tsrc)
570-
_add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
571-
f₁, f₂, α, β, backend...)
572-
end
603+
function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β,
604+
backend::TensorKitBackend, allocator)
605+
tforeach(fusiontrees(tsrc); backend.scheduler) do (f₁, f₂)
606+
return _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
607+
f₁, f₂, α, β, backend.arraybackend, allocator)
573608
end
574609
return nothing
575610
end
576611
577-
function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, backend...)
612+
function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β,
613+
backend, allocator)
578614
(f₁′, f₂′), coeff = first(fusiontreetransform(f₁, f₂))
579615
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
580-
backend...)
616+
backend, allocator)
581617
return nothing
582618
end
583619
584-
function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
620+
function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend, allocator)
585621
if iszero(β)
586622
tdst = zerovector!(tdst)
587623
elseif β != 1
588624
tdst = scale!(tdst, β)
589625
end
590626
β′ = One()
591-
if Threads.nthreads() > 1
592-
Threads.@sync for s₁ in sectors(codomain(tsrc)), s₂ in sectors(domain(tsrc))
593-
Threads.@spawn _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁,
594-
s₂, α, β′, backend...)
595-
end
596-
else
627+
if backend.scheduler isa SerialScheduler
597628
for (f₁, f₂) in fusiontrees(tsrc)
598629
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
599630
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff,
600-
β′, backend...)
631+
β′, backend.arraybackend, allocator)
601632
end
602633
end
634+
else
635+
tforeach(Iterators.product(sectors(codomain(tsrc)), sectors(domain(tsrc)))) do (s₁,
636+
s₂)
637+
return _add_nonabelian_sector!(tdts, tsrc, p, fusiontreetransform, s₁, s₂, α,
638+
β′, backend.arraybackend, allocator)
639+
end
603640
end
604641
return nothing
605642
end
606643
607644
# TODO: β argument is weird here because it has to be 1
608645
function _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, β,
609-
backend...)
646+
backend, allocator)
610647
for (f₁, f₂) in fusiontrees(tsrc)
611648
(f₁.uncoupled == s₁ && f₂.uncoupled == s₂) || continue
612649
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
613650
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
614-
backend...)
651+
backend, allocator)
615652
end
616653
end
617654
return nothing

src/tensors/tensoroperations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ TO.tensorcost(t::AbstractTensorMap, i::Int) = dim(space(t, i))
150150
# TODO: what should be the default scheduler?
151151
# TODO: should we allow a separate scheduler for "blocks" and "subblocks"
152152
@kwdef struct TensorKitBackend{B<:AbstractBackend,S<:Scheduler} <: AbstractBackend
153-
arraybackend::B = DefaultBackend()
153+
arraybackend::B = TO.DefaultBackend()
154154
scheduler::S = SerialScheduler()
155155
end
156156

0 commit comments

Comments
 (0)