Skip to content

Commit ff760fa

Browse files
authored
Fix Vec * Mat (#175)
* Fix Vec * Mat * Update muladd.jl
1 parent 424eff9 commit ff760fa

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "1.4"
4+
version = "1.4.1"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/muladd.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,23 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractMatrix, β, C::Abstr
189189
C
190190
end
191191

192+
function default_blasmul!(α, A::AbstractVector, B::AbstractMatrix, β, C::AbstractMatrix)
193+
mA, = size(A)
194+
mB, nB = size(B)
195+
1 == mB || throw(DimensionMismatch("Dimensions must match"))
196+
size(C) == (mA, nB) || throw(DimensionMismatch("Dimensions must match"))
197+
198+
lmul!(β, C)
199+
200+
(iszero(mA) || iszero(nB)) && return C
201+
202+
for k in colsupport(A), j in rowsupport(B)
203+
_default_blasmul_loop!(α, A, B, β, C, k, j)
204+
end
205+
C
206+
end
207+
208+
192209
function _default_blasmul!(::IndexLinear, α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector)
193210
mA, nA = size(A)
194211
mB = length(B)
@@ -266,6 +283,11 @@ function materialize!(M::MatMulVecAdd)
266283
default_blasmul!(α, unalias(C,A), unalias(C,B), iszero(β) ? false : β, C)
267284
end
268285

286+
function materialize!(M::VecMulMatAdd)
287+
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
288+
default_blasmul!(α, unalias(C,A), unalias(C,B), iszero(β) ? false : β, C)
289+
end
290+
269291
@inline _gemv!(tA, α, A, x, β, y) = BLAS.gemv!(tA, α, unalias(y,A), unalias(y,x), β, y)
270292
@inline _gemm!(tA, tB, α, A, B, β, C) = BLAS.gemm!(tA, tB, α, unalias(C,A), unalias(C,B), β, C)
271293

@@ -424,8 +446,7 @@ function similar(M::MulAdd{<:DualLayout,<:Any,ZerosLayout}, ::Type{T}, (x,y)) wh
424446
trans(similar(trans(M.A), T, y))
425447
end
426448

427-
function similar(M::MulAdd{<:Any,<:DualLayout,ZerosLayout}, ::Type{T}, (x,y)) where T
428-
@assert length(x) == 1
449+
function similar(M::MulAdd{ScalarLayout,<:DualLayout,ZerosLayout}, ::Type{T}, (x,y)) where T
429450
trans = transtype(M.B)
430451
trans(similar(trans(M.B), T, y))
431452
end
@@ -434,3 +455,4 @@ const ZerosLayouts = Union{ZerosLayout,DualLayout{ZerosLayout}}
434455
copy(M::MulAdd{<:ZerosLayouts, <:ZerosLayouts, <:ZerosLayouts}) = M.C
435456
copy(M::MulAdd{<:ZerosLayouts, <:Any, <:ZerosLayouts}) = M.C
436457
copy(M::MulAdd{<:Any, <:ZerosLayouts, <:ZerosLayouts}) = M.C
458+

test/test_muladd.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,4 +731,8 @@ Random.seed!(0)
731731
Y = randn(rng, 8, 2)
732732
@test mul(Y',X) Y'X
733733
end
734+
735+
@testset "Vec * Adj" begin
736+
@test ArrayLayouts.mul(1:5, (1:4)') == (1:5) * (1:4)'
737+
end
734738
end

0 commit comments

Comments
 (0)