Skip to content

Commit a2a6913

Browse files
committed
rework to twist smallest object
1 parent 6e238c5 commit a2a6913

File tree

1 file changed

+33
-17
lines changed

1 file changed

+33
-17
lines changed

src/tensors/tensoroperations.jl

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -327,51 +327,67 @@ function blas_contract!(
327327
pAB::Index2Tuple, α, β,
328328
backend, allocator
329329
)
330-
I = sectortype(C)
331-
BraidingStyle(I) isa SymmetricBraiding ||
330+
bstyle = BraidingStyle(sectortype(C))
331+
bstyle isa SymmetricBraiding ||
332332
throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
333333
TC = eltype(C)
334334

335+
# check which tensors have to be permuted/copied
336+
copyA = !(TO.isblascontractable(A, pA) && eltype(A) === TC)
337+
copyB = !(TO.isblascontractable(B, pB) && eltype(B) === TC)
338+
339+
if bstyle isa Fermionic && any(isdual Base.Fix1(space, B), pB[1])
340+
# twist smallest object if neither or both already have to be permuted
341+
# otherwise twist the one that already is copied
342+
if copyA copyB
343+
twistA = dim(A) < dim(B)
344+
else
345+
twistA = copyA
346+
end
347+
twistB = !twistA
348+
copyA |= twistA
349+
copyB |= twistB
350+
else
351+
twistA = false
352+
twistB = false
353+
end
354+
335355
# Bring A in the correct form for BLAS contraction
336-
flagA = TO.isblascontractable(A, pA) && eltype(A) === TC &&
337-
!(BraidingStyle(I) isa Fermionic && any(i -> isdual(space(A, i)), pA[2]))
338-
if !flagA
356+
if copyA
339357
Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator)
340358
Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
341-
for i in domainind(Anew)
342-
isdual(space(Anew, i)) || twist!(Anew, i)
343-
end
359+
twistA && twist!(Anew, filter(!isdual Base.Fix1(space, Anew), domainind(Anew)))
344360
else
345361
Anew = permute(A, pA)
346362
end
347363
pAnew = (codomainind(Anew), domainind(Anew))
348364

349365
# Bring B in the correct form for BLAS contraction
350-
flagB = TO.isblascontractable(B, pB) && eltype(B) === TC
351-
if !flagB
366+
if copyB
352367
Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator)
353368
Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator)
369+
twistB && twist!(Bnew, filter(isdual Base.Fix1(space, Bnew), codomainind(Bnew)))
354370
else
355371
Bnew = permute(B, pB)
356372
end
357373
pBnew = (codomainind(Bnew), domainind(Bnew))
358374

359375
# Bring C in the correct form for BLAS contraction
360376
ipAB = TO.oindABinC(pAB, pAnew, pBnew)
361-
flagC = TO.isblasdestination(C, ipAB)
377+
copyC = !TO.isblasdestination(C, ipAB)
362378

363-
if flagC
364-
Cnew = permute(C, ipAB)
365-
mul!(Cnew, Anew, Bnew, α, β)
366-
else
379+
if copyC
367380
Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)
368381
mul!(Cnew, Anew, Bnew)
369382
TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator)
370383
TO.tensorfree!(Cnew, allocator)
384+
else
385+
Cnew = permute(C, ipAB)
386+
mul!(Cnew, Anew, Bnew, α, β)
371387
end
372388

373-
flagA || TO.tensorfree!(Anew, allocator)
374-
flagB || TO.tensorfree!(Bnew, allocator)
389+
copyA && TO.tensorfree!(Anew, allocator)
390+
copyB && TO.tensorfree!(Bnew, allocator)
375391

376392
return C
377393
end

0 commit comments

Comments
 (0)