Skip to content

Commit 38f8798

Browse files
authored
rework TensorOperations implementation to use backend and allocator (#311)
* rework TensorOperations to use backend and allocator * rework to twist smallest object * fix logic mistake
1 parent cc5e33b commit 38f8798

File tree

1 file changed

+107
-65
lines changed

1 file changed

+107
-65
lines changed

src/tensors/tensoroperations.jl

Lines changed: 107 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -253,100 +253,142 @@ the indices of `A` and `B` according to `(oindA, cindA)` and `(cindB, oindB)` re
253253
"""
254254
function contract!(
255255
C::AbstractTensorMap,
256-
A::AbstractTensorMap, (oindA, cindA)::Index2Tuple,
257-
B::AbstractTensorMap, (cindB, oindB)::Index2Tuple,
258-
(p₁, p₂)::Index2Tuple,
259-
α::Number, β::Number,
256+
A::AbstractTensorMap, pA::Index2Tuple,
257+
B::AbstractTensorMap, pB::Index2Tuple,
258+
pAB::Index2Tuple, α::Number, β::Number,
260259
backend, allocator
261260
)
262-
length(cindA) == length(cindB) ||
261+
length(pA[2]) == length(pB[1]) ||
263262
throw(IndexError("number of contracted indices does not match"))
264-
N₁, N₂ = length(oindA), length(oindB)
265-
266-
# find optimal contraction scheme
267-
hsp = has_shared_permute
268-
ipAB = TupleTools.invperm((p₁..., p₂...))
269-
oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁))
270-
oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂))
263+
N₁, N₂ = length(pA[1]), length(pB[2])
271264

272-
qA = TupleTools.sortperm(cindA)
273-
cindA′ = TupleTools.getindices(cindA, qA)
274-
cindB′ = TupleTools.getindices(cindB, qA)
265+
# find optimal contraction scheme by checking the following options:
266+
# - sorting the contracted inds of A or B to avoid permutations
267+
# - contracting B with A instead to avoid permutations
275268

276-
qB = TupleTools.sortperm(cindB)
277-
cindA′′ = TupleTools.getindices(cindA, qB)
278-
cindB′′ = TupleTools.getindices(cindB, qB)
269+
qA = TupleTools.sortperm(pA[2])
270+
pA′ = Base.setindex(pA, TupleTools.getindices(pA[2], qA), 2)
271+
pB′ = Base.setindex(pB, TupleTools.getindices(pB[1], qA), 1)
279272

280-
dA, dB, dC = dim(A), dim(B), dim(C)
273+
qB = TupleTools.sortperm(pB[1])
274+
pA″ = Base.setindex(pA, TupleTools.getindices(pA[2], qB), 2)
275+
pB″ = Base.setindex(pB, TupleTools.getindices(pB[1], qB), 1)
281276

282277
# keep order A en B, check possibilities for cind
283-
memcost1 = memcost2 = dC * (!hsp(C, (oindAinC, oindBinC)))
284-
memcost1 += dA * (!hsp(A, (oindA, cindA′))) + dB * (!hsp(B, (cindB′, oindB)))
285-
memcost2 += dA * (!hsp(A, (oindA, cindA′′))) + dB * (!hsp(B, (cindB′′, oindB)))
278+
memcost1 = TO.contract_memcost(C, A, pA′, B, pB′, pAB)
279+
memcost2 = TO.contract_memcost(C, A, pA″, B, pB″, pAB)
286280

287281
# reverse order A en B, check possibilities for cind
288-
memcost3 = memcost4 = dC * (!hsp(C, (oindBinC, oindAinC)))
289-
memcost3 += dB * (!hsp(B, (oindB, cindB′))) + dA * (!hsp(A, (cindA′, oindA)))
290-
memcost4 += dB * (!hsp(B, (oindB, cindB′′))) + dA * (!hsp(A, (cindA′′, oindA)))
282+
pAB′ = (
283+
map(n -> ifelse(n > N₁, n - N₁, n + N₂), pAB[1]),
284+
map(n -> ifelse(n > N₁, n - N₁, n + N₂), pAB[2]),
285+
)
286+
memcost3 = TO.contract_memcost(C, B, reverse(pB′), A, reverse(pA′), pAB′)
287+
memcost4 = TO.contract_memcost(C, B, reverse(pB″), A, reverse(pA″), pAB′)
291288

292289
return if min(memcost1, memcost2) <= min(memcost3, memcost4)
293290
if memcost1 <= memcost2
294-
return _contract!, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂)
291+
return blas_contract!(C, A, pA′, B, pB′, pAB, α, β, backend, allocator)
295292
else
296-
return _contract!, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂)
293+
return blas_contract!(C, A, pA″, B, pB″, pAB, α, β, backend, allocator)
297294
end
298295
else
299-
p1′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₁)
300-
p2′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₂)
301296
if memcost3 <= memcost4
302-
return _contract!, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′)
297+
return blas_contract!(C, B, reverse(pB′), A, reverse(pA′), pAB′, α, β, backend, allocator)
303298
else
304-
return _contract!, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′)
299+
return blas_contract!(C, B, reverse(pB″), A, reverse(pA″), pAB′, α, β, backend, allocator)
305300
end
306301
end
307302
end
308303

309-
# TODO: also transform _contract! into new interface, and add backend support
310-
function _contract!(
311-
α, A::AbstractTensorMap, B::AbstractTensorMap,
312-
β, C::AbstractTensorMap,
313-
oindA::IndexTuple, cindA::IndexTuple,
314-
oindB::IndexTuple, cindB::IndexTuple,
315-
p₁::IndexTuple, p₂::IndexTuple
304+
function TO.contract_memcost(
305+
C::AbstractTensorMap,
306+
A::AbstractTensorMap, pA::Index2Tuple,
307+
B::AbstractTensorMap, pB::Index2Tuple,
308+
pAB::Index2Tuple
309+
)
310+
ipAB = TO.oindABinC(pAB, pA, pB)
311+
return dim(A) * (!TO.isblascontractable(A, pA) || eltype(A) !== eltype(C)) +
312+
dim(B) * (!TO.isblascontractable(B, pB) || eltype(B) !== eltype(C)) +
313+
dim(C) * !TO.isblasdestination(C, ipAB)
314+
end
315+
316+
function TO.isblascontractable(A::AbstractTensorMap, pA::Index2Tuple)
317+
return eltype(A) <: LinearAlgebra.BlasFloat && has_shared_permute(A, pA)
318+
end
319+
function TO.isblasdestination(A::AbstractTensorMap, ipAB::Index2Tuple)
320+
return eltype(A) <: LinearAlgebra.BlasFloat && has_shared_permute(A, ipAB)
321+
end
322+
323+
function blas_contract!(
324+
C::AbstractTensorMap,
325+
A::AbstractTensorMap, pA::Index2Tuple,
326+
B::AbstractTensorMap, pB::Index2Tuple,
327+
pAB::Index2Tuple, α, β,
328+
backend, allocator
316329
)
317-
if !(BraidingStyle(sectortype(C)) isa SymmetricBraiding)
330+
bstyle = BraidingStyle(sectortype(C))
331+
bstyle isa SymmetricBraiding ||
318332
throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
319-
end
320-
N₁, N₂ = length(oindA), length(oindB)
321-
copyA = false
322-
if BraidingStyle(sectortype(A)) isa Fermionic
323-
for i in cindA
324-
if !isdual(space(A, i))
325-
copyA = true
326-
end
333+
TC = eltype(C)
334+
335+
# check which tensors have to be permuted/copied
336+
copyA = !(TO.isblascontractable(A, pA) && eltype(A) === TC)
337+
copyB = !(TO.isblascontractable(B, pB) && eltype(B) === TC)
338+
339+
if bstyle isa Fermionic && any(isdual Base.Fix1(space, B), pB[1])
340+
# twist smallest object if neither or both already have to be permuted
341+
# otherwise twist the one that already is copied
342+
if !(copyA copyB)
343+
twistA = dim(A) < dim(B)
344+
else
345+
twistA = copyA
327346
end
347+
twistB = !twistA
348+
copyA |= twistA
349+
copyB |= twistB
350+
else
351+
twistA = false
352+
twistB = false
328353
end
329-
A′ = permute(A, (oindA, cindA); copy = copyA)
330-
B′ = permute(B, (cindB, oindB))
331-
if BraidingStyle(sectortype(A)) isa Fermionic
332-
for i in domainind(A′)
333-
if !isdual(space(A′, i))
334-
A′ = twist!(A′, i)
335-
end
336-
end
337-
# A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′)))
338-
# commented version leads to boxing of `A′` and type instabilities in the result
354+
355+
# Bring A in the correct form for BLAS contraction
356+
if copyA
357+
Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator)
358+
Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
359+
twistA && twist!(Anew, filter(!isdual Base.Fix1(space, Anew), domainind(Anew)))
360+
else
361+
Anew = permute(A, pA)
339362
end
340-
ipAB = TupleTools.invperm((p₁..., p₂...))
341-
oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁))
342-
oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂))
343-
if has_shared_permute(C, (oindAinC, oindBinC))
344-
C′ = permute(C, (oindAinC, oindBinC))
345-
mul!(C′, A′, B′, α, β)
363+
pAnew = (codomainind(Anew), domainind(Anew))
364+
365+
# Bring B in the correct form for BLAS contraction
366+
if copyB
367+
Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator)
368+
Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator)
369+
twistB && twist!(Bnew, filter(isdual Base.Fix1(space, Bnew), codomainind(Bnew)))
346370
else
347-
C′ = A′ * B′
348-
add_permute!(C, C′, (p₁, p₂), α, β)
371+
Bnew = permute(B, pB)
349372
end
373+
pBnew = (codomainind(Bnew), domainind(Bnew))
374+
375+
# Bring C in the correct form for BLAS contraction
376+
ipAB = TO.oindABinC(pAB, pAnew, pBnew)
377+
copyC = !TO.isblasdestination(C, ipAB)
378+
379+
if copyC
380+
Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)
381+
mul!(Cnew, Anew, Bnew)
382+
TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator)
383+
TO.tensorfree!(Cnew, allocator)
384+
else
385+
Cnew = permute(C, ipAB)
386+
mul!(Cnew, Anew, Bnew, α, β)
387+
end
388+
389+
copyA && TO.tensorfree!(Anew, allocator)
390+
copyB && TO.tensorfree!(Bnew, allocator)
391+
350392
return C
351393
end
352394

0 commit comments

Comments
 (0)