Skip to content

Commit 375d194

Browse files
committed
a bit more functionality and tests
1 parent 435a2a3 commit 375d194

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

src/tensors/diagtensor.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,33 @@ function Base.zero(d::DiagonalTensorMap)
198198
return DiagonalTensorMap(zero.(d.data), d.domain)
199199
end
200200

201+
function compose(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
202+
d1.domain == d2.domain || throw(SpaceMismatch())
203+
return DiagonalTensorMap(d1.data .* d2.data, d1.domain)
204+
end
205+
Base.inv(d::DiagonalTensorMap) = DiagonalTensorMap(inv.(d.data), d.domain)
206+
function Base.:\(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
207+
d1.domain == d2.domain || throw(SpaceMismatch())
208+
return DiagonalTensorMap(d1.data .\ d2.data, d1.domain)
209+
end
210+
function Base.:/(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
211+
d1.domain == d2.domain || throw(SpaceMismatch())
212+
return DiagonalTensorMap(d1.data ./ d2.data, d1.domain)
213+
end
214+
function LinearAlgebra.pinv(d::DiagonalTensorMap; kwargs...)
215+
T = eltype(d.data)
216+
atol = get(kwargs, :atol, zero(real(T)))
217+
if iszero(atol)
218+
rtol = get(kwargs, :rtol, zero(real(T)))
219+
else
220+
rtol = sqrt(eps(real(float(oneunit(T))))) * length(d.data)
221+
end
222+
pdata = let tol = max(atol, rtol * maximum(abs, d.data))
223+
map(x -> abs(x) < tol ? zero(x) : pinv(x), d.data)
224+
end
225+
return DiagonalTensorMap(pdata, d.domain)
226+
end
227+
201228
function eig!(d::DiagonalTensorMap)
202229
return d, one(d)
203230
end

src/tensors/linalg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ end
345345
function LinearAlgebra.pinv(t::AbstractTensorMap; kwargs...)
346346
T = float(scalartype(t))
347347
tpinv = similar(t, T, domain(t) codomain(t))
348+
# TODO: fix so that `rtol` used total tensor norm instead of per block
348349
for (c, b) in blocks(t)
349350
copy!(block(tpinv, c), pinv(b; kwargs...))
350351
end

test/diagtensors.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,23 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
103103
@test convert(TensorMap, t5) == permute(convert(TensorMap, t), (((), (2, 1))))
104104
end
105105
end
106+
@timedtestset "Trace, Multiplication and inverse" begin
107+
t1 = DiagonalTensorMap(rand(Float64, reduceddim(V)), V)
108+
t2 = DiagonalTensorMap(rand(ComplexF64, reduceddim(V)), V)
109+
@test tr(TensorMap(t1)) == @constinferred tr(t1)
110+
@test tr(TensorMap(t2)) == @constinferred tr(t2)
111+
@test TensorMap(@constinferred t1 * t2) TensorMap(t1) * TensorMap(t2)
112+
@test TensorMap(@constinferred t1 \ t2) TensorMap(t1) \ TensorMap(t2)
113+
@test TensorMap(@constinferred t1 / t2) TensorMap(t1) / TensorMap(t2)
114+
@test TensorMap(@constinferred inv(t1)) inv(TensorMap(t1))
115+
@test TensorMap(@constinferred pinv(t1)) pinv(TensorMap(t1))
116+
@test all(Base.Fix2(isa, DiagonalTensorMap),
117+
(t1 * t2, t1 \ t2, t1 / t2, inv(t1), pinv(t1)))
118+
119+
u = randn(Float64, V * V' * V, V)
120+
@test u * t1 ≈ u * TensorMap(t1)
121+
@test u / t1 ≈ u / TensorMap(t1)
122+
@test t1 * u' TensorMap(t1) * u'
123+
@test t1 \ u' TensorMap(t1) \ u'
124+
end
106125
end

0 commit comments

Comments
 (0)