@@ -327,51 +327,67 @@ function blas_contract!(
327327 pAB:: Index2Tuple , α, β,
328328 backend, allocator
329329 )
330- I = sectortype (C)
331- BraidingStyle (I) isa SymmetricBraiding ||
330+ bstyle = BraidingStyle ( sectortype (C) )
331+ bstyle isa SymmetricBraiding ||
332332 throw (SectorMismatch (" only tensors with symmetric braiding rules can be contracted; try `@planar` instead" ))
333333 TC = eltype (C)
334334
335+ # check which tensors have to be permuted/copied
336+ copyA = ! (TO. isblascontractable (A, pA) && eltype (A) === TC)
337+ copyB = ! (TO. isblascontractable (B, pB) && eltype (B) === TC)
338+
339+ if bstyle isa Fermionic && any (isdual ∘ Base. Fix1 (space, B), pB[1 ])
340+ # twist smallest object if neither or both already have to be permuted
341+ # otherwise twist the one that already is copied
342+ if copyA ⊻ copyB
343+ twistA = dim (A) < dim (B)
344+ else
345+ twistA = copyA
346+ end
347+ twistB = ! twistA
348+ copyA |= twistA
349+ copyB |= twistB
350+ else
351+ twistA = false
352+ twistB = false
353+ end
354+
335355 # Bring A in the correct form for BLAS contraction
336- flagA = TO. isblascontractable (A, pA) && eltype (A) === TC &&
337- ! (BraidingStyle (I) isa Fermionic && any (i -> isdual (space (A, i)), pA[2 ]))
338- if ! flagA
356+ if copyA
339357 Anew = TO. tensoralloc_add (TC, A, pA, false , Val (true ), allocator)
340358 Anew = TO. tensoradd! (Anew, A, pA, false , One (), Zero (), backend, allocator)
341- for i in domainind (Anew)
342- isdual (space (Anew, i)) || twist! (Anew, i)
343- end
359+ twistA && twist! (Anew, filter (! isdual ∘ Base. Fix1 (space, Anew), domainind (Anew)))
344360 else
345361 Anew = permute (A, pA)
346362 end
347363 pAnew = (codomainind (Anew), domainind (Anew))
348364
349365 # Bring B in the correct form for BLAS contraction
350- flagB = TO. isblascontractable (B, pB) && eltype (B) === TC
351- if ! flagB
366+ if copyB
352367 Bnew = TO. tensoralloc_add (TC, B, pB, false , Val (true ), allocator)
353368 Bnew = TO. tensoradd! (Bnew, B, pB, false , One (), Zero (), backend, allocator)
369+ twistB && twist! (Bnew, filter (isdual ∘ Base. Fix1 (space, Bnew), codomainind (Bnew)))
354370 else
355371 Bnew = permute (B, pB)
356372 end
357373 pBnew = (codomainind (Bnew), domainind (Bnew))
358374
359375 # Bring C in the correct form for BLAS contraction
360376 ipAB = TO. oindABinC (pAB, pAnew, pBnew)
361- flagC = TO. isblasdestination (C, ipAB)
377+ copyC = ! TO. isblasdestination (C, ipAB)
362378
363- if flagC
364- Cnew = permute (C, ipAB)
365- mul! (Cnew, Anew, Bnew, α, β)
366- else
379+ if copyC
367380 Cnew = TO. tensoralloc_add (TC, C, ipAB, false , Val (true ), allocator)
368381 mul! (Cnew, Anew, Bnew)
369382 TO. tensoradd! (C, Cnew, pAB, false , α, β, backend, allocator)
370383 TO. tensorfree! (Cnew, allocator)
384+ else
385+ Cnew = permute (C, ipAB)
386+ mul! (Cnew, Anew, Bnew, α, β)
371387 end
372388
373- flagA || TO. tensorfree! (Anew, allocator)
374- flagB || TO. tensorfree! (Bnew, allocator)
389+ copyA && TO. tensorfree! (Anew, allocator)
390+ copyB && TO. tensorfree! (Bnew, allocator)
375391
376392 return C
377393end
0 commit comments