Skip to content

Commit 234a973

Browse files
lkdvosJutho
authored andcommitted
Reinstate general add_transform! kernels to deal with AdjointTensorMap
1 parent 32308ed commit 234a973

File tree

2 files changed

+105
-7
lines changed

2 files changed

+105
-7
lines changed

src/tensors/indexmanipulations.jl

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,24 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad
328328
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
329329
end
330330

331-
function add_transform!(tdst::AbstractTensorMap{T,S,N₁,N₂},
331+
function add_transform!(tdst::AbstractTensorMap,
332332
tsrc::AbstractTensorMap,
333-
(p₁, p₂)::Index2Tuple{N₁,N₂},
333+
(p₁, p₂)::Index2Tuple,
334334
transformer,
335335
α::Number,
336336
β::Number,
337-
backend::AbstractBackend...) where {T,S,N₁,N₂}
337+
backend::AbstractBackend...)
338338
@boundscheck begin
339339
permute(space(tsrc), (p₁, p₂)) == space(tdst) ||
340340
throw(SpaceMismatch("source = $(codomain(tsrc))$(domain(tsrc)),
341341
dest = $(codomain(tdst))$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)"))
342342
end
343343

344-
add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend...)
344+
if p₁ === codomainind(tsrc) && p₂ === domainind(tsrc)
345+
add!(tdst, tsrc, α, β)
346+
else
347+
add_transform_kernel!(tdst, tsrc, (p₁, p₂), transformer, α, β, backend...)
348+
end
345349

346350
return tdst
347351
end
@@ -417,3 +421,87 @@ function add_transform_kernel!(tdst::TensorMap,
417421

418422
return tdst
419423
end
424+
425+
function add_transform_kernel!(tdst::AbstractTensorMap,
426+
tsrc::AbstractTensorMap,
427+
(p₁, p₂)::Index2Tuple,
428+
fusiontreetransform::Function,
429+
α::Number,
430+
β::Number,
431+
backend::AbstractBackend...)
432+
I = sectortype(spacetype(tdst))
433+
434+
if I === Trivial
435+
_add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
436+
elseif FusionStyle(I) isa UniqueFusion
437+
_add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
438+
else
439+
_add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
440+
end
441+
442+
return nothing
443+
end
444+
445+
# internal methods: no argument types
446+
function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
447+
TO.tensoradd!(tdst[], tsrc[], p, false, α, β, backend...)
448+
return nothing
449+
end
450+
451+
function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
452+
if Threads.nthreads() > 1
453+
Threads.@sync for (f₁, f₂) in fusiontrees(tsrc)
454+
Threads.@spawn _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
455+
f₁, f₂, α, β, backend...)
456+
end
457+
else
458+
for (f₁, f₂) in fusiontrees(tsrc)
459+
_add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
460+
f₁, f₂, α, β, backend...)
461+
end
462+
end
463+
return nothing
464+
end
465+
466+
function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, backend...)
467+
(f₁′, f₂′), coeff = first(fusiontreetransform(f₁, f₂))
468+
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
469+
backend...)
470+
return nothing
471+
end
472+
473+
function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
474+
if iszero(β)
475+
tdst = zerovector!(tdst)
476+
elseif β != 1
477+
tdst = scale!(tdst, β)
478+
end
479+
β′ = One()
480+
if Threads.nthreads() > 1
481+
Threads.@sync for s₁ in sectors(codomain(tsrc)), s₂ in sectors(domain(tsrc))
482+
Threads.@spawn _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁,
483+
s₂, α, β′, backend...)
484+
end
485+
else
486+
for (f₁, f₂) in fusiontrees(tsrc)
487+
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
488+
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff,
489+
β′, backend...)
490+
end
491+
end
492+
end
493+
return nothing
494+
end
495+
496+
# TODO: β argument is weird here because it has to be 1
497+
function _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, β,
498+
backend...)
499+
for (f₁, f₂) in fusiontrees(tsrc)
500+
(f₁.uncoupled == s₁ && f₂.uncoupled == s₂) || continue
501+
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
502+
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff, β,
503+
backend...)
504+
end
505+
end
506+
return nothing
507+
end

src/tensors/treetransformers.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ end
7575
const treetransposercache = LRU{Any,Any}(; maxsize=10^5)
7676
const usetreetransposercache = Ref{Bool}(true)
7777

78-
function treetransposer(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple)
78+
function treetransposer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple)
79+
return fusiontreetransform(f1, f2) = transpose(f1, f2, p...)
80+
end
81+
function treetransposer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple)
7982
if usetreetransposercache[]
8083
key = (space(tdst), space(tsrc), p)
8184
A = treetransformertype(space(tdst), space(tsrc))
@@ -100,7 +103,11 @@ end
100103
const treebraidercache = LRU{Any,Any}(; maxsize=10^5)
101104
const usetreebraidercache = Ref{Bool}(true)
102105

103-
function treebraider(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple,
106+
function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple,
107+
l::Index2Tuple)
108+
return fusiontreetransform(f1, f2) = braid(f1, f2, p..., l...)
109+
end
110+
function treebraider(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple,
104111
l::Index2Tuple)
105112
if usetreebraidercache[]
106113
key = (space(tdst), space(tsrc), p, l)
@@ -126,7 +133,10 @@ end
126133
const treepermutercache = LRU{Any,Any}(; maxsize=10^5)
127134
const usetreepermutercache = Ref{Bool}(true)
128135

129-
function treepermuter(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple)
136+
function treepermuter(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple)
137+
return fusiontreetransform(f1, f2) = permute(f1, f2, p...)
138+
end
139+
function treepermuter(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple)
130140
if usetreepermutercache[]
131141
key = (space(tdst), space(tsrc), p)
132142
A = treetransformertype(space(tdst), space(tsrc))

0 commit comments

Comments
 (0)