Skip to content

Commit ba8ba47

Browse files
authored
Add absorb for putting (part of) the contents of one tensor in another (#283)
* Add method for embedding tensors * Add small amount of tests * Fix issue with kwarg interpreted as iterator * rename `absorb` and add out-of-place version
1 parent 980fe55 commit ba8ba47

File tree

5 files changed

+69
-13
lines changed

5 files changed

+69
-13
lines changed

src/TensorKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ export leftorth, rightorth, leftnull, rightnull,
7676
isposdef, isposdef!, ishermitian, sylvester, rank, cond
7777
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
7878
repartition!
79-
export catdomain, catcodomain
79+
export catdomain, catcodomain, absorb, absorb!
8080

8181
export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQpos,
8282
SVD, SDD, Polar

src/spaces/gradedspace.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,23 +168,19 @@ function fuse(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector}
168168
end
169169

170170
function infimum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector}
171-
if V₁.dual == V₂.dual
172-
typeof(V₁)(c => min(dim(V₁, c), dim(V₂, c))
173-
for c in
174-
union(sectors(V₁), sectors(V₂)), dual in V₁.dual)
175-
else
171+
Visdual = isdual(V₁)
172+
Visdual == isdual(V₂) ||
176173
throw(SpaceMismatch("Infimum of space and dual space does not exist"))
177-
end
174+
return typeof(V₁)((Visdual ? dual(c) : c) => min(dim(V₁, c), dim(V₂, c))
175+
for c in intersect(sectors(V₁), sectors(V₂)); dual=Visdual)
178176
end
179177

180178
function supremum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector}
181-
if V₁.dual == V₂.dual
182-
typeof(V₁)(c => max(dim(V₁, c), dim(V₂, c))
183-
for c in
184-
union(sectors(V₁), sectors(V₂)), dual in V₁.dual)
185-
else
179+
Visdual = isdual(V₁)
180+
Visdual == isdual(V₂) ||
186181
throw(SpaceMismatch("Supremum of space and dual space does not exist"))
187-
end
182+
return typeof(V₁)((Visdual ? dual(c) : c) => max(dim(V₁, c), dim(V₂, c))
183+
for c in union(sectors(V₁), sectors(V₂)); dual=Visdual)
188184
end
189185

190186
function Base.show(io::IO, V::GradedSpace{I}) where {I<:Sector}

src/spaces/homspace.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ function dim(W::HomSpace)
125125
return d
126126
end
127127

128+
"""
129+
fusiontrees(W::HomSpace)
130+
131+
Return the fusiontrees corresponding to all valid fusion channels of a given `HomSpace`.
132+
"""
133+
fusiontrees(W::HomSpace) = fusionblockstructure(W).fusiontreelist
134+
128135
# Operations on HomSpaces
129136
# -----------------------
130137
"""

src/tensors/linalg.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,38 @@ function catcodomain(t1::TT, t2::TT) where {S,N₂,TT<:AbstractTensorMap{<:Any,S
512512
return t
513513
end
514514

515+
"""
516+
absorb(tdst::AbstractTensorMap, tsrc::AbstractTensorMap)
517+
absorb!(tdst::AbstactTensorMap, tsrc::AbstractTensorMap)
518+
519+
Absorb the contents of `tsrc` into `tdst`, which may have different sizes of data.
520+
This is equivalent to the following operation on dense arrays, but also works for symmetric
521+
tensors. Note also that this only overwrites the regions that are shared, and will do
522+
nothing on the ones that are not, so it is up to the user to properly initialize the
523+
destination.
524+
525+
```julia
526+
sub_axes = map((x, y) -> 1:min(x, y), size(tdst), size(tsrc))
527+
tdst[sub_axes...] .= tsrc[sub_axes...]
528+
```
529+
"""
530+
absorb(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) = absorb!(copy(tdst), tsrc)
531+
function absorb!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap)
532+
numin(tdst) == numin(tsrc) && numout(tdst) == numout(tsrc) ||
533+
throw(DimensionError("Incompatible number of indices for source and destination"))
534+
S = spacetype(tdst)
535+
S == spacetype(tsrc) || throw(SpaceMismatch("incompatible spacetypes"))
536+
dom = mapreduce(infimum, , domain(tdst), domain(tsrc); init=one(S))
537+
cod = mapreduce(infimum, , codomain(tdst), codomain(tsrc); init=one(S))
538+
for (f1, f2) in fusiontrees(cod dom)
539+
@inbounds data_dst = tdst[f1, f2]
540+
@inbounds data_src = tsrc[f1, f2]
541+
sub_axes = map(Base.OneTo min, size(data_dst), size(data_src))
542+
data_dst[sub_axes...] .= data_src[sub_axes...]
543+
end
544+
return tdst
545+
end
546+
515547
# tensor product of tensors
516548
"""
517549
⊗(t1::AbstractTensorMap, t2::AbstractTensorMap, ...) -> TensorMap

test/tensors.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,27 @@ for V in spacelist
739739
@test t t′
740740
end
741741
end
742+
@timedtestset "Tensor absorpsion" begin
743+
# absorbing small into large
744+
t1 = zeros(V1 V1, V2 V3)
745+
t2 = rand(V1, V2 V3)
746+
t3 = @constinferred absorb(t1, t2)
747+
@test norm(t3) norm(t2)
748+
@test norm(t1) == 0
749+
t4 = @constinferred absorb!(t1, t2)
750+
@test t1 === t4
751+
@test t3 t4
752+
753+
# absorbing large into small
754+
t1 = rand(V1 V1, V2 V3)
755+
t2 = zeros(V1, V2 V3)
756+
t3 = @constinferred absorb(t2, t1)
757+
@test norm(t3) < norm(t1)
758+
@test norm(t2) == 0
759+
t4 = @constinferred absorb!(t2, t1)
760+
@test t2 === t4
761+
@test t3 t4
762+
end
742763
end
743764
TensorKit.empty_globalcaches!()
744765
end

0 commit comments

Comments
 (0)