Skip to content

Commit 21d9f2c

Browse files
committed
Fix unitful 3-arg *
1 parent 189d49d commit 21d9f2c

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

src/matmul.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
605605
aA_11 = abs2(A[1,1])
606606
fill!(UpperTriangular(C), zero(aA_11 + aA_11))
607607
end
608-
iszero(α) && return C
608+
(iszero) || isempty(A)) && return C
609609
@inbounds if !conjugate
610610
if aat
611611
for k 1:n, j 1:m
@@ -1075,7 +1075,8 @@ function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::Abstract
10751075
elseif length(B) == 0
10761076
C[i] = zero(eltype(C))
10771077
else
1078-
C[i] = zero(A[i]*B[1] + A[i]*B[1])
1078+
ci = @stable_muladdmul MulAddMul(alpha,false)(A[i]*B[1])
1079+
C[i] = zero(ci + ci)
10791080
end
10801081
end
10811082
if !iszero(alpha)
@@ -1147,12 +1148,12 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
11471148
for j in axes(C, 2)
11481149
B_1j = B[b1, j]
11491150
for i in axes(C, 1)
1150-
C_ij = A[i, a1] * B_1j
1151+
C_ij = @stable_muladdmul MulAddMul(alpha, false)(A[i, a1] * B_1j)
11511152
C[i,j] = zero(C_ij + C_ij)
11521153
end
11531154
end
11541155
end
1155-
iszero(alpha) && return C
1156+
(iszero(alpha) || isempty(A) || isempty(B)) && return C
11561157
@inbounds for n in axes(B, 2), k in axes(B, 1)
11571158
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
11581159
Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n])
@@ -1167,21 +1168,21 @@ function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta)
11671168
t = _wrapperop(A)
11681169
pB = parent(B)
11691170
pA = parent(A)
1170-
if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta
1171+
if (!iszero(beta) || isempty(A) || isempty(B))
11711172
_rmul_or_fill!(C, beta)
11721173
else # iszero(beta) && A and B are non-empty
11731174
a1 = firstindex(pA, 1)
11741175
b1 = firstindex(pB, 2)
11751176
for j in axes(C, 2)
11761177
tB_1j = t(pB[j, b1])
11771178
for i in axes(C, 1)
1178-
C_ij = t(pA[a1, i]) * tB_1j
1179+
C_ij = @stable_muladdmul MulAddMul(alpha, false)(t(pA[a1, i]) * tB_1j)
11791180
C[i,j] = zero(C_ij + C_ij)
11801181
end
11811182
end
11821183
end
1183-
iszero(alpha) && return C
1184-
tmp = similar(C, axes(C, 2))
1184+
(iszero(alpha) || isempty(A) || isempty(B)) && return C
1185+
tmp = similar(C, promote_op(matprod, typeof(first(A)), typeof(first(B))), axes(C, 2))
11851186
ci = firstindex(C, 1)
11861187
ta = t(alpha)
11871188
if isone(ta)
@@ -1434,7 +1435,7 @@ mat_vec_scalar(A::StridedMaybeAdjOrTransMat, x::StridedVector, γ) = _mat_vec_sc
14341435
mat_vec_scalar(A::AdjOrTransAbsVec, x::StridedVector, γ) = (A * x) * γ
14351436

14361437
function _mat_vec_scalar(A, x, γ)
1437-
T = promote_type(eltype(A), eltype(x), typeof(γ))
1438+
T = promote_op(*, promote_op(matprod, eltype(A), eltype(x)), typeof(γ))
14381439
C = similar(A, T, axes(A,1))
14391440
mul!(C, A, x, γ, false)
14401441
end
@@ -1444,7 +1445,7 @@ mat_mat_scalar(A::StridedMaybeAdjOrTransMat, B::StridedMaybeAdjOrTransMat, γ) =
14441445
_mat_mat_scalar(A, B, γ)
14451446

14461447
function _mat_mat_scalar(A, B, γ)
1447-
T = promote_type(eltype(A), eltype(B), typeof(γ))
1448+
T = promote_op(*, promote_op(matprod, eltype(A), eltype(B)), typeof(γ))
14481449
C = similar(A, T, axes(A,1), axes(B,2))
14491450
mul!(C, A, B, γ, false)
14501451
end

test/unitful.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,21 @@ end
202202
@test (C \ b)::Vector{<:Furlong{0}} == (D \ b)::Vector{<:Furlong{0}} == Furlong{0}.([5, -12])
203203
end
204204

205+
@testset "unitful 3-arg *" begin
206+
for n in (2, 3, 5)
207+
λ = 5
208+
A = randn(-10:10, n, n)
209+
b = randn(-10:10, n)
210+
λu = Furlong{1}(λ)
211+
Au = Furlong{1}.(A)
212+
bu = Furlong{1}.(b)
213+
@test Furlong{3}.(A*A*λ) == Au*Au*λu
214+
@test Furlong{3}.(A'*A*λ) == Au'*Au*λu
215+
@test Furlong{3}.(A*A'*λ) == Au*Au'*λu
216+
@test Furlong{3}.(A'*A'*λ) == Au'*Au'*λu
217+
@test Furlong{3}.(A*b*λ) == Au*bu*λu
218+
@test Furlong{3}.(A'*b*λ) == Au'*bu*λu
219+
end
220+
end
221+
205222
end # module TestUnitfulLinAlg

0 commit comments

Comments
 (0)