Skip to content

Commit aa1e6d1

Browse files
committed
Add backend/allocator support in TensorOperations
1 parent 2da9b7a commit aa1e6d1

File tree

3 files changed

+105
-101
lines changed

3 files changed

+105
-101
lines changed

docs/src/lib/tensors.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ TensorKit.add_transpose!
200200
```@docs
201201
compose(::AbstractTensorMap, ::AbstractTensorMap)
202202
trace_permute!
203-
contract!
204203
⊗(::AbstractTensorMap, ::AbstractTensorMap)
205204
```
206205

src/planar/planaroperations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ function planarcontract!(C::AbstractTensorMap,
142142
α::Number, β::Number,
143143
backend, allocator)
144144
if BraidingStyle(sectortype(C)) == Bosonic()
145-
return contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
145+
return TO.tensorcontract!(C, A, pA, false, B, pB, false, pAB,
146+
α, β, backend, allocator)
146147
end
147148

148149
codA, domA = codomainind(A), domainind(A)

src/tensors/tensoroperations.jl

Lines changed: 103 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,17 @@ function TO.tensorcontract!(C::AbstractTensorMap,
9494
pA′ = adjointtensorindices(A, pA)
9595
B′ = B'
9696
pB′ = adjointtensorindices(B, pB)
97-
contract!(C, A′, pA′, B′, pB′, pAB′, α, β, backend, allocator)
97+
TO.blas_contract!(C, A′, pA′, B′, pB′, pAB′, α, β, backend, allocator)
9898
elseif conjA
9999
A′ = A'
100100
pA′ = adjointtensorindices(A, pA)
101-
contract!(C, A′, pA′, B, pB, pAB′, α, β, backend, allocator)
101+
TO.blas_contract!(C, A′, pA′, B, pB, pAB′, α, β, backend, allocator)
102102
elseif conjB
103103
B′ = B'
104104
pB′ = adjointtensorindices(B, pB)
105-
contract!(C, A, pA, B′, pB′, pAB′, α, β, backend, allocator)
105+
TO.blas_contract!(C, A, pA, B′, pB′, pAB′, α, β, backend, allocator)
106106
else
107-
contract!(C, A, pA, B, pB, pAB′, α, β, backend, allocator)
107+
TO.blas_contract!(C, A, pA, B, pB, pAB′, α, β, backend, allocator)
108108
end
109109
return C
110110
end
@@ -154,6 +154,19 @@ TO.tensorcost(t::AbstractTensorMap, i::Int) = dim(space(t, i))
154154
scheduler::S = SerialScheduler()
155155
end
156156

157+
function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,
158+
A::AbstractTensorMap)
159+
return TensorKitBackend()
160+
end
161+
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,
162+
A::AbstractTensorMap)
163+
return TensorKitBackend()
164+
end
165+
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,
166+
A::AbstractTensorMap, B::AbstractTensorMap)
167+
return TensorKitBackend()
168+
end
169+
157170
# Trace implementation
158171
#----------------------
159172
"""
@@ -232,114 +245,105 @@ end
232245
# TODO: contraction with either A or B a rank (1, 1) tensor does not require to
233246
# permute the fusion tree and should therefore be special cased. This will speed
234247
# up MPS algorithms
235-
"""
236-
contract!(C::AbstractTensorMap,
237-
A::AbstractTensorMap, (oindA, cindA)::Index2Tuple,
238-
B::AbstractTensorMap, (cindB, oindB)::Index2Tuple,
239-
(p₁, p₂)::Index2Tuple,
240-
α::Number, β::Number,
241-
backend, allocator)
242-
243-
Return the updated `C`, which is the result of adding `α * A * B` to `C` after permuting
244-
the indices of `A` and `B` according to `(oindA, cindA)` and `(cindB, oindB)` respectively.
245-
"""
246-
function contract!(C::AbstractTensorMap,
247-
A::AbstractTensorMap, (oindA, cindA)::Index2Tuple,
248-
B::AbstractTensorMap, (cindB, oindB)::Index2Tuple,
249-
(p₁, p₂)::Index2Tuple,
250-
α::Number, β::Number,
251-
backend, allocator)
252-
length(cindA) == length(cindB) ||
253-
throw(IndexError("number of contracted indices does not match"))
254-
N₁, N₂ = length(oindA), length(oindB)
255-
256-
# find optimal contraction scheme
257-
hsp = has_shared_permute
258-
ipAB = TupleTools.invperm((p₁..., p₂...))
259-
oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁))
260-
oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂))
261-
262-
qA = TupleTools.sortperm(cindA)
263-
cindA′ = TupleTools.getindices(cindA, qA)
264-
cindB′ = TupleTools.getindices(cindB, qA)
265-
266-
qB = TupleTools.sortperm(cindB)
267-
cindA′′ = TupleTools.getindices(cindA, qB)
268-
cindB′′ = TupleTools.getindices(cindB, qB)
269-
270-
dA, dB, dC = dim(A), dim(B), dim(C)
271-
272-
# keep order A en B, check possibilities for cind
273-
memcost1 = memcost2 = dC * (!hsp(C, (oindAinC, oindBinC)))
274-
memcost1 += dA * (!hsp(A, (oindA, cindA′))) +
275-
dB * (!hsp(B, (cindB′, oindB)))
276-
memcost2 += dA * (!hsp(A, (oindA, cindA′′))) +
277-
dB * (!hsp(B, (cindB′′, oindB)))
278-
279-
# reverse order A en B, check possibilities for cind
280-
memcost3 = memcost4 = dC * (!hsp(C, (oindBinC, oindAinC)))
281-
memcost3 += dB * (!hsp(B, (oindB, cindB′))) +
282-
dA * (!hsp(A, (cindA′, oindA)))
283-
memcost4 += dB * (!hsp(B, (oindB, cindB′′))) +
284-
dA * (!hsp(A, (cindA′′, oindA)))
285-
286-
if min(memcost1, memcost2) <= min(memcost3, memcost4)
287-
if memcost1 <= memcost2
288-
return _contract!(α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂)
248+
249+
# this is a copy of the TensorOperations implementation, adding two ways to
250+
# permute the contracted indices for 4 total possible implementations
251+
function TO.blas_contract!(C::AbstractTensorMap, A::AbstractTensorMap, pA,
252+
B::AbstractTensorMap, pB, pAB, α, β, backend, allocator)
253+
# index permutations for reverse contraction
254+
indCinoBA = let N₁ = TO.numout(pA), N₂ = TO.numin(pB)
255+
map(n -> ifelse(n > N₁, n - N₁, n + N₂), TO.linearize(pAB))
256+
end
257+
tpAB = TO.trivialpermutation(pAB)
258+
pBA = (TupleTools.getindices(indCinoBA, tpAB[1]),
259+
TupleTools.getindices(indCinoBA, tpAB[2]))
260+
261+
# permutations of contracted indices
262+
qA = TupleTools.sortperm(pA[2])
263+
pA′ = (pA[1], TupleTools.getindices(pA[2], qA))
264+
pB′ = (TupleTools.getindices(pB[1], qA), pB[2])
265+
266+
qB = TupleTools.sortperm(pB[1])
267+
pA″ = (pA[1], TupleTools.getindices(pA[2], qB))
268+
pB″ = (TupleTools.getindices(pB[1], qB), pB[2])
269+
270+
memcost1 = TO.contract_memcost(C, A, pA′, B, pB′, pAB)
271+
memcost2 = TO.contract_memcost(C, A, pA″, B, pB″, pAB)
272+
memcost3 = TO.contract_memcost(C, B, reverse(pB′), A, reverse(pA′), pBA)
273+
memcost4 = TO.contract_memcost(C, B, reverse(pB″), A, reverse(pA″), pBA)
274+
275+
return if min(memcost1, memcost2) min(memcost3, memcost4)
276+
if memcost1 memcost2
277+
_blas_contract!(C, A, pA′, B, pB′, pAB, α, β, backend, allocator)
289278
else
290-
return _contract!, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂)
279+
_blas_contract!(C, A, pA″, B, pB″, pAB, α, β, backend, allocator)
291280
end
292281
else
293-
p1′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₁)
294-
p2′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₂)
295-
if memcost3 <= memcost4
296-
return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′)
282+
if memcost3 memcost4
283+
_blas_contract!(C, B, reverse(pB′), A, reverse(pA′), pBA, α, β, backend,
284+
allocator)
297285
else
298-
return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′)
286+
_blas_contract!(C, B, reverse(pB″), A, reverse(pA″), pBA, α, β, backend,
287+
allocator)
299288
end
300289
end
301290
end
302291

303-
# TODO: also transform _contract! into new interface, and add backend support
304-
function _contract!(α, A::AbstractTensorMap, B::AbstractTensorMap,
305-
β, C::AbstractTensorMap,
306-
oindA::IndexTuple, cindA::IndexTuple,
307-
oindB::IndexTuple, cindB::IndexTuple,
308-
p₁::IndexTuple, p₂::IndexTuple)
309-
if !(BraidingStyle(sectortype(C)) isa SymmetricBraiding)
310-
throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
311-
end
312-
N₁, N₂ = length(oindA), length(oindB)
313-
copyA = false
314-
if BraidingStyle(sectortype(A)) isa Fermionic
315-
for i in cindA
316-
if !isdual(space(A, i))
317-
copyA = true
318-
end
319-
end
292+
function TO.contract_memcost(C::AbstractTensorMap, A::AbstractTensorMap, pA,
293+
B::AbstractTensorMap, pB, pAB)
294+
ipAB = TO.oindABinC(pAB, pA, pB)
295+
return dim(A) * !isblascontractable(A, pA) + dim(B) * !isblascontractable(B, pB) +
296+
dim(C) * !isblasdestination(C, ipAB)
297+
end
298+
299+
# TODO: delibarately not importing private TO functions from here on out. Should we?
300+
function _blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
301+
TC = eltype(C)
302+
303+
A_, pA, flagA = makeblascontractable(A, pA, TC, backend, allocator, true)
304+
B_, pB, flagB = makeblascontractable(B, pB, TC, backend, allocator, false)
305+
306+
ipAB = TO.oindABinC(pAB, pA, pB)
307+
flagC = isblasdestination(C, ipAB)
308+
if flagC
309+
mul!(C, A_, B_, α, β, backend)
310+
else
311+
C_ = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)
312+
mul!(C_, A_, B_, One(), Zero(), backend)
313+
TO.tensoradd!(C, C_, pAB, false, α, β, backend, allocator)
314+
TO.tensorfree!(C_, allocator)
320315
end
321-
A′ = permute(A, (oindA, cindA); copy=copyA)
322-
B′ = permute(B, (cindB, oindB))
323-
if BraidingStyle(sectortype(A)) isa Fermionic
324-
for i in domainind(A′)
325-
if !isdual(space(A′, i))
326-
A′ = twist!(A′, i)
316+
flagA || TO.tensorfree!(A_, allocator)
317+
flagB || TO.tensorfree!(B_, allocator)
318+
return C
319+
end
320+
321+
isblascontractable(A, pA) = (pA[1] == codomainind(A) && pA[2] == domainind(A))
322+
function isblasdestination(A::AbstractTensorMap, p::Index2Tuple)
323+
return (p[1] == codomainind(A) && p[2] == domainind(A))
324+
end
325+
326+
@inline function makeblascontractable(A::AbstractTensorMap, pA, TC, backend, allocator,
327+
dotwist::Bool=false)
328+
flagA = (scalartype(A) === TC && isblascontractable(A, pA) && !dotwist)
329+
if !flagA
330+
A_ = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator)
331+
Anew = TO.tensoradd!(A_, A, pA, false, One(), Zero(), backend, allocator)
332+
if dotwist && (BraidingStyle(sectortype(A)) isa Fermionic)
333+
for i in domainind(Anew)
334+
if !isdual(space(Anew, i))
335+
twist!(Anew, i)
336+
end
327337
end
338+
# TODO: this seems type-unstable:
339+
# Anew = twist!(Anew, filter(i -> !isdual(space(Anew, i)), domainind(Anew)))
328340
end
329-
# A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′)))
330-
# commented version leads to boxing of `A′` and type instabilities in the result
331-
end
332-
ipAB = TupleTools.invperm((p₁..., p₂...))
333-
oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁))
334-
oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂))
335-
if has_shared_permute(C, (oindAinC, oindBinC))
336-
C′ = permute(C, (oindAinC, oindBinC))
337-
mul!(C′, A′, B′, α, β)
341+
pAnew = TO.trivialpermutation(pA)
338342
else
339-
C′ = A* B′
340-
add_permute!(C, C′, (p₁, p₂), α, β)
343+
Anew = A
344+
pAnew = pA
341345
end
342-
return C
346+
return Anew, pAnew, flagA
343347
end
344348

345349
# Scalar implementation

0 commit comments

Comments
 (0)