Skip to content

Commit 1522007

Browse files
committed
Make matmul work with zero-less eltypes
1 parent b599095 commit 1522007

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

src/matmul.jl

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,21 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
599599
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
600600
end
601601

602-
_rmul_or_fill!(C, β)
602+
if (!iszero(β) || isempty(A)) # return C*beta
603+
_rmul_or_fill!(C, β)
604+
else # iszero(β) && A and B are non-empty
605+
a1 = firstindex(A, 1)
606+
a2 = firstindex(A, 2)
607+
for j in axes(C, 2), i in axes(C, 1)
608+
z1 = zero(A[i, a2]*A[a1, j] + A[i, a2]*A[a1, j])
609+
C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1)
610+
end
611+
end
612+
iszero(α) && return C
603613
@inbounds if !conjugate
604614
if aat
605615
for k 1:n, j 1:m
606-
αA_jk = A[j, k] * α
616+
αA_jk = @stable_muladdmul MulAddMul(α, false)(A[j, k])
607617
for i 1:j
608618
C[i, j] += A[i, k] * αA_jk
609619
end
@@ -614,17 +624,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
614624
for k 2:m
615625
temp += A[k, i] * A[k, j]
616626
end
617-
C[i, j] += temp * α
627+
C[i, j] += @stable_muladdmul MulAddMul(α, false)(temp)
618628
end
619629
end
620630
else
621631
if aat
622632
for k 1:n, j 1:m
623-
αA_jk_bar = conj(A[j, k]) * α
633+
αA_jk_bar = @stable_muladdmul MulAddMul(α, false)(conj(A[j, k]))
624634
for i 1:j-1
625635
C[i, j] += A[i, k] * αA_jk_bar
626636
end
627-
C[j, j] += abs2(A[j, k]) * α
637+
C[j, j] += @stable_muladdmul MulAddMul(α, false)(abs2(A[j, k]))
628638
end
629639
else
630640
for j 1:n
@@ -633,13 +643,13 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
633643
for k 2:m
634644
temp += conj(A[k, i]) * A[k, j]
635645
end
636-
C[i, j] += temp * α
646+
C[i, j] += @stable_muladdmul MulAddMul(α, false)(temp)
637647
end
638648
temp = abs2(A[1, j])
639649
for k 2:m
640650
temp += abs2(A[k, j])
641651
end
642-
C[j, j] += temp * α
652+
C[j, j] += @stable_muladdmul MulAddMul(α, false)(temp)
643653
end
644654
end
645655
end
@@ -1132,8 +1142,18 @@ __generic_matmatmul!(C, A, B, alpha, beta, ::Val{true}) = _generic_matmatmul_non
11321142
__generic_matmatmul!(C, A, B, alpha, beta, ::Val{false}) = _generic_matmatmul_generic!(C, A, B, alpha, beta)
11331143

11341144
function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
1135-
_rmul_or_fill!(C, beta)
1136-
(iszero(alpha) || isempty(A) || isempty(B)) && return C
1145+
# _rmul_or_fill!(C, beta) spelled out more carefully to allow for zero-less eltypes
1146+
if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta
1147+
_rmul_or_fill!(C, beta)
1148+
else # iszero(beta) && A and B are non-empty
1149+
a1 = firstindex(A, 2)
1150+
b1 = firstindex(B, 1)
1151+
for j in axes(C, 2), i in axes(C, 1)
1152+
z1 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j])
1153+
C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1)
1154+
end
1155+
end
1156+
iszero(alpha) && return C
11371157
@inbounds for n in axes(B, 2), k in axes(B, 1)
11381158
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
11391159
Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n])
@@ -1145,20 +1165,37 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
11451165
C
11461166
end
11471167
function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta)
1148-
_rmul_or_fill!(C, beta)
1149-
(iszero(alpha) || isempty(A) || isempty(B)) && return C
1168+
if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta
1169+
_rmul_or_fill!(C, beta)
1170+
else # iszero(beta) && A and B are non-empty
1171+
a1 = firstindex(A, 2)
1172+
b1 = firstindex(B, 1)
1173+
for j in axes(C, 2), i in axes(C, 1)
1174+
z1 = zero(A[i, a1]*B[b1, j] + A[i, a1]*B[b1, j])
1175+
C[i,j] = convert(promote_type(typeof(z1), eltype(C)), z1)
1176+
end
1177+
end
1178+
iszero(alpha) && return C
11501179
t = _wrapperop(A)
11511180
pB = parent(B)
11521181
pA = parent(A)
11531182
tmp = similar(C, axes(C, 2))
11541183
ci = firstindex(C, 1)
11551184
ta = t(alpha)
1156-
for i in axes(A, 1)
1157-
mul!(tmp, pB, view(pA, :, i))
1158-
@views C[ci,:] .+= t.(ta .* tmp)
1159-
ci += 1
1185+
if isone(ta)
1186+
for i in axes(A, 1)
1187+
mul!(tmp, pB, view(pA, :, i))
1188+
@views C[ci,:] .+= t.(tmp)
1189+
ci += 1
1190+
end
1191+
else
1192+
for i in axes(A, 1)
1193+
mul!(tmp, pB, view(pA, :, i))
1194+
@views C[ci,:] .+= t.(ta .* tmp)
1195+
ci += 1
1196+
end
11601197
end
1161-
C
1198+
return C
11621199
end
11631200
function _generic_matmatmul_generic!(C, A, B, alpha, beta)
11641201
if iszero(alpha) || isempty(A) || isempty(B)

test/matmul.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,4 +1241,29 @@ end
12411241
@test C1 C2
12421242
end
12431243

1244+
@testset "matmul with zero-less types" begin
1245+
struct Mod <: Real
1246+
val::Int
1247+
modulo::Int
1248+
Mod(x::Int, y::Int) = new(x % y, y)
1249+
end
1250+
1251+
Base.:+(x::Mod, y::Mod) = Mod(x.val + y.val, x.modulo)
1252+
Base.:*(x::Mod, y::Mod) = Mod(x.val * y.val, x.modulo)
1253+
Base.zero(x::Mod) = Mod(0, x.modulo)
1254+
1255+
m = Mod.(rand(0:19, 5, 0), 20)
1256+
@test_throws MethodError m * copy(m')
1257+
for n in (2, 3, 5)
1258+
A = rand(0:19, n, n)
1259+
M = Mod.(A, 20)
1260+
@test M * M == Mod.(A * A, 20)
1261+
@test M' * M == Mod.(A' * A, 20)
1262+
@test M * M' == Mod.(A * A', 20)
1263+
@test M' * M' == Mod.(A' * A', 20)
1264+
@test M * M[:, 1] == Mod.(A * A[:, 1], 20)
1265+
@test M' * M[:, 1] == Mod.(A' * A[:, 1], 20)
1266+
end
1267+
end
1268+
12441269
end # module TestMatmul

0 commit comments

Comments
 (0)