Skip to content

Commit 3a06898

Browse files
authored
Error message improvements (#309)
1 parent 3d65453 commit 3a06898

File tree

6 files changed

+160
-57
lines changed

6 files changed

+160
-57
lines changed

docs/src/lib/spaces.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ the resuling `HomSpace` after applying certain tensor operations.
123123

124124
```@docs
125125
flip(W::HomSpace{S}, I) where {S}
126-
TensorKit.permute(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
126+
TensorKit.permute(::HomSpace, ::Index2Tuple)
127127
TensorKit.select(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
128128
TensorKit.compose(::HomSpace{S}, ::HomSpace{S}) where {S}
129129
insertleftunit(::HomSpace, ::Val{i}) where {i}

src/TensorKit.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,11 @@ struct SpaceMismatch{S <: Union{Nothing, AbstractString}} <: TensorException
178178
message::S
179179
end
180180
SpaceMismatch() = SpaceMismatch{Nothing}(nothing)
181-
Base.showerror(io::IO, ::SpaceMismatch{Nothing}) = print(io, "SpaceMismatch()")
182-
Base.showerror(io::IO, e::SpaceMismatch) = print(io, "SpaceMismatch(\"", e.message, "\")")
181+
function Base.showerror(io::IO, err::SpaceMismatch)
182+
print(io, "SpaceMismatch: ")
183+
isnothing(err.message) || print(io, err.message)
184+
return nothing
185+
end
183186

184187
# Exception type for all errors related to invalid tensor index specification.
185188
struct IndexError{S <: Union{Nothing, AbstractString}} <: TensorException

src/spaces/homspace.jl

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,15 @@ end
3939

4040
spacetype(::Type{<:HomSpace{S}}) where {S} = S
4141

42-
numout(W::HomSpace) = length(codomain(W))
43-
numin(W::HomSpace) = length(domain(W))
44-
numind(W::HomSpace) = numin(W) + numout(W)
45-
4642
const TensorSpace{S <: ElementarySpace} = Union{S, ProductSpace{S}}
4743
const TensorMapSpace{S <: ElementarySpace, N₁, N₂} = HomSpace{
4844
S, ProductSpace{S, N₁},
4945
ProductSpace{S, N₂},
5046
}
5147

48+
numout(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₁
49+
numin(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₂
50+
5251
function Base.getindex(W::TensorMapSpace{<:IndexSpace, N₁, N₂}, i) where {N₁, N₂}
5352
return i <= N₁ ? codomain(W)[i] : dual(domain(W)[i - N₁])
5453
end
@@ -137,18 +136,33 @@ fusiontrees(W::HomSpace) = fusionblockstructure(W).fusiontreelist
137136
# Operations on HomSpaces
138137
# -----------------------
139138
"""
140-
permute(W::HomSpace, (p₁, p₂)::Index2Tuple{N₁,N₂})
139+
permute(W::HomSpace, (p₁, p₂)::Index2Tuple)
141140
142141
Return the `HomSpace` obtained by permuting the indices of the domain and codomain of `W`
143142
according to the permutation `p₁` and `p₂` respectively.
144143
"""
145-
function permute(W::HomSpace{S}, (p₁, p₂)::Index2Tuple{N₁, N₂}) where {S, N₁, N₂}
144+
function permute(W::HomSpace, (p₁, p₂)::Index2Tuple)
146145
p = (p₁..., p₂...)
147146
TupleTools.isperm(p) && length(p) == numind(W) ||
148147
throw(ArgumentError("$((p₁, p₂)) is not a valid permutation for $(W)"))
149148
return select(W, (p₁, p₂))
150149
end
151150

151+
_transpose_indices(W::HomSpace) = (reverse(domainind(W)), reverse(codomainind(W)))
152+
153+
function LinearAlgebra.transpose(W::HomSpace, (p₁, p₂)::Index2Tuple = _transpose_indices(W))
154+
p = linearizepermutation(p₁, p₂, numout(W), numin(W))
155+
iscyclicpermutation(p) || throw(ArgumentError(lazy"$((p₁, p₂)) is not a cyclic permutation for $W"))
156+
return select(W, (p₁, p₂))
157+
end
158+
159+
function braid(W::HomSpace, (p₁, p₂)::Index2Tuple, levels::IndexTuple)
160+
p = (p₁..., p₂...)
161+
TupleTools.isperm(p) && length(p) == numind(W) == length(levels) ||
162+
throw(ArgumentError("$((p₁, p₂)), $levels is not a valid braiding for $(W)"))
163+
return select(W, (p₁, p₂))
164+
end
165+
152166
"""
153167
select(W::HomSpace, (p₁, p₂)::Index2Tuple{N₁,N₂})
154168
@@ -188,6 +202,30 @@ function compose(W::HomSpace{S}, V::HomSpace{S}) where {S}
188202
return HomSpace(codomain(W), domain(V))
189203
end
190204

205+
function TensorOperations.tensorcontract(
206+
A::HomSpace, pA::Index2Tuple, conjA::Bool,
207+
B::HomSpace, pB::Index2Tuple, conjB::Bool,
208+
pAB::Index2Tuple
209+
)
210+
return if conjA && conjB
211+
A′ = A'
212+
pA′ = adjointtensorindices(A, pA)
213+
B′ = B'
214+
pB′ = adjointtensorindices(B, pB)
215+
TensorOperations.tensorcontract(A′, pA′, false, B′, pB′, false, pAB)
216+
elseif conjA
217+
A′ = A'
218+
pA′ = adjointtensorindices(A, pA)
219+
TensorOperations.tensorcontract(A′, pA′, false, B, pB, false, pAB)
220+
elseif conjB
221+
B′ = B'
222+
pB′ = adjointtensorindices(B, pB)
223+
TensorOperations.tensorcontract(A, pA, false, B′, pB′, false, pAB)
224+
else
225+
return permute(compose(permute(A, pA), permute(B, pB)), pAB)
226+
end
227+
end
228+
191229
"""
192230
insertleftunit(W::HomSpace, i=numind(W) + 1; conj=false, dual=false)
193231

src/tensors/abstracttensor.jl

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -89,80 +89,84 @@ domain(t::AbstractTensorMap) = domain(space(t))
8989
domain(t::AbstractTensorMap, i) = domain(t)[i]
9090
source(t::AbstractTensorMap) = domain(t) # categorical terminology
9191

92-
"""
93-
numout(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Int
92+
@doc """
93+
numout(x) -> Int
94+
numout(T::Type) -> Int
9495
95-
Return the number of output spaces of a tensor. This is equivalent to the number of spaces in the codomain of that tensor.
96+
Return the length of the codomain, i.e. the number of output spaces.
97+
By default, this is implemented in the type domain.
9698
9799
See also [`numin`](@ref) and [`numind`](@ref).
98-
"""
100+
""" numout
101+
102+
numout(x) = numout(typeof(x))
103+
numout(T::Type) = throw(MethodError(numout, T)) # avoid infinite recursion
99104
numout(::Type{<:AbstractTensorMap{T, S, N₁}}) where {T, S, N₁} = N₁
100105

101-
"""
102-
numin(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Int
106+
@doc """
107+
numin(x) -> Int
108+
numin(T::Type) -> Int
103109
104-
Return the number of input spaces of a tensor. This is equivalent to the number of spaces in the domain of that tensor.
110+
Return the length of the domain, i.e. the number of input spaces.
111+
By default, this is implemented in the type domain.
105112
106113
See also [`numout`](@ref) and [`numind`](@ref).
107-
"""
114+
""" numin
115+
116+
numin(x) = numin(typeof(x))
117+
numin(T::Type) = throw(MethodError(numin, T)) # avoid infinite recursion
108118
numin(::Type{<:AbstractTensorMap{T, S, N₁, N₂}}) where {T, S, N₁, N₂} = N₂
109119

110120
"""
111-
numind(::Union{T,Type{T}}) where {T<:AbstractTensorMap} -> Int
121+
numind(x) -> Int
122+
numind(T::Type) -> Int
123+
order(x) = numind(x)
112124
113-
Return the total number of input and output spaces of a tensor. This is equivalent to the
114-
total number of spaces in the domain and codomain of that tensor.
125+
Return the total number of input and output spaces, i.e. `numin(x) + numout(x)`.
126+
Alternatively, the alias `order` can also be used.
115127
116128
See also [`numout`](@ref) and [`numin`](@ref).
117129
"""
118-
numind(::Type{TT}) where {TT <: AbstractTensorMap} = numin(TT) + numout(TT)
130+
numind(x) = numin(x) + numout(x)
131+
119132
const order = numind
120133

121134
"""
122-
codomainind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int}
135+
codomainind(x) -> Tuple{Int}
123136
124-
Return all indices of the codomain of a tensor.
137+
Return all indices of the codomain.
125138
126139
See also [`domainind`](@ref) and [`allind`](@ref).
127140
"""
128-
function codomainind(::Type{TT}) where {TT <: AbstractTensorMap}
129-
return ntuple(identity, numout(TT))
130-
end
131-
codomainind(t::AbstractTensorMap) = codomainind(typeof(t))
141+
codomainind(x) = ntuple(identity, numout(x))
132142

133143
"""
134-
domainind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int}
144+
domainind(x) -> Tuple{Int}
135145
136-
Return all indices of the domain of a tensor.
146+
Return all indices of the domain.
137147
138148
See also [`codomainind`](@ref) and [`allind`](@ref).
139149
"""
140-
function domainind(::Type{TT}) where {TT <: AbstractTensorMap}
141-
return ntuple(n -> numout(TT) + n, numin(TT))
142-
end
143-
domainind(t::AbstractTensorMap) = domainind(typeof(t))
150+
domainind(x) = ntuple(n -> numout(x) + n, numin(x))
144151

145152
"""
146-
allind(::Union{TT,Type{TT}}) where {TT<:AbstractTensorMap} -> Tuple{Int}
153+
allind(x) -> Tuple{Int}
147154
148-
Return all indices of a tensor, i.e. the indices of its domain and codomain.
155+
Return all indices, i.e. the indices of both domain and codomain.
149156
150157
See also [`codomainind`](@ref) and [`domainind`](@ref).
151158
"""
152-
function allind(::Type{TT}) where {TT <: AbstractTensorMap}
153-
return ntuple(identity, numind(TT))
154-
end
155-
allind(t::AbstractTensorMap) = allind(typeof(t))
159+
allind(x) = ntuple(identity, numind(x))
156160

157-
function adjointtensorindex(t::AbstractTensorMap, i)
161+
function adjointtensorindex(t, i)
158162
return ifelse(i <= numout(t), numin(t) + i, i - numout(t))
159163
end
160164

161-
function adjointtensorindices(t::AbstractTensorMap, indices::IndexTuple)
165+
function adjointtensorindices(t, indices::IndexTuple)
162166
return map(i -> adjointtensorindex(t, i), indices)
163167
end
164168

165-
function adjointtensorindices(t::AbstractTensorMap, p::Index2Tuple)
169+
function adjointtensorindices(t, p::Index2Tuple)
166170
return (adjointtensorindices(t, p[1]), adjointtensorindices(t, p[2]))
167171
end
168172

src/tensors/indexmanipulations.jl

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ the permutation `(p₁..., reverse(p₂)...)` should constitute a cyclic permuta
175175
176176
See [`transpose`](@ref) for creating a new tensor and [`add_transpose!`](@ref) for a more general version.
177177
"""
178-
function LinearAlgebra.transpose!(
179-
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(t)
178+
@propagate_inbounds function LinearAlgebra.transpose!(
179+
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(tsrc)
180180
)
181181
return add_transpose!(tdst, tsrc, (p₁, p₂), One(), Zero())
182182
end
@@ -229,7 +229,7 @@ case of a transposition that only changes the number of in- and outgoing indices
229229
230230
See [`repartition`](@ref) for creating a new tensor.
231231
"""
232-
function repartition!(tdst::AbstractTensorMap{S}, tsrc::AbstractTensorMap{S}) where {S}
232+
@propagate_inbounds function repartition!(tdst::AbstractTensorMap{S}, tsrc::AbstractTensorMap{S}) where {S}
233233
numind(tsrc) == numind(tdst) ||
234234
throw(ArgumentError("tsrc and tdst should have an equal amount of indices"))
235235
all_inds = (codomainind(tsrc)..., reverse(domainind(tsrc))...)
@@ -410,6 +410,38 @@ end
410410
#-------------------------------------
411411
# Full implementations based on `add`
412412
#-------------------------------------
413+
spacecheck_transform(f, tdst::AbstractTensorMap, tsrc::AbstractTensorMap, args...) =
414+
spacecheck_transform(f, space(tdst), space(tsrc), args...)
415+
@noinline function spacecheck_transform(f, Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple)
416+
spacetype(Vdst) == spacetype(Vsrc) || throw(SectorMismatch("incompatible sector types"))
417+
f(Vsrc, p) == Vdst ||
418+
throw(
419+
SpaceMismatch(
420+
lazy"""
421+
incompatible spaces for `$f(Vsrc, $p) -> Vdst`
422+
Vsrc = $Vsrc
423+
Vdst = $Vdst
424+
"""
425+
)
426+
)
427+
return nothing
428+
end
429+
@noinline function spacecheck_transform(::typeof(braid), Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple, levels::IndexTuple)
430+
spacetype(Vdst) == spacetype(Vsrc) || throw(SectorMismatch("incompatible sector types"))
431+
braid(Vsrc, p, levels) == Vdst ||
432+
throw(
433+
SpaceMismatch(
434+
lazy"""
435+
incompatible spaces for `braid(Vsrc, $p, $levels) -> Vdst`
436+
Vsrc = $Vsrc
437+
Vdst = $Vdst
438+
"""
439+
)
440+
)
441+
return nothing
442+
end
443+
444+
413445
"""
414446
add_permute!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple,
415447
α::Number, β::Number, backend::AbstractBackend...)
@@ -423,8 +455,9 @@ See also [`permute`](@ref), [`permute!`](@ref), [`add_braid!`](@ref), [`add_tran
423455
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple,
424456
α::Number, β::Number, backend::AbstractBackend...
425457
)
458+
@boundscheck spacecheck_transform(permute, tdst, tsrc, p)
426459
transformer = treepermuter(tdst, tsrc, p)
427-
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
460+
return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
428461
end
429462

