Skip to content

Commit f7eabe5

Browse files
authored
Fix converting empty tensors to array (#180)
* Fix conversion of empty tensor * Add tests for empty tensor conversion * Add `zero(::ElementarySpace)` * Fix failing tests * Add explicit test for issue #178 * Fix wrong type * `eltype` should determine precision * Add entry to docs * small fix * small fix attempt II Fixes #178
1 parent 8b38973 commit f7eabe5

File tree

10 files changed

+43
-18
lines changed

10 files changed

+43
-18
lines changed

docs/src/lib/spaces.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ dual
9090
conj
9191
flip
9292
93+
zero(::ElementarySpace)
9394
oneunit
9495
supremum
9596
infimum

src/spaces/cartesianspace.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ sectors(V::CartesianSpace) = OneOrNoneIterator(dim(V) != 0, Trivial())
4848
sectortype(::Type{CartesianSpace}) = Trivial
4949

5050
Base.oneunit(::Type{CartesianSpace}) = CartesianSpace(1)
51+
Base.zero(::Type{CartesianSpace}) = CartesianSpace(0)
5152
(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d + V₂.d)
5253
fuse(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d * V₂.d)
5354
flip(V::CartesianSpace) = V

src/spaces/complexspace.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ sectortype(::Type{ComplexSpace}) = Trivial
4949
Base.conj(V::ComplexSpace) = ComplexSpace(dim(V), !isdual(V))
5050

5151
Base.oneunit(::Type{ComplexSpace}) = ComplexSpace(1)
52+
Base.zero(::Type{ComplexSpace}) = ComplexSpace(0)
5253
function (V₁::ComplexSpace, V₂::ComplexSpace)
5354
return isdual(V₁) == isdual(V₂) ?
5455
ComplexSpace(dim(V₁) + dim(V₂), isdual(V₁)) :

src/spaces/generalspace.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ sectortype(::Type{<:GeneralSpace}) = Trivial
3535
field(::Type{GeneralSpace{𝔽}}) where {𝔽} = 𝔽
3636
InnerProductStyle(::Type{<:GeneralSpace}) = NoInnerProduct()
3737

38+
Base.oneunit(::Type{GeneralSpace{𝔽}}) where {𝔽} = GeneralSpace{𝔽}(1, false, false)
39+
Base.zero(::Type{GeneralSpace{𝔽}}) where {𝔽} = GeneralSpace{𝔽}(0, false, false)
40+
3841
dual(V::GeneralSpace{𝔽}) where {𝔽} = GeneralSpace{𝔽}(dim(V), !isdual(V), isconj(V))
3942
Base.conj(V::GeneralSpace{𝔽}) where {𝔽} = GeneralSpace{𝔽}(dim(V), isdual(V), !isconj(V))
4043

src/spaces/gradedspace.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ function Base.axes(V::GradedSpace{I}, c::I) where {I<:Sector}
132132
end
133133

134134
Base.oneunit(S::Type{<:GradedSpace{I}}) where {I<:Sector} = S(one(I) => 1)
135+
Base.zero(S::Type{<:GradedSpace{I}}) where {I<:Sector} = S(one(I) => 0)
135136

136137
# TODO: the following methods can probably be implemented more efficiently for
137138
# `FiniteGradedSpace`, but we don't expect them to be used often in hot loops, so

src/spaces/vectorspaces.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ that this is different from `one(V::S)`, which returns the empty product space
128128
"""
129129
Base.oneunit(V::ElementarySpace) = oneunit(typeof(V))
130130

131+
"""
132+
zero(V::S) where {S<:ElementarySpace} -> S
133+
134+
Return the corresponding vector space of type `S` that represents the zero-dimensional or empty space.
135+
This is, with a slight abuse of notation, the zero element of the direct sum of vector spaces.
136+
"""
137+
Base.zero(V::ElementarySpace) = zero(typeof(V))
138+
131139
"""
132140
⊕(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} -> S
133141
oplus(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} -> S

src/tensors/abstracttensor.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -497,21 +497,13 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
497497
else
498498
cod = codomain(t)
499499
dom = domain(t)
500-
local A
500+
T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) :
501+
sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t))
502+
A = zeros(T, dims(cod)..., dims(dom)...)
501503
for (f₁, f₂) in fusiontrees(t)
502504
F = convert(Array, (f₁, f₂))
503-
if !(@isdefined A)
504-
if eltype(F) <: Complex
505-
T = complex(float(scalartype(t)))
506-
elseif eltype(F) <: Integer
507-
T = scalartype(t)
508-
else
509-
T = float(scalartype(t))
510-
end
511-
A = fill(zero(T), (dims(cod)..., dims(dom)...))
512-
end
513505
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
514-
axpy!(1, StridedView(_kron(convert(Array, t[f₁, f₂]), F)), Aslice)
506+
add!(Aslice, StridedView(_kron(convert(Array, t[f₁, f₂]), F)))
515507
end
516508
return A
517509
end

