Skip to content

Commit 4632dcd

Browse files
committed
streamline tensorcontraction errors
1 parent 402e945 commit 4632dcd

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

src/spaces/homspace.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,30 @@ function compose(W::HomSpace{S}, V::HomSpace{S}) where {S}
208208
return HomSpace(codomain(W), domain(V))
209209
end
210210

211+
function TensorOperations.tensorcontract(
212+
A::HomSpace, pA::Index2Tuple, conjA::Bool,
213+
B::HomSpace, pB::Index2Tuple, conjB::Bool,
214+
pAB::Index2Tuple
215+
)
216+
return if conjA && conjB
217+
A′ = A'
218+
pA′ = adjointtensorindices(A, pA)
219+
B′ = B'
220+
pB′ = adjointtensorindices(B, pB)
221+
TensorOperations.tensorcontract(A′, pA′, false, B′, pB′, false, pAB)
222+
elseif conjA
223+
A′ = A'
224+
pA′ = adjointtensorindices(A, pA)
225+
TensorOperations.tensorcontract(A′, pA′, false, B, pB, false, pAB)
226+
elseif conjB
227+
B′ = B'
228+
pB′ = adjointtensorindices(B, pB)
229+
TensorOperations.tensorcontract(A, pA, false, B′, pB′, false, pAB)
230+
else
231+
return permute(compose(permute(A, pA), permute(B, pB)), pAB)
232+
end
233+
end
234+
211235
"""
212236
insertleftunit(W::HomSpace, i=numind(W) + 1; conj=false, dual=false)
213237

src/tensors/abstracttensor.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,15 @@ function allind(::Type{TT}) where {TT <: Union{HomSpace, AbstractTensorMap}}
154154
end
155155
allind(t::Union{HomSpace, AbstractTensorMap}) = allind(typeof(t))
156156

157-
function adjointtensorindex(t::AbstractTensorMap, i)
157+
function adjointtensorindex(t, i)
158158
return ifelse(i <= numout(t), numin(t) + i, i - numout(t))
159159
end
160160

161-
function adjointtensorindices(t::AbstractTensorMap, indices::IndexTuple)
161+
function adjointtensorindices(t, indices::IndexTuple)
162162
return map(i -> adjointtensorindex(t, i), indices)
163163
end
164164

165-
function adjointtensorindices(t::AbstractTensorMap, p::Index2Tuple)
165+
function adjointtensorindices(t, p::Index2Tuple)
166166
return (adjointtensorindices(t, p[1]), adjointtensorindices(t, p[2]))
167167
end
168168

src/tensors/tensoroperations.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,35 @@ function TO.tensortrace!(
9090
end
9191

9292
# tensorcontract!
93+
function spacecheck_contract(
94+
C::AbstractTensorMap,
95+
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
96+
B::AbstractTensorMap, pB::Index2Tuple, conjB::Bool,
97+
pAB::Index2Tuple
98+
)
99+
return spacecheck_contract(space(C), space(A), pA, conjA, space(B), pB, conjB, pAB)
100+
end
101+
@noinline function spacecheck_contract(
102+
VC::TensorMapSpace,
103+
VA::TensorMapSpace, pA::Index2Tuple, conjA::Bool,
104+
VB::TensorMapSpace, pB::Index2Tuple, conjB::Bool,
105+
pAB::Index2Tuple
106+
)
107+
spacetype(VC) == spacetype(VA) == spacetype(VB) || throw(SectorMismatch("incompatible sector types"))
108+
TO.tensorcontract(VA, pA, conjA, VB, pB, conjB, pAB) == VC ||
109+
throw(
110+
SpaceMismatch(
111+
lazy"""
112+
incompatible spaces for `tensorcontract(VA, $pA, $conjA, VB, $pB, $conjB, $pAB) -> VC`
113+
VA = $VA
114+
VB = $VB
115+
VC = $VC
116+
"""
117+
)
118+
)
119+
return nothing
120+
end
121+
93122
function TO.tensorcontract!(
94123
C::AbstractTensorMap,
95124
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
@@ -98,6 +127,7 @@ function TO.tensorcontract!(
98127
backend, allocator
99128
)
100129
pAB′ = _canonicalize(pAB, C)
130+
@boundscheck spacecheck_contract(C, A, pA, conjA, B, pB, conjB, pAB′)
101131
if conjA && conjB
102132
A′ = A'
103133
pA′ = adjointtensorindices(A, pA)

0 commit comments

Comments
 (0)