430463
"""
@@ -440,14 +473,12 @@ See also [`braid`](@ref), [`braid!`](@ref), [`add_permute!`](@ref), [`add_transp
440473
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, levels::IndexTuple,
441474
α::Number, β::Number, backend::AbstractBackend...
442475
)
443-
length(levels) == numind(tsrc) ||
444-
throw(ArgumentError("incorrect levels $levels for tensor map $(codomain(tsrc))$(domain(tsrc))"))
445-
476+
@boundscheck spacecheck_transform(braid, tdst, tsrc, p, levels)
446477
levels1 = TupleTools.getindices(levels, codomainind(tsrc))
447478
levels2 = TupleTools.getindices(levels, domainind(tsrc))
448479
# TODO: arg order for tensormaps is different than for fusiontrees
449480
transformer = treebraider(tdst, tsrc, p, (levels1, levels2))
450-
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
481+
return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
451482
end
452483

453484
"""
@@ -463,19 +494,16 @@ See also [`transpose`](@ref), [`transpose!`](@ref), [`add_permute!`](@ref), [`ad
463494
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple,
464495
α::Number, β::Number, backend::AbstractBackend...
465496
)
497+
@boundscheck spacecheck_transform(transpose, tdst, tsrc, p)
466498
transformer = treetransposer(tdst, tsrc, p)
467-
return add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
499+
return @inbounds add_transform!(tdst, tsrc, p, transformer, α, β, backend...)
468500
end
469501

470-
function add_transform!(
502+
@propagate_inbounds function add_transform!(
471503
tdst::AbstractTensorMap, tsrc::AbstractTensorMap, p::Index2Tuple, transformer,
472504
α::Number, β::Number, backend::AbstractBackend...
473505
)
474-
@boundscheck begin
475-
permute(space(tsrc), p) == space(tdst) ||
476-
throw(SpaceMismatch("source = $(codomain(tsrc))$(domain(tsrc)),
477-
dest = $(codomain(tdst))$(domain(tdst)), p₁ = $(p[1]), p₂ = $(p[2])"))
478-
end
506+
@boundscheck spacecheck_transform(permute, tdst, tsrc, p)
479507

480508
if p[1] === codomainind(tsrc) && p[2] === domainind(tsrc)
481509
add!(tdst, tsrc, α, β)

src/tensors/tensoroperations.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,35 @@ function TO.tensortrace!(
9090
end
9191

9292
# tensorcontract!
93+
function spacecheck_contract(
94+
C::AbstractTensorMap,
95+
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
96+
B::AbstractTensorMap, pB::Index2Tuple, conjB::Bool,
97+
pAB::Index2Tuple
98+
)
99+
return spacecheck_contract(space(C), space(A), pA, conjA, space(B), pB, conjB, pAB)
100+
end
101+
@noinline function spacecheck_contract(
102+
VC::TensorMapSpace,
103+
VA::TensorMapSpace, pA::Index2Tuple, conjA::Bool,
104+
VB::TensorMapSpace, pB::Index2Tuple, conjB::Bool,
105+
pAB::Index2Tuple
106+
)
107+
spacetype(VC) == spacetype(VA) == spacetype(VB) || throw(SectorMismatch("incompatible sector types"))
108+
TO.tensorcontract(VA, pA, conjA, VB, pB, conjB, pAB) == VC ||
109+
throw(
110+
SpaceMismatch(
111+
lazy"""
112+
incompatible spaces for `tensorcontract(VA, $pA, $conjA, VB, $pB, $conjB, $pAB) -> VC`
113+
VA = $VA
114+
VB = $VB
115+
VC = $VC
116+
"""
117+
)
118+
)
119+
return nothing
120+
end
121+
93122
function TO.tensorcontract!(
94123
C::AbstractTensorMap,
95124
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
@@ -98,6 +127,7 @@ function TO.tensorcontract!(
98127
backend, allocator
99128
)
100129
pAB′ = _canonicalize(pAB, C)
130+
@boundscheck spacecheck_contract(C, A, pA, conjA, B, pB, conjB, pAB′)
101131
if conjA && conjB
102132
A′ = A'
103133
pA′ = adjointtensorindices(A, pA)

0 commit comments

Comments
 (0)