test/bugfixes.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,11 @@
2222
@test w == v
2323
@test scalartype(w) == Float64
2424
end
25+
26+
# https://github.com/Jutho/TensorKit.jl/issues/178
27+
@testset "Issue #178" begin
28+
t = rand(U1Space(1 => 1) U1Space(1 => 1)')
29+
a = convert(Array, t)
30+
@test a == zeros(size(a))
31+
end
2532
end

test/spaces.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,14 @@ println("------------------------------------")
6666
@test length(sectors(V)) == 1
6767
@test @constinferred(TensorKit.hassector(V, Trivial()))
6868
@test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial()))
69-
@test dim(@constinferred(typeof(V)())) == 0
70-
@test (sectors(typeof(V)())...,) == ()
69+
@test dim(@constinferred(zero(V))) == 0
70+
@test (sectors(zero(V))...,) == ()
7171
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(d)
7272
@test^d == ℝ[](d) == CartesianSpace(d) == typeof(V)(d)
7373
W = @constinferred^1
7474
@test @constinferred(oneunit(V)) == W == oneunit(typeof(V))
75+
@test @constinferred(zero(V)) ==^0 == zero(typeof(V))
76+
@test @constinferred((V, zero(V))) == V
7577
@test @constinferred((V, V)) ==^(2d)
7678
@test @constinferred((V, oneunit(V))) ==^(d + 1)
7779
@test @constinferred((V, V, V, V)) ==^(4d)
@@ -111,12 +113,14 @@ println("------------------------------------")
111113
@test length(sectors(V)) == 1
112114
@test @constinferred(TensorKit.hassector(V, Trivial()))
113115
@test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial()))
114-
@test dim(@constinferred(typeof(V)())) == 0
115-
@test (sectors(typeof(V)())...,) == ()
116+
@test dim(@constinferred(zero(V))) == 0
117+
@test (sectors(zero(V))...,) == ()
116118
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(d)
117119
@test^d == Vect[Trivial](d) == Vect[](Trivial() => d) == ℂ[](d) == typeof(V)(d)
118120
W = @constinferred^1
119121
@test @constinferred(oneunit(V)) == W == oneunit(typeof(V))
122+
@test @constinferred(zero(V)) ==^0 == zero(typeof(V))
123+
@test @constinferred((V, zero(V))) == V
120124
@test @constinferred((V, V)) ==^(2d)
121125
@test_throws SpaceMismatch ((V, V'))
122126
# promote_except = ErrorException("promotion of types $(typeof(ℝ^d)) and " *
@@ -200,11 +204,12 @@ println("------------------------------------")
200204
@test eval(Meta.parse(sprint(show, V))) == V
201205
@test eval(Meta.parse(sprint(show, typeof(V)))) == typeof(V)
202206
# space with no sectors
203-
@test dim(@constinferred(typeof(V)())) == 0
207+
@test dim(@constinferred(zero(V))) == 0
204208
# space with a single sector
205209
W = @constinferred GradedSpace(one(I) => 1)
206210
@test W == GradedSpace(one(I) => 1, randsector(I) => 0)
207211
@test @constinferred(oneunit(V)) == W == oneunit(typeof(V))
212+
@test @constinferred(zero(V)) == GradedSpace(one(I) => 0)
208213
# randsector never returns trivial sector, so this cannot error
209214
@test_throws ArgumentError GradedSpace(one(I) => 1, randsector(I) => 0, one(I) => 3)
210215
@test eval(Meta.parse(sprint(show, W))) == W
@@ -226,6 +231,7 @@ println("------------------------------------")
226231
if hasfusiontensor(I)
227232
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(dim(V))
228233
end
234+
@test @constinferred((V, zero(V))) == V
229235
@test @constinferred((V, V)) == Vect[I](c => 2dim(V, c) for c in sectors(V))
230236
@test @constinferred((V, V, V, V)) == Vect[I](c => 4dim(V, c) for c in sectors(V))
231237
@test @constinferred((V, oneunit(V))) ==

test/tensors.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ for V in spacelist
126126
@test t === @constinferred TensorMap(t.data, W)
127127
end
128128
end
129+
for T in (Int, Float32, ComplexF64)
130+
t = randn(T, V1 V2 zero(V1))
131+
a = convert(Array, t)
132+
@test norm(a) == 0
133+
end
129134
end
130135
end
131136
@timedtestset "Basic linear algebra" begin
@@ -466,7 +471,7 @@ for V in spacelist
466471
end
467472
end
468473
@testset "empty tensor" begin
469-
t = randn(T, V1 V2, typeof(V1)())
474+
t = randn(T, V1 V2, zero(V1))
470475
@testset "leftorth with $alg" for alg in
471476
(TensorKit.QR(), TensorKit.QRpos(),
472477
TensorKit.QL(), TensorKit.QLpos(),

0 commit comments

Comments
 (0)