@@ -94,17 +94,17 @@ function TO.tensorcontract!(C::AbstractTensorMap,
9494 pA′ = adjointtensorindices (A, pA)
9595 B′ = B'
9696 pB′ = adjointtensorindices (B, pB)
97- contract ! (C, A′, pA′, B′, pB′, pAB′, α, β, backend, allocator)
97+ TO . blas_contract ! (C, A′, pA′, B′, pB′, pAB′, α, β, backend, allocator)
9898 elseif conjA
9999 A′ = A'
100100 pA′ = adjointtensorindices (A, pA)
101- contract ! (C, A′, pA′, B, pB, pAB′, α, β, backend, allocator)
101+ TO . blas_contract ! (C, A′, pA′, B, pB, pAB′, α, β, backend, allocator)
102102 elseif conjB
103103 B′ = B'
104104 pB′ = adjointtensorindices (B, pB)
105- contract ! (C, A, pA, B′, pB′, pAB′, α, β, backend, allocator)
105+ TO . blas_contract ! (C, A, pA, B′, pB′, pAB′, α, β, backend, allocator)
106106 else
107- contract ! (C, A, pA, B, pB, pAB′, α, β, backend, allocator)
107+ TO . blas_contract ! (C, A, pA, B, pB, pAB′, α, β, backend, allocator)
108108 end
109109 return C
110110end
@@ -154,6 +154,19 @@ TO.tensorcost(t::AbstractTensorMap, i::Int) = dim(space(t, i))
154154 scheduler:: S = SerialScheduler ()
155155end
156156
157+ function TO. select_backend (:: typeof (TO. tensoradd!), C:: AbstractTensorMap ,
158+ A:: AbstractTensorMap )
159+ return TensorKitBackend ()
160+ end
161+ function TO. select_backend (:: typeof (TO. tensortrace!), C:: AbstractTensorMap ,
162+ A:: AbstractTensorMap )
163+ return TensorKitBackend ()
164+ end
165+ function TO. select_backend (:: typeof (TO. tensorcontract!), C:: AbstractTensorMap ,
166+ A:: AbstractTensorMap , B:: AbstractTensorMap )
167+ return TensorKitBackend ()
168+ end
169+
157170# Trace implementation
158171# ----------------------
159172"""
@@ -232,114 +245,105 @@ end
232245# TODO : contraction with either A or B a rank (1, 1) tensor does not require to
233246# permute the fusion tree and should therefore be special cased. This will speed
234247# up MPS algorithms
235- """
236- contract!(C::AbstractTensorMap,
237- A::AbstractTensorMap, (oindA, cindA)::Index2Tuple,
238- B::AbstractTensorMap, (cindB, oindB)::Index2Tuple,
239- (p₁, p₂)::Index2Tuple,
240- α::Number, β::Number,
241- backend, allocator)
242-
243- Return the updated `C`, which is the result of adding `α * A * B` to `C` after permuting
244- the indices of `A` and `B` according to `(oindA, cindA)` and `(cindB, oindB)` respectively.
245- """
246- function contract! (C:: AbstractTensorMap ,
247- A:: AbstractTensorMap , (oindA, cindA):: Index2Tuple ,
248- B:: AbstractTensorMap , (cindB, oindB):: Index2Tuple ,
249- (p₁, p₂):: Index2Tuple ,
250- α:: Number , β:: Number ,
251- backend, allocator)
252- length (cindA) == length (cindB) ||
253- throw (IndexError (" number of contracted indices does not match" ))
254- N₁, N₂ = length (oindA), length (oindB)
255-
256- # find optimal contraction scheme
257- hsp = has_shared_permute
258- ipAB = TupleTools. invperm ((p₁... , p₂... ))
259- oindAinC = TupleTools. getindices (ipAB, ntuple (n -> n, N₁))
260- oindBinC = TupleTools. getindices (ipAB, ntuple (n -> n + N₁, N₂))
261-
262- qA = TupleTools. sortperm (cindA)
263- cindA′ = TupleTools. getindices (cindA, qA)
264- cindB′ = TupleTools. getindices (cindB, qA)
265-
266- qB = TupleTools. sortperm (cindB)
267- cindA′′ = TupleTools. getindices (cindA, qB)
268- cindB′′ = TupleTools. getindices (cindB, qB)
269-
270- dA, dB, dC = dim (A), dim (B), dim (C)
271-
272- # keep order A en B, check possibilities for cind
273- memcost1 = memcost2 = dC * (! hsp (C, (oindAinC, oindBinC)))
274- memcost1 += dA * (! hsp (A, (oindA, cindA′))) +
275- dB * (! hsp (B, (cindB′, oindB)))
276- memcost2 += dA * (! hsp (A, (oindA, cindA′′))) +
277- dB * (! hsp (B, (cindB′′, oindB)))
278-
279- # reverse order A en B, check possibilities for cind
280- memcost3 = memcost4 = dC * (! hsp (C, (oindBinC, oindAinC)))
281- memcost3 += dB * (! hsp (B, (oindB, cindB′))) +
282- dA * (! hsp (A, (cindA′, oindA)))
283- memcost4 += dB * (! hsp (B, (oindB, cindB′′))) +
284- dA * (! hsp (A, (cindA′′, oindA)))
285-
286- if min (memcost1, memcost2) <= min (memcost3, memcost4)
287- if memcost1 <= memcost2
288- return _contract! (α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂)
248+
249+ # this is a copy of the TensorOperations implementation, adding two ways to
250+ # permute the contracted indices for 4 total possible implementations
251+ function TO. blas_contract! (C:: AbstractTensorMap , A:: AbstractTensorMap , pA,
252+ B:: AbstractTensorMap , pB, pAB, α, β, backend, allocator)
253+ # index permutations for reverse contraction
254+ indCinoBA = let N₁ = TO. numout (pA), N₂ = TO. numin (pB)
255+ map (n -> ifelse (n > N₁, n - N₁, n + N₂), TO. linearize (pAB))
256+ end
257+ tpAB = TO. trivialpermutation (pAB)
258+ pBA = (TupleTools. getindices (indCinoBA, tpAB[1 ]),
259+ TupleTools. getindices (indCinoBA, tpAB[2 ]))
260+
261+ # permutations of contracted indices
262+ qA = TupleTools. sortperm (pA[2 ])
263+ pA′ = (pA[1 ], TupleTools. getindices (pA[2 ], qA))
264+ pB′ = (TupleTools. getindices (pB[1 ], qA), pB[2 ])
265+
266+ qB = TupleTools. sortperm (pB[1 ])
267+ pA″ = (pA[1 ], TupleTools. getindices (pA[2 ], qB))
268+ pB″ = (TupleTools. getindices (pB[1 ], qB), pB[2 ])
269+
270+ memcost1 = TO. contract_memcost (C, A, pA′, B, pB′, pAB)
271+ memcost2 = TO. contract_memcost (C, A, pA″, B, pB″, pAB)
272+ memcost3 = TO. contract_memcost (C, B, reverse (pB′), A, reverse (pA′), pBA)
273+ memcost4 = TO. contract_memcost (C, B, reverse (pB″), A, reverse (pA″), pBA)
274+
275+ return if min (memcost1, memcost2) ≤ min (memcost3, memcost4)
276+ if memcost1 ≤ memcost2
277+ _blas_contract! (C, A, pA′, B, pB′, pAB, α, β, backend, allocator)
289278 else
290- return _contract! (α , A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂ )
279+ _blas_contract! (C , A, pA″, B, pB″, pAB, α, β, backend, allocator )
291280 end
292281 else
293- p1′ = map (n -> ifelse (n > N₁, n - N₁, n + N₂), p₁)
294- p2′ = map (n -> ifelse (n > N₁, n - N₁, n + N₂), p₂)
295- if memcost3 <= memcost4
296- return _contract! (α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′)
282+ if memcost3 ≤ memcost4
283+ _blas_contract! (C, B, reverse (pB′), A, reverse (pA′), pBA, α, β, backend,
284+ allocator)
297285 else
298- return _contract! (α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′)
286+ _blas_contract! (C, B, reverse (pB″), A, reverse (pA″), pBA, α, β, backend,
287+ allocator)
299288 end
300289 end
301290end
302291
303- # TODO : also transform _contract! into new interface, and add backend support
304- function _contract! (α, A:: AbstractTensorMap , B:: AbstractTensorMap ,
305- β, C:: AbstractTensorMap ,
306- oindA:: IndexTuple , cindA:: IndexTuple ,
307- oindB:: IndexTuple , cindB:: IndexTuple ,
308- p₁:: IndexTuple , p₂:: IndexTuple )
309- if ! (BraidingStyle (sectortype (C)) isa SymmetricBraiding)
310- throw (SectorMismatch (" only tensors with symmetric braiding rules can be contracted; try `@planar` instead" ))
311- end
312- N₁, N₂ = length (oindA), length (oindB)
313- copyA = false
314- if BraidingStyle (sectortype (A)) isa Fermionic
315- for i in cindA
316- if ! isdual (space (A, i))
317- copyA = true
318- end
319- end
292+ function TO. contract_memcost (C:: AbstractTensorMap , A:: AbstractTensorMap , pA,
293+ B:: AbstractTensorMap , pB, pAB)
294+ ipAB = TO. oindABinC (pAB, pA, pB)
295+ return dim (A) * ! isblascontractable (A, pA) + dim (B) * ! isblascontractable (B, pB) +
296+ dim (C) * ! isblasdestination (C, ipAB)
297+ end
298+
299+ # TODO : delibarately not importing private TO functions from here on out. Should we?
300+ function _blas_contract! (C, A, pA, B, pB, pAB, α, β, backend, allocator)
301+ TC = eltype (C)
302+
303+ A_, pA, flagA = makeblascontractable (A, pA, TC, backend, allocator, true )
304+ B_, pB, flagB = makeblascontractable (B, pB, TC, backend, allocator, false )
305+
306+ ipAB = TO. oindABinC (pAB, pA, pB)
307+ flagC = isblasdestination (C, ipAB)
308+ if flagC
309+ mul! (C, A_, B_, α, β, backend)
310+ else
311+ C_ = TO. tensoralloc_add (TC, C, ipAB, false , Val (true ), allocator)
312+ mul! (C_, A_, B_, One (), Zero (), backend)
313+ TO. tensoradd! (C, C_, pAB, false , α, β, backend, allocator)
314+ TO. tensorfree! (C_, allocator)
320315 end
321- A′ = permute (A, (oindA, cindA); copy= copyA)
322- B′ = permute (B, (cindB, oindB))
323- if BraidingStyle (sectortype (A)) isa Fermionic
324- for i in domainind (A′)
325- if ! isdual (space (A′, i))
326- A′ = twist! (A′, i)
316+ flagA || TO. tensorfree! (A_, allocator)
317+ flagB || TO. tensorfree! (B_, allocator)
318+ return C
319+ end
320+
321+ isblascontractable (A, pA) = (pA[1 ] == codomainind (A) && pA[2 ] == domainind (A))
322+ function isblasdestination (A:: AbstractTensorMap , p:: Index2Tuple )
323+ return (p[1 ] == codomainind (A) && p[2 ] == domainind (A))
324+ end
325+
326+ @inline function makeblascontractable (A:: AbstractTensorMap , pA, TC, backend, allocator,
327+ dotwist:: Bool = false )
328+ flagA = (scalartype (A) === TC && isblascontractable (A, pA) && ! dotwist)
329+ if ! flagA
330+ A_ = TO. tensoralloc_add (TC, A, pA, false , Val (true ), allocator)
331+ Anew = TO. tensoradd! (A_, A, pA, false , One (), Zero (), backend, allocator)
332+ if dotwist && (BraidingStyle (sectortype (A)) isa Fermionic)
333+ for i in domainind (Anew)
334+ if ! isdual (space (Anew, i))
335+ twist! (Anew, i)
336+ end
327337 end
338+ # TODO : this seems type-unstable:
339+ # Anew = twist!(Anew, filter(i -> !isdual(space(Anew, i)), domainind(Anew)))
328340 end
329- # A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′)))
330- # commented version leads to boxing of `A′` and type instabilities in the result
331- end
332- ipAB = TupleTools. invperm ((p₁... , p₂... ))
333- oindAinC = TupleTools. getindices (ipAB, ntuple (n -> n, N₁))
334- oindBinC = TupleTools. getindices (ipAB, ntuple (n -> n + N₁, N₂))
335- if has_shared_permute (C, (oindAinC, oindBinC))
336- C′ = permute (C, (oindAinC, oindBinC))
337- mul! (C′, A′, B′, α, β)
341+ pAnew = TO. trivialpermutation (pA)
338342 else
339- C′ = A′ * B′
340- add_permute! (C, C′, (p₁, p₂), α, β)
343+ Anew = A
344+ pAnew = pA
341345 end
342- return C
346+ return Anew, pAnew, flagA
343347end
344348
345349# Scalar implementation
0 commit comments