@@ -189,6 +189,23 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractMatrix, β, C::Abstr
189
189
C
190
190
end
191
191
192
+ function default_blasmul! (α, A:: AbstractVector , B:: AbstractMatrix , β, C:: AbstractMatrix )
193
+ mA, = size (A)
194
+ mB, nB = size (B)
195
+ 1 == mB || throw (DimensionMismatch (" Dimensions must match" ))
196
+ size (C) == (mA, nB) || throw (DimensionMismatch (" Dimensions must match" ))
197
+
198
+ lmul! (β, C)
199
+
200
+ (iszero (mA) || iszero (nB)) && return C
201
+
202
+ for k in colsupport (A), j in rowsupport (B)
203
+ _default_blasmul_loop! (α, A, B, β, C, k, j)
204
+ end
205
+ C
206
+ end
207
+
208
+
192
209
function _default_blasmul! (:: IndexLinear , α, A:: AbstractMatrix , B:: AbstractVector , β, C:: AbstractVector )
193
210
mA, nA = size (A)
194
211
mB = length (B)
@@ -266,6 +283,11 @@ function materialize!(M::MatMulVecAdd)
266
283
default_blasmul! (α, unalias (C,A), unalias (C,B), iszero (β) ? false : β, C)
267
284
end
268
285
286
+ function materialize! (M:: VecMulMatAdd )
287
+ α, A, B, β, C = M. α, M. A, M. B, M. β, M. C
288
+ default_blasmul! (α, unalias (C,A), unalias (C,B), iszero (β) ? false : β, C)
289
+ end
290
+
269
291
@inline _gemv! (tA, α, A, x, β, y) = BLAS. gemv! (tA, α, unalias (y,A), unalias (y,x), β, y)
270
292
@inline _gemm! (tA, tB, α, A, B, β, C) = BLAS. gemm! (tA, tB, α, unalias (C,A), unalias (C,B), β, C)
271
293
@@ -424,8 +446,7 @@ function similar(M::MulAdd{<:DualLayout,<:Any,ZerosLayout}, ::Type{T}, (x,y)) wh
424
446
trans (similar (trans (M. A), T, y))
425
447
end
426
448
427
- function similar (M:: MulAdd{<:Any,<:DualLayout,ZerosLayout} , :: Type{T} , (x,y)) where T
428
- @assert length (x) == 1
449
+ function similar (M:: MulAdd{ScalarLayout,<:DualLayout,ZerosLayout} , :: Type{T} , (x,y)) where T
429
450
trans = transtype (M. B)
430
451
trans (similar (trans (M. B), T, y))
431
452
end
@@ -434,3 +455,4 @@ const ZerosLayouts = Union{ZerosLayout,DualLayout{ZerosLayout}}
434
455
copy (M:: MulAdd{<:ZerosLayouts, <:ZerosLayouts, <:ZerosLayouts} ) = M. C
435
456
copy (M:: MulAdd{<:ZerosLayouts, <:Any, <:ZerosLayouts} ) = M. C
436
457
copy (M:: MulAdd{<:Any, <:ZerosLayouts, <:ZerosLayouts} ) = M. C
458
+
0 commit comments