diff --git a/src/matmul.jl b/src/matmul.jl index 2b2f7d81..d618bcfe 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -1072,11 +1072,13 @@ function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::Abstract C[i] = zero(A[i]*B[1] + A[i]*B[1]) end end - for k = eachindex(B) - aoffs = (k-1)*Astride - b = @stable_muladdmul MulAddMul(alpha,false)(B[k]) - for i = eachindex(C) - C[i] += A[aoffs + i] * b + if !iszero(alpha) + for k = eachindex(B) + aoffs = (k-1)*Astride + b = @stable_muladdmul MulAddMul(alpha,false)(B[k]) + for i = eachindex(C) + C[i] += A[aoffs + i] * b + end end end end diff --git a/test/matmul.jl b/test/matmul.jl index c2ce312f..1fcf2009 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -980,11 +980,23 @@ Base.:*(x::Float64, a::A32092) = x * a.x end @testset "strong zero" begin - @testset for α in Any[false, 0.0, 0], n in 1:4 - C = ones(n, n) - A = fill!(zeros(n, n), NaN) - B = ones(n, n) + @testset for α in Any[false, 0.0, 0], n in 1:4, T in (Float16, Float64) + C = ones(T, n) + A = fill(T(NaN), n, n) + B = ones(T, n) @test mul!(copy(C), A, B, α, 1.0) == C + C = ones(T, n, n) + B = ones(T, n, n) + @test mul!(copy(C), A, B, α, 1.0) == C + end + @testset for α in Any[false, 0.0, 0], β in Any[false, 0.0, 0], n in 1:4, T in (Float16, Float64) + C = fill(T(NaN), n) + A = fill(T(NaN), n, n) + B = fill(T(NaN), n) + @test iszero(mul!(copy(C), A, B, α, β)) + C = fill(T(NaN), n, n) + B = fill(T(NaN), n, n) + @test iszero(mul!(copy(C), A, B, α, β)) end end