@@ -49,6 +49,13 @@ function check_mul_axes(A, B, C...)
49
49
check_mul_axes (B, C... )
50
50
end
51
51
52
+ # we need to special case AbstractQ as it allows non-compatiple multiplication
53
+ function check_mul_axes (A:: AbstractQ , B, C... )
54
+ axes (A. factors, 1 ) == axes (B, 1 ) || axes (A. factors, 2 ) == axes (B, 1 ) ||
55
+ throw (DimensionMismatch (" First axis of B, $(axes (B,1 )) must match either axes of A, $(axes (A)) " ))
56
+ check_mul_axes (B, C... )
57
+ end
58
+
52
59
53
60
function instantiate (M:: MulAdd )
54
61
@boundscheck check_mul_axes (M. α, M. A, M. B)
@@ -373,4 +380,85 @@ copy(M::MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout})
373
380
copy (M:: MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout}} ) = (M. α * getindex_value (M. B. diag)) .* M. A .+ M. β .* M. C
374
381
copy (M:: MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout},ZerosLayout} ) = (M. α * getindex_value (M. B. diag)) .* M. A
375
382
376
- BroadcastStyle (:: Type{<:MulAdd} ) = ApplyBroadcastStyle ()
383
+ BroadcastStyle (:: Type{<:MulAdd} ) = ApplyBroadcastStyle ()
384
+
385
+ scalarone (:: Type{T} ) where T = one (T)
386
+ scalarone (:: Type{<:AbstractArray{T}} ) where T = scalarone (T)
387
+ scalarzero (:: Type{T} ) where T = zero (T)
388
+ scalarzero (:: Type{<:AbstractArray{T}} ) where T = scalarzero (T)
389
+
390
+ fillzeros (:: Type{T} , ax) where T = Zeros {T} (ax)
391
+
392
+ function mul! (dest:: AbstractArray{W} , A:: AbstractArray{T} , b:: AbstractArray{V} ) where {T,V,W}
393
+ TVW = promote_type (W, _mul_eltype (T,V))
394
+ muladd! (scalarone (TVW), A, b, scalarzero (TVW), dest)
395
+ end
396
+
397
+ function MulAdd (A:: AbstractArray{T} , B:: AbstractVector{V} ) where {T,V}
398
+ TV = _mul_eltype (eltype (A), eltype (B))
399
+ MulAdd (scalarone (TV), A, B, scalarzero (TV), fillzeros (TV,(axes (A,1 ))))
400
+ end
401
+
402
+ function MulAdd (A:: AbstractArray{T} , B:: AbstractMatrix{V} ) where {T,V}
403
+ TV = _mul_eltype (eltype (A), eltype (B))
404
+ MulAdd (scalarone (TV), A, B, scalarzero (TV), fillzeros (TV,(axes (A,1 ),axes (B,2 ))))
405
+ end
406
+
407
+ mul (A:: AbstractArray , B:: AbstractArray ) = materialize (MulAdd (A,B))
408
+
409
+ macro lazymul (Typ)
410
+ ret = quote
411
+ LinearAlgebra. mul! (dest:: AbstractVector , A:: $Typ , b:: AbstractVector ) =
412
+ ArrayLayouts. mul! (dest,A,b)
413
+
414
+ LinearAlgebra. mul! (dest:: AbstractMatrix , A:: $Typ , b:: AbstractMatrix ) =
415
+ ArrayLayouts. mul! (dest,A,b)
416
+ LinearAlgebra. mul! (dest:: AbstractMatrix , A:: $Typ , b:: $Typ ) =
417
+ ArrayLayouts. mul! (dest,A,b)
418
+
419
+ Base.:* (A:: $Typ , B:: $Typ ) = ArrayLayouts. mul (A,B)
420
+ Base.:* (A:: $Typ , B:: AbstractMatrix ) = ArrayLayouts. mul (A,B)
421
+ Base.:* (A:: $Typ , B:: AbstractVector ) = ArrayLayouts. mul (A,B)
422
+ Base.:* (A:: AbstractMatrix , B:: $Typ ) = ArrayLayouts. mul (A,B)
423
+ Base.:* (A:: LinearAlgebra.AdjointAbsVec , B:: $Typ ) = ArrayLayouts. mul (A,B)
424
+ Base.:* (A:: LinearAlgebra.TransposeAbsVec , B:: $Typ ) = ArrayLayouts. mul (A,B)
425
+
426
+ Base.:* (A:: LinearAlgebra.AbstractQ , B:: $Typ ) = ArrayLayouts. lmul (A,B)
427
+ Base.:* (A:: $Typ , B:: LinearAlgebra.AbstractQ ) = ArrayLayouts. rmul (A,B)
428
+ end
429
+ for Struc in (:AbstractTriangular , :Diagonal )
430
+ ret = quote
431
+ $ ret
432
+
433
+ Base.:* (A:: LinearAlgebra. $ Struc, B:: $Typ ) = ArrayLayouts. mul (A,B)
434
+ Base.:* (A:: $Typ , B:: LinearAlgebra. $ Struc) = ArrayLayouts. mul (A,B)
435
+ end
436
+ end
437
+ for Mod in (:Adjoint , :Transpose , :Symmetric , :Hermitian )
438
+ ret = quote
439
+ $ ret
440
+
441
+ LinearAlgebra. mul! (dest:: AbstractMatrix , A:: $Typ , b:: $Mod{<:Any,<:AbstractMatrix} ) =
442
+ ArrayLayouts. mul! (dest,A,b)
443
+
444
+ LinearAlgebra. mul! (dest:: AbstractVector , A:: $Mod{<:Any,<:$Typ} , b:: AbstractVector ) =
445
+ ArrayLayouts. mul! (dest,A,b)
446
+
447
+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
448
+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: AbstractMatrix ) = ArrayLayouts. mul (A,B)
449
+ Base.:* (A:: AbstractMatrix , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
450
+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: AbstractVector ) = ArrayLayouts. mul (A,B)
451
+
452
+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: $Typ ) = ArrayLayouts. mul (A,B)
453
+ Base.:* (A:: $Typ , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
454
+
455
+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: Diagonal ) = ArrayLayouts. mul (A,B)
456
+ Base.:* (A:: Diagonal , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
457
+
458
+ Base.:* (A:: LinearAlgebra.AbstractTriangular , B:: $Mod{<:Any,<:$Typ} ) = ArrayLayouts. mul (A,B)
459
+ Base.:* (A:: $Mod{<:Any,<:$Typ} , B:: LinearAlgebra.AbstractTriangular ) = ArrayLayouts. mul (A,B)
460
+ end
461
+ end
462
+
463
+ esc (ret)
464
+ end
0 commit comments