Skip to content

Commit fd9cb9f

Browse files
committed
Add TensorOperations specializations
1 parent b9d02fa commit fd9cb9f

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

src/tensors/diagonal.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,34 @@ function VectorInterface.add(ty::DiagonalTensorMap, tx::DiagonalTensorMap,
185185
return add!(scale!(zerovector(ty, T), ty, β), tx, α) # zerovector instead of similar preserves diagonal structure
186186
end
187187

188+
# TensorOperations
189+
# ----------------
190+
function TO.tensoradd_type(TC, A::DiagonalTensorMap, ::Index2Tuple{1,1}, ::Bool)
191+
M = similarstoragetype(A, TC)
192+
return DiagonalTensorMap{TC,spacetype(A),M}
193+
end
194+
195+
function TO.tensorcontract_type(TC, A::DiagonalTensorMap, ::Index2Tuple{1,1}, ::Bool,
196+
B::DiagonalTensorMap, ::Index2Tuple{1,1}, ::Bool,
197+
::Index2Tuple{1,1})
198+
M = similarstoragetype(A, TC)
199+
M == similarstoragetype(B, TC) ||
200+
throw(ArgumentError("incompatible storage types:\n$(M)$(similarstoragetype(B, TC))"))
201+
spacetype(A) == spacetype(B) || throw(SpaceMismatch("incompatible space types"))
202+
return DiagonalTensorMap{TC,spacetype(A),M}
203+
end
204+
205+
function TO.tensoralloc(::Type{DiagonalTensorMap{T,S,M}},
206+
structure::TensorMapSpace{S,1,1},
207+
istemp::Val,
208+
allocator=TO.DefaultAllocator()) where {T,S,M}
209+
domain(structure) == codomain(structure) || throw(ArgumentError("domain ≠ codomain"))
210+
V = only(domain(structure))
211+
dim = reduceddim(V)
212+
data = TO.tensoralloc(M, dim, istemp, allocator)
213+
return DiagonalTensorMap{T,S,M}(data, V)
214+
end
215+
188216
# Linear Algebra and factorizations
189217
# ---------------------------------
190218
function one!(d::DiagonalTensorMap)
@@ -198,13 +226,6 @@ function Base.zero(d::DiagonalTensorMap)
198226
return DiagonalTensorMap(zero.(d.data), d.domain)
199227
end
200228

201-
function compose_dest(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
202-
A = promote_type(storagetype(d1), storagetype(d2))
203-
S = spacetype(d1)
204-
T = scalartype(A)
205-
return DiagonalTensorMap{T,S,A}(undef, d1.domain)
206-
end
207-
208229
function LinearAlgebra.mul!(dC::DiagonalTensorMap,
209230
dA::DiagonalTensorMap,
210231
dB::DiagonalTensorMap,

test/diagonal.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
130130
if BraidingStyle(I) isa SymmetricBraiding
131131
@tensor C[a b c; d] := A[a b c; e] * d[e, d]
132132
@test C A * d
133+
@tensor D[a; b] := d[a, c] * d[c, b]
134+
@test D d * d
135+
@test D isa DiagonalTensorMap
133136
end
134137
@planar E1[-1 -2 -3; -4 -5] := B[-1 -2 -3; 1 -5] * d[1; -4]
135138
@planar E2[-1 -2 -3; -4 -5] := B[-1 -2 -3; 1 -5] * t[1; -4]

0 commit comments

Comments
 (0)