19
19
MulAdd {StyleA,StyleB,StyleC} (α, A, B, β, C)
20
20
end
21
21
22
- @inline MulAdd (α, A:: AA , B:: BB , β, C:: CC ) where {AA,BB,CC} =
22
+ @inline MulAdd (α, A:: AA , B:: BB , β, C:: CC ) where {AA,BB,CC} =
23
23
MulAdd {typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))} (α, A, B, β, C)
24
24
25
25
MulAdd (A, B) = MulAdd (Mul (A, B))
26
26
function MulAdd (M:: Mul )
27
27
TV = eltype (M)
28
- MulAdd (scalarone (TV), M. A, M. B, scalarzero (TV), fillzeros (TV,axes (M) ))
28
+ MulAdd (scalarone (TV), M. A, M. B, scalarzero (TV), mulzeros (TV,M ))
29
29
end
30
30
31
31
@inline eltype (:: MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC} ) where {StyleA,StyleB,StyleC,T,AA,BB,CC} =
@@ -69,18 +69,11 @@ muladd!(α, A, B, β, C) = materialize!(MulAdd(α, A, B, β, C))
69
69
materialize (M:: MulAdd ) = copy (instantiate (M))
70
70
copy (M:: MulAdd ) = copyto! (similar (M), M)
71
71
72
- @inline function copyto! (dest:: AbstractArray{T} , M:: MulAdd ) where T
73
- M. C === dest || copyto! (dest, M. C)
74
- muladd! (M. α, M. A, M. B, M. β, dest)
75
- end
72
+ _fill_copyto! (dest, C) = copyto! (dest, C)
73
+ _fill_copyto! (dest, C:: Zeros ) = zero! (dest) # exploit special fill! overload
76
74
77
- @inline function copyto! (dest:: AbstractArray{T} , M:: MulAdd{<:Any,<:Any,ZerosLayout} ) where T
78
- α,A,B,β,C = M. α, M. A, M. B, M. β, M. C
79
- if ! isbitstype (T) # instantiate
80
- dest .= β .* view (A,:,1 ) .* Ref (B[1 ]) # get shape right
81
- end
82
- muladd! (α, A, B, β, dest)
83
- end
75
+ @inline copyto! (dest:: AbstractArray{T} , M:: MulAdd ) where T =
76
+ muladd! (M. α, unalias (dest,M. A), unalias (dest,M. B), M. β, _fill_copyto! (dest, M. C))
84
77
85
78
# Modified from LinearAlgebra._generic_matmatmul!
86
79
function tile_size (T, S, R)
@@ -226,32 +219,28 @@ function _default_blasmul!(::IndexCartesian, α, A::AbstractMatrix, B::AbstractV
226
219
C
227
220
end
228
221
229
- default_blasmul! (α, A:: AbstractMatrix , B:: AbstractVector , β, C:: AbstractVector ) =
222
+ default_blasmul! (α, A:: AbstractMatrix , B:: AbstractVector , β, C:: AbstractVector ) =
230
223
_default_blasmul! (Base. IndexStyle (typeof (A)), α, A, B, β, C)
231
224
232
225
function materialize! (M:: MatMulMatAdd )
233
226
α, A, B, β, C = M. α, M. A, M. B, M. β, M. C
234
- if C ≡ B
235
- B = copy (B)
236
- end
237
- default_blasmul! (α, A, B, iszero (β) ? false : β, C)
227
+ default_blasmul! (α, unalias (C,A), unalias (C,B), iszero (β) ? false : β, C)
238
228
end
239
229
240
230
function materialize! (M:: MatMulMatAdd{<:AbstractStridedLayout,<:AbstractStridedLayout,<:AbstractStridedLayout} )
241
- α, A, B, β, C = M. α, M. A, M. B, M. β, M. C
242
- if C ≡ B
243
- B = copy (B)
244
- end
231
+ α, Ain, Bin, β, C = M. α, M. A, M. B, M. β, M. C
232
+ A = unalias (C, Ain)
233
+ B = unalias (C, Bin)
245
234
ts = tile_size (eltype (A), eltype (B), eltype (C))
246
235
if iszero (β) # false is a "strong" zero to wipe out NaNs
247
236
if ts == 0 || ! (axes (A) isa NTuple{2 ,OneTo{Int}}) || ! (axes (B) isa NTuple{2 ,OneTo{Int}}) || ! (axes (C) isa NTuple{2 ,OneTo{Int}})
248
- default_blasmul! (α, A, B, false , C)
249
- else
237
+ default_blasmul! (α, A, B, false , C)
238
+ else
250
239
tiled_blasmul! (ts, α, A, B, false , C)
251
240
end
252
241
else
253
242
if ts == 0 || ! (axes (A) isa NTuple{2 ,OneTo{Int}}) || ! (axes (B) isa NTuple{2 ,OneTo{Int}}) || ! (axes (C) isa NTuple{2 ,OneTo{Int}})
254
- default_blasmul! (α, A, B, β, C)
243
+ default_blasmul! (α, A, B, β, C)
255
244
else
256
245
tiled_blasmul! (ts, α, A, B, β, C)
257
246
end
@@ -260,29 +249,11 @@ end
260
249
261
250
function materialize! (M:: MatMulVecAdd )
262
251
α, A, B, β, C = M. α, M. A, M. B, M. β, M. C
263
- if C ≡ B
264
- B = copy (B)
265
- end
266
- default_blasmul! (α, A, B, iszero (β) ? false : β, C)
252
+ default_blasmul! (α, unalias (C,A), unalias (C,B), iszero (β) ? false : β, C)
267
253
end
268
254
269
- # make copy to make sure always works
270
- @inline function _gemv! (tA, α, A, x, β, y)
271
- if x ≡ y
272
- BLAS. gemv! (tA, α, A, copy (x), β, y)
273
- else
274
- BLAS. gemv! (tA, α, A, x, β, y)
275
- end
276
- end
277
-
278
- # make copy to make sure always works
279
- @inline function _gemm! (tA, tB, α, A, B, β, C)
280
- if B ≡ C
281
- BLAS. gemm! (tA, tB, α, A, copy (B), β, C)
282
- else
283
- BLAS. gemm! (tA, tB, α, A, B, β, C)
284
- end
285
- end
255
+ @inline _gemv! (tA, α, A, x, β, y) = BLAS. gemv! (tA, α, unalias (y,A), unalias (y,x), β, y)
256
+ @inline _gemm! (tA, tB, α, A, B, β, C) = BLAS. gemm! (tA, tB, α, unalias (C,A), unalias (C,B), β, C)
286
257
287
258
# work around pointer issues
288
259
@inline materialize! (M:: BlasMatMulVecAdd{<:AbstractColumnMajor,<:AbstractStridedLayout,<:AbstractStridedLayout} ) =
350
321
# ##
351
322
352
323
# make copy to make sure always works
353
- @inline function _symv! (tA, α, A, x, β, y)
354
- if x ≡ y
355
- BLAS. symv! (tA, α, A, copy (x), β, y)
356
- else
357
- BLAS. symv! (tA, α, A, x, β, y)
358
- end
359
- end
360
-
361
- @inline function _hemv! (tA, α, A, x, β, y)
362
- if x ≡ y
363
- BLAS. hemv! (tA, α, A, copy (x), β, y)
364
- else
365
- BLAS. hemv! (tA, α, A, x, β, y)
366
- end
367
- end
324
+ @inline _symv! (tA, α, A, x, β, y) = BLAS. symv! (tA, α, unalias (y,A), unalias (y,x), β, y)
325
+ @inline _hemv! (tA, α, A, x, β, y) = BLAS. hemv! (tA, α, unalias (y,A), unalias (y,x), β, y)
368
326
369
327
370
328
materialize! (M:: BlasMatMulVecAdd{<:SymmetricLayout{<:AbstractColumnMajor},<:AbstractStridedLayout,<:AbstractStridedLayout} ) =
@@ -411,10 +369,28 @@ scalarone(::Type{<:AbstractArray{T}}) where T = scalarone(T)
411
369
scalarzero (:: Type{T} ) where T = zero (T)
412
370
scalarzero (:: Type{<:AbstractArray{T}} ) where T = scalarzero (T)
413
371
414
- fillzeros (:: Type{T} , ax) where T = Zeros {T} (ax)
372
+ fillzeros (:: Type{T} , ax) where T<: Number = Zeros {T} (ax)
373
+ mulzeros (:: Type{T} , M) where T<: Number = fillzeros (T, axes (M))
374
+
375
+ # initiate array-valued MulAdd
376
+ function _mulzeros! (dest:: AbstractVector{T} , A, B) where T
377
+ for k in axes (dest,1 )
378
+ dest[k] = similar (Mul (A[k,1 ],B[1 ]), eltype (T))
379
+ end
380
+ dest
381
+ end
382
+
383
+ function _mulzeros! (dest:: AbstractMatrix{T} , A, B) where T
384
+ for j in axes (dest,2 ), k in axes (dest,1 )
385
+ dest[k,j] = similar (Mul (A[k,1 ],B[1 ,j]), eltype (T))
386
+ end
387
+ dest
388
+ end
389
+
390
+ mulzeros (:: Type{T} , M) where T<: AbstractArray = _mulzeros! (similar (Array{T}, axes (M)), M. A, M. B)
415
391
416
392
# ##
417
- # Fill
393
+ # Fill
418
394
# ##
419
395
420
396
copy (M:: MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout} ) = M. α* M. A* M. B + M. β* M. C
0 commit comments