@@ -136,48 +136,14 @@ function stridedtensorcontract!(C::StridedView,
136136 B:: StridedView , pB:: Index2Tuple ,
137137 pAB:: Index2Tuple ,
138138 α:: Number , β:: Number ,
139- :: StridedBLAS , allocator= DefaultAllocator ())
139+ backend :: StridedBLAS , allocator= DefaultAllocator ())
140140 argcheck_tensorcontract (C, A, pA, B, pB, pAB)
141141 dimcheck_tensorcontract (C, A, pA, B, pB, pAB)
142142
143143 (Base. mightalias (C, A) || Base. mightalias (C, B)) &&
144144 throw (ArgumentError (" output tensor must not be aliased with input tensor" ))
145145
146- rpA = reverse (pA)
147- rpB = reverse (pB)
148- indCinoBA = let N₁ = numout (pA), N₂ = numin (pB)
149- map (n -> ifelse (n > N₁, n - N₁, n + N₂), linearize (pAB))
150- end
151- tpAB = trivialpermutation (pAB)
152- rpAB = (TupleTools. getindices (indCinoBA, tpAB[1 ]),
153- TupleTools. getindices (indCinoBA, tpAB[2 ]))
154- if contract_memcost (C, A, pA, B, pB, pAB) <= contract_memcost (C, B, rpB, A, rpA, rpAB)
155- return blas_contract! (C, A, pA, B, pB, pAB, α, β, allocator)
156- else
157- return blas_contract! (C, B, rpB, A, rpA, rpAB, α, β, allocator)
158- end
159- return C
160- end
161-
162- # reduce overhead for the case where it is just matrix multiplication
163- function stridedtensorcontract! (C:: StridedView{T,2} ,
164- A:: StridedView{T,2} , pA:: Index2Tuple{1,1} ,
165- B:: StridedView{T,2} , pB:: Index2Tuple{1,1} ,
166- pAB:: Index2Tuple{1,1} , α:: Number , β:: Number ,
167- :: StridedBLAS ) where {T}
168- argcheck_tensorcontract (C, A, pA, B, pB, pAB)
169- dimcheck_tensorcontract (C, A, pA, B, pB, pAB)
170-
171- (Base. mightalias (C, A) || Base. mightalias (C, B)) &&
172- throw (ArgumentError (" output tensor must not be aliased with input tensor" ))
173-
174- A′ = pA == ((1 ,), (2 ,)) ? A : permutedims (A, (pA[1 ][1 ], pA[2 ][1 ]))
175- B′ = pB == ((1 ,), (2 ,)) ? B : permutedims (B, (pB[1 ][1 ], pB[2 ][1 ]))
176- if pAB == ((1 ,), (2 ,))
177- mul! (C, A′, B′, α, β)
178- elseif pAB == ((2 ,), (1 ,))
179- mul! (C, transpose (B′), transpose (A′), α, β)
180- end
146+ blas_contract! (C, A, pA, B, pB, pAB, α, β, backend, allocator)
181147 return C
182148end
183149
@@ -209,131 +175,3 @@ function stridedtensorcontract!(C::StridedView,
209175 Strided. _mapreducedim! (op1, + , op2, tsize, (CS, AS, BS))
210176 return C
211177end
212-
213- # -------------------------------------------------------------------------------------------
214- # StridedViewBLAS contraction implementation
215- # -------------------------------------------------------------------------------------------
216- function blas_contract! (C, A, pA, B, pB, pAB, α, β, allocator)
217- TC = eltype (C)
218-
219- A_, pA, flagA = makeblascontractable (A, pA, TC, allocator)
220- B_, pB, flagB = makeblascontractable (B, pB, TC, allocator)
221-
222- ipAB = oindABinC (pAB, pA, pB)
223- flagC = isblasdestination (C, ipAB)
224- if flagC
225- C_ = C
226- _unsafe_blas_contract! (C_, A_, pA, B_, pB, ipAB, α, β)
227- else
228- C_ = SV (tensoralloc_add (TC, C, ipAB, false , Val (true ), allocator))
229- _unsafe_blas_contract! (C_, A_, pA, B_, pB, trivialpermutation (ipAB),
230- one (TC), zero (TC))
231- stridedtensoradd! (C, C_, pAB, α, β, StridedNative (), allocator)
232- tensorfree! (C_. parent, allocator)
233- end
234- flagA || tensorfree! (A_. parent, allocator)
235- flagB || tensorfree! (B_. parent, allocator)
236- return C
237- end
238-
239- function _unsafe_blas_contract! (C:: StridedView{T} ,
240- A:: StridedView{T} , pA,
241- B:: StridedView{T} , pB,
242- pAB, α, β) where {T<: BlasFloat }
243- sizeA = size (A)
244- sizeB = size (B)
245- csizeA = TupleTools. getindices (sizeA, pA[2 ])
246- csizeB = TupleTools. getindices (sizeB, pB[1 ])
247- osizeA = TupleTools. getindices (sizeA, pA[1 ])
248- osizeB = TupleTools. getindices (sizeB, pB[2 ])
249-
250- mul! (sreshape (permutedims (C, linearize (pAB)), (prod (osizeA), prod (osizeB))),
251- sreshape (permutedims (A, linearize (pA)), (prod (osizeA), prod (csizeA))),
252- sreshape (permutedims (B, linearize (pB)), (prod (csizeB), prod (osizeB))),
253- α, β)
254-
255- return C
256- end
257-
258- @inline function makeblascontractable (A, pA, TC, allocator)
259- flagA = isblascontractable (A, pA) && eltype (A) == TC
260- if ! flagA
261- A_ = tensoralloc_add (TC, A, pA, false , Val (true ), allocator)
262- Anew = SV (A_, size (A_), strides (A_), 0 , A. op)
263- Anew = stridedtensoradd! (Anew, A, pA, One (), Zero (), StridedNative (), allocator)
264- pAnew = trivialpermutation (pA)
265- else
266- Anew = A
267- pAnew = pA
268- end
269- return Anew, pAnew, flagA
270- end
271-
272- function isblascontractable (A:: StridedView , p:: Index2Tuple )
273- eltype (A) <: LinearAlgebra.BlasFloat || return false
274-
275- sizeA = size (A)
276- stridesA = strides (A)
277- sizeA1 = TupleTools. getindices (sizeA, p[1 ])
278- sizeA2 = TupleTools. getindices (sizeA, p[2 ])
279- stridesA1 = TupleTools. getindices (stridesA, p[1 ])
280- stridesA2 = TupleTools. getindices (stridesA, p[2 ])
281-
282- canfuse1, _, s1 = _canfuse (sizeA1, stridesA1)
283- canfuse2, _, s2 = _canfuse (sizeA2, stridesA2)
284-
285- if A. op == conj
286- return canfuse1 && canfuse2 && s2 == 1
287- else
288- return canfuse1 && canfuse2 && (s1 == 1 || s2 == 1 )
289- end
290- end
291-
292- function isblasdestination (A:: StridedView , p:: Index2Tuple )
293- (eltype (A) <: LinearAlgebra.BlasFloat && A. op == identity) || return false
294-
295- sizeA = size (A)
296- stridesA = strides (A)
297-
298- sizeA1 = TupleTools. getindices (sizeA, p[1 ])
299- stridesA1 = TupleTools. getindices (stridesA, p[1 ])
300- canfuse1, _, s1 = _canfuse (sizeA1, stridesA1)
301- (canfuse1 && s1 == 1 ) || return false
302-
303- sizeA2 = TupleTools. getindices (sizeA, p[2 ])
304- stridesA2 = TupleTools. getindices (stridesA, p[2 ])
305- canfuse2, _, _ = _canfuse (sizeA2, stridesA2)
306- return canfuse2
307- end
308-
309- _canfuse (:: Dims{0} , :: Dims{0} ) = true , 1 , 1
310- _canfuse (dims:: Dims{1} , strides:: Dims{1} ) = true , dims[1 ], strides[1 ]
311- @inline function _canfuse (dims:: Dims{N} , strides:: Dims{N} ) where {N}
312- if dims[1 ] == 0
313- return true , 0 , 1
314- elseif dims[1 ] == 1
315- return _canfuse (Base. tail (dims), Base. tail (strides))
316- else
317- b, d, s = _canfuse (Base. tail (dims), Base. tail (strides))
318- if b && (s == dims[1 ] * strides[1 ] || d == 1 )
319- dnew = dims[1 ] * d
320- return true , dnew, (dnew == 0 || dnew == 1 ) ? 1 : strides[1 ]
321- else
322- return false , dims[1 ] * d, strides[1 ]
323- end
324- end
325- end
326-
327- function oindABinC (pAB, pA, pB)
328- ipAB = invperm (linearize (pAB))
329- oindAinC = TupleTools. getindices (ipAB, trivialpermutation (pA[1 ]))
330- oindBinC = TupleTools. getindices (ipAB, numout (pA) .+ trivialpermutation (pB[2 ]))
331- return (oindAinC, oindBinC)
332- end
333-
334- function contract_memcost (C, A, pA, B, pB, pAB)
335- ipAB = oindABinC (pAB, pA, pB)
336- return length (A) * (! isblascontractable (A, pA) || eltype (A) != = eltype (C)) +
337- length (B) * (! isblascontractable (B, pB) || eltype (B) != = eltype (C)) +
338- length (C) * ! isblasdestination (C, ipAB)
339- end
0 commit comments