@@ -253,100 +253,142 @@ the indices of `A` and `B` according to `(oindA, cindA)` and `(cindB, oindB)` re
253253"""
254254function contract! (
255255 C:: AbstractTensorMap ,
256- A:: AbstractTensorMap , (oindA, cindA):: Index2Tuple ,
257- B:: AbstractTensorMap , (cindB, oindB):: Index2Tuple ,
258- (p₁, p₂):: Index2Tuple ,
259- α:: Number , β:: Number ,
256+ A:: AbstractTensorMap , pA:: Index2Tuple ,
257+ B:: AbstractTensorMap , pB:: Index2Tuple ,
258+ pAB:: Index2Tuple , α:: Number , β:: Number ,
260259 backend, allocator
261260 )
262- length (cindA ) == length (cindB ) ||
261+ length (pA[ 2 ] ) == length (pB[ 1 ] ) ||
263262 throw (IndexError (" number of contracted indices does not match" ))
264- N₁, N₂ = length (oindA), length (oindB)
265-
266- # find optimal contraction scheme
267- hsp = has_shared_permute
268- ipAB = TupleTools. invperm ((p₁... , p₂... ))
269- oindAinC = TupleTools. getindices (ipAB, ntuple (n -> n, N₁))
270- oindBinC = TupleTools. getindices (ipAB, ntuple (n -> n + N₁, N₂))
263+ N₁, N₂ = length (pA[1 ]), length (pB[2 ])
271264
272- qA = TupleTools . sortperm (cindA)
273- cindA′ = TupleTools . getindices (cindA, qA)
274- cindB′ = TupleTools . getindices (cindB, qA)
265+ # find optimal contraction scheme by checking the following options:
266+ # - sorting the contracted inds of A or B to avoid permutations
267+ # - contracting B with A instead to avoid permutations
275268
276- qB = TupleTools. sortperm (cindB )
277- cindA′′ = TupleTools. getindices (cindA, qB )
278- cindB′′ = TupleTools. getindices (cindB, qB )
269+ qA = TupleTools. sortperm (pA[ 2 ] )
270+ pA′ = Base . setindex (pA, TupleTools. getindices (pA[ 2 ], qA), 2 )
271+ pB′ = Base . setindex (pB, TupleTools. getindices (pB[ 1 ], qA), 1 )
279272
280- dA, dB, dC = dim (A), dim (B), dim (C)
273+ qB = TupleTools. sortperm (pB[1 ])
274+ pA″ = Base. setindex (pA, TupleTools. getindices (pA[2 ], qB), 2 )
275+ pB″ = Base. setindex (pB, TupleTools. getindices (pB[1 ], qB), 1 )
281276
282277 # keep order A en B, check possibilities for cind
283- memcost1 = memcost2 = dC * (! hsp (C, (oindAinC, oindBinC)))
284- memcost1 += dA * (! hsp (A, (oindA, cindA′))) + dB * (! hsp (B, (cindB′, oindB)))
285- memcost2 += dA * (! hsp (A, (oindA, cindA′′))) + dB * (! hsp (B, (cindB′′, oindB)))
278+ memcost1 = TO. contract_memcost (C, A, pA′, B, pB′, pAB)
279+ memcost2 = TO. contract_memcost (C, A, pA″, B, pB″, pAB)
286280
287281 # reverse order A en B, check possibilities for cind
288- memcost3 = memcost4 = dC * (! hsp (C, (oindBinC, oindAinC)))
289- memcost3 += dB * (! hsp (B, (oindB, cindB′))) + dA * (! hsp (A, (cindA′, oindA)))
290- memcost4 += dB * (! hsp (B, (oindB, cindB′′))) + dA * (! hsp (A, (cindA′′, oindA)))
282+ pAB′ = (
283+ map (n -> ifelse (n > N₁, n - N₁, n + N₂), pAB[1 ]),
284+ map (n -> ifelse (n > N₁, n - N₁, n + N₂), pAB[2 ]),
285+ )
286+ memcost3 = TO. contract_memcost (C, B, reverse (pB′), A, reverse (pA′), pAB′)
287+ memcost4 = TO. contract_memcost (C, B, reverse (pB″), A, reverse (pA″), pAB′)
291288
292289 return if min (memcost1, memcost2) <= min (memcost3, memcost4)
293290 if memcost1 <= memcost2
294- return _contract! (α , A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂ )
291+ return blas_contract! (C , A, pA′, B, pB′, pAB, α, β, backend, allocator )
295292 else
296- return _contract! (α , A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂ )
293+ return blas_contract! (C , A, pA″, B, pB″, pAB, α, β, backend, allocator )
297294 end
298295 else
299- p1′ = map (n -> ifelse (n > N₁, n - N₁, n + N₂), p₁)
300- p2′ = map (n -> ifelse (n > N₁, n - N₁, n + N₂), p₂)
301296 if memcost3 <= memcost4
302- return _contract! (α , B, A, β, C, oindB, cindB ′, oindA, cindA′, p1′, p2′ )
297+ return blas_contract! (C , B, reverse (pB′), A, reverse (pA′), pAB ′, α, β, backend, allocator )
303298 else
304- return _contract! (α , B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′ )
299+ return blas_contract! (C , B, reverse (pB″), A, reverse (pA″), pAB′, α, β, backend, allocator )
305300 end
306301 end
307302end
308303
309- # TODO : also transform _contract! into new interface, and add backend support
310- function _contract! (
311- α, A:: AbstractTensorMap , B:: AbstractTensorMap ,
312- β, C:: AbstractTensorMap ,
313- oindA:: IndexTuple , cindA:: IndexTuple ,
314- oindB:: IndexTuple , cindB:: IndexTuple ,
315- p₁:: IndexTuple , p₂:: IndexTuple
304+ function TO. contract_memcost (
305+ C:: AbstractTensorMap ,
306+ A:: AbstractTensorMap , pA:: Index2Tuple ,
307+ B:: AbstractTensorMap , pB:: Index2Tuple ,
308+ pAB:: Index2Tuple
309+ )
310+ ipAB = TO. oindABinC (pAB, pA, pB)
311+ return dim (A) * (! TO. isblascontractable (A, pA) || eltype (A) != = eltype (C)) +
312+ dim (B) * (! TO. isblascontractable (B, pB) || eltype (B) != = eltype (C)) +
313+ dim (C) * ! TO. isblasdestination (C, ipAB)
314+ end
315+
316+ function TO. isblascontractable (A:: AbstractTensorMap , pA:: Index2Tuple )
317+ return eltype (A) <: LinearAlgebra.BlasFloat && has_shared_permute (A, pA)
318+ end
319+ function TO. isblasdestination (A:: AbstractTensorMap , ipAB:: Index2Tuple )
320+ return eltype (A) <: LinearAlgebra.BlasFloat && has_shared_permute (A, ipAB)
321+ end
322+
323+ function blas_contract! (
324+ C:: AbstractTensorMap ,
325+ A:: AbstractTensorMap , pA:: Index2Tuple ,
326+ B:: AbstractTensorMap , pB:: Index2Tuple ,
327+ pAB:: Index2Tuple , α, β,
328+ backend, allocator
316329 )
317- if ! (BraidingStyle (sectortype (C)) isa SymmetricBraiding)
330+ bstyle = BraidingStyle (sectortype (C))
331+ bstyle isa SymmetricBraiding ||
318332 throw (SectorMismatch (" only tensors with symmetric braiding rules can be contracted; try `@planar` instead" ))
319- end
320- N₁, N₂ = length (oindA), length (oindB)
321- copyA = false
322- if BraidingStyle (sectortype (A)) isa Fermionic
323- for i in cindA
324- if ! isdual (space (A, i))
325- copyA = true
326- end
333+ TC = eltype (C)
334+
335+ # check which tensors have to be permuted/copied
336+ copyA = ! (TO. isblascontractable (A, pA) && eltype (A) === TC)
337+ copyB = ! (TO. isblascontractable (B, pB) && eltype (B) === TC)
338+
339+ if bstyle isa Fermionic && any (isdual ∘ Base. Fix1 (space, B), pB[1 ])
340+ # twist smallest object if neither or both already have to be permuted
341+ # otherwise twist the one that already is copied
342+ if ! (copyA ⊻ copyB)
343+ twistA = dim (A) < dim (B)
344+ else
345+ twistA = copyA
327346 end
347+ twistB = ! twistA
348+ copyA |= twistA
349+ copyB |= twistB
350+ else
351+ twistA = false
352+ twistB = false
328353 end
329- A′ = permute (A, (oindA, cindA); copy = copyA)
330- B′ = permute (B, (cindB, oindB))
331- if BraidingStyle (sectortype (A)) isa Fermionic
332- for i in domainind (A′)
333- if ! isdual (space (A′, i))
334- A′ = twist! (A′, i)
335- end
336- end
337- # A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′)))
338- # commented version leads to boxing of `A′` and type instabilities in the result
354+
355+ # Bring A in the correct form for BLAS contraction
356+ if copyA
357+ Anew = TO. tensoralloc_add (TC, A, pA, false , Val (true ), allocator)
358+ Anew = TO. tensoradd! (Anew, A, pA, false , One (), Zero (), backend, allocator)
359+ twistA && twist! (Anew, filter (! isdual ∘ Base. Fix1 (space, Anew), domainind (Anew)))
360+ else
361+ Anew = permute (A, pA)
339362 end
340- ipAB = TupleTools. invperm ((p₁... , p₂... ))
341- oindAinC = TupleTools. getindices (ipAB, ntuple (n -> n, N₁))
342- oindBinC = TupleTools. getindices (ipAB, ntuple (n -> n + N₁, N₂))
343- if has_shared_permute (C, (oindAinC, oindBinC))
344- C′ = permute (C, (oindAinC, oindBinC))
345- mul! (C′, A′, B′, α, β)
363+ pAnew = (codomainind (Anew), domainind (Anew))
364+
365+ # Bring B in the correct form for BLAS contraction
366+ if copyB
367+ Bnew = TO. tensoralloc_add (TC, B, pB, false , Val (true ), allocator)
368+ Bnew = TO. tensoradd! (Bnew, B, pB, false , One (), Zero (), backend, allocator)
369+ twistB && twist! (Bnew, filter (isdual ∘ Base. Fix1 (space, Bnew), codomainind (Bnew)))
346370 else
347- C′ = A′ * B′
348- add_permute! (C, C′, (p₁, p₂), α, β)
371+ Bnew = permute (B, pB)
349372 end
373+ pBnew = (codomainind (Bnew), domainind (Bnew))
374+
375+ # Bring C in the correct form for BLAS contraction
376+ ipAB = TO. oindABinC (pAB, pAnew, pBnew)
377+ copyC = ! TO. isblasdestination (C, ipAB)
378+
379+ if copyC
380+ Cnew = TO. tensoralloc_add (TC, C, ipAB, false , Val (true ), allocator)
381+ mul! (Cnew, Anew, Bnew)
382+ TO. tensoradd! (C, Cnew, pAB, false , α, β, backend, allocator)
383+ TO. tensorfree! (Cnew, allocator)
384+ else
385+ Cnew = permute (C, ipAB)
386+ mul! (Cnew, Anew, Bnew, α, β)
387+ end
388+
389+ copyA && TO. tensorfree! (Anew, allocator)
390+ copyB && TO. tensorfree! (Bnew, allocator)
391+
350392 return C
351393end
352394
0 commit comments