Skip to content
Merged
17 changes: 14 additions & 3 deletions src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,21 @@ function Base.zero(d::DiagonalTensorMap)
return DiagonalTensorMap(zero.(d.data), d.domain)
end

function compose(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
d1.domain == d2.domain || throw(SpaceMismatch())
return DiagonalTensorMap(d1.data .* d2.data, d1.domain)
function compose_dest(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
A = promote_type(storagetype(d1), storagetype(d2))
S = spacetype(d1)
T = scalartype(A)
return DiagonalTensorMap{T,S,A}(undef, d1.domain)
end

function LinearAlgebra.mul!(dC::DiagonalTensorMap,
dA::DiagonalTensorMap,
dB::DiagonalTensorMap)
dC.domain == dA.domain == dB.domain || throw(SpaceMismatch())
dC.data .= dA.data .* dB.data
return dC
end

Base.inv(d::DiagonalTensorMap) = DiagonalTensorMap(inv.(d.data), d.domain)
function Base.:\(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
d1.domain == d2.domain || throw(SpaceMismatch())
Expand Down