Skip to content

Commit 6beb32c

Browse files
authored
Make matmul work with zero-less eltypes (#1488)
1 parent b599095 commit 6beb32c

File tree

2 files changed

+80
-16
lines changed

2 files changed

+80
-16
lines changed

src/matmul.jl

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
599599
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
600600
end
601601

602-
_rmul_or_fill!(C, β)
602+
if (!iszero(β) || isempty(A)) # return C*beta
603+
_rmul_or_fill!(C, β)
604+
else # iszero(β) && A is non-empty
605+
aA_11 = abs2(A[1,1])
606+
fill!(UpperTriangular(C), zero(aA_11 + aA_11))
607+
end
608+
iszero(α) && return C
603609
@inbounds if !conjugate
604610
if aat
605611
for k 1:n, j 1:m
606-
αA_jk = A[j, k] * α
612+
αA_jk = @stable_muladdmul MulAddMul(α, false)(A[j, k])
607613
for i 1:j
608614
C[i, j] += A[i, k] * αA_jk
609615
end
@@ -614,17 +620,17 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
614620
for k 2:m
615621
temp += A[k, i] * A[k, j]
616622
end
617-
C[i, j] += temp * α
623+
C[i, j] += @stable_muladdmul MulAddMul(α, false)(temp)
618624
end
619625
end
620626
else
621627
if aat
622628
for k 1:n, j 1:m
623-
αA_jk_bar = conj(A[j, k]) * α
629+
αA_jk_bar = @stable_muladdmul MulAddMul(α, false)(conj(A[j, k]))
624630
for i 1:j-1
625631
C[i, j] += A[i, k] * αA_jk_bar
626632
end
627-
C[j, j] += abs2(A[j, k]) * α
633+
C[j, j] += @stable_muladdmul MulAddMul(α, false)(abs2(A[j, k]))
628634
end
629635
else
630636
for j 1:n
@@ -633,13 +639,13 @@ function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bo
633639
for k 2:m
634640
temp += conj(A[k, i]) * A[k, j]
635641
end
636-
C[i, j] += temp * α
642+
C[i, j] += @stable_muladdmul MulAddMul(α, false)(temp)
637643
end
638644
temp = abs2(A[1, j])
639645
for k 2:m
640646
temp += abs2(A[k, j])
641647
end
642-
C[j, j] += temp * α
648+
C[j, j] += @stable_muladdmul MulAddMul(α, false)(temp)
643649
end
644650
end
645651
end
@@ -1132,8 +1138,21 @@ __generic_matmatmul!(C, A, B, alpha, beta, ::Val{true}) = _generic_matmatmul_non
11321138
__generic_matmatmul!(C, A, B, alpha, beta, ::Val{false}) = _generic_matmatmul_generic!(C, A, B, alpha, beta)
11331139

11341140
function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
1135-
_rmul_or_fill!(C, beta)
1136-
(iszero(alpha) || isempty(A) || isempty(B)) && return C
1141+
# _rmul_or_fill!(C, beta) spelled out more carefully to allow for zero-less eltypes
1142+
if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta
1143+
_rmul_or_fill!(C, beta)
1144+
else # iszero(beta) && A and B are non-empty
1145+
a1 = firstindex(A, 2)
1146+
b1 = firstindex(B, 1)
1147+
for j in axes(C, 2)
1148+
B_1j = B[b1, j]
1149+
for i in axes(C, 1)
1150+
C_ij = A[i, a1] * B_1j
1151+
C[i,j] = zero(C_ij + C_ij)
1152+
end
1153+
end
1154+
end
1155+
iszero(alpha) && return C
11371156
@inbounds for n in axes(B, 2), k in axes(B, 1)
11381157
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
11391158
Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n])
@@ -1145,20 +1164,40 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
11451164
C
11461165
end
11471166
function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta)
1148-
_rmul_or_fill!(C, beta)
1149-
(iszero(alpha) || isempty(A) || isempty(B)) && return C
11501167
t = _wrapperop(A)
11511168
pB = parent(B)
11521169
pA = parent(A)
1170+
if (!iszero(beta) || isempty(A) || isempty(B)) # return C*beta
1171+
_rmul_or_fill!(C, beta)
1172+
else # iszero(beta) && A and B are non-empty
1173+
a1 = firstindex(pA, 1)
1174+
b1 = firstindex(pB, 2)
1175+
for j in axes(C, 2)
1176+
tB_1j = t(pB[j, b1])
1177+
for i in axes(C, 1)
1178+
C_ij = t(pA[a1, i]) * tB_1j
1179+
C[i,j] = zero(C_ij + C_ij)
1180+
end
1181+
end
1182+
end
1183+
iszero(alpha) && return C
11531184
tmp = similar(C, axes(C, 2))
11541185
ci = firstindex(C, 1)
11551186
ta = t(alpha)
1156-
for i in axes(A, 1)
1157-
mul!(tmp, pB, view(pA, :, i))
1158-
@views C[ci,:] .+= t.(ta .* tmp)
1159-
ci += 1
1187+
if isone(ta)
1188+
for i in axes(A, 1)
1189+
mul!(tmp, pB, view(pA, :, i))
1190+
@views C[ci,:] .+= t.(tmp)
1191+
ci += 1
1192+
end
1193+
else
1194+
for i in axes(A, 1)
1195+
mul!(tmp, pB, view(pA, :, i))
1196+
@views C[ci,:] .+= t.(ta .* tmp)
1197+
ci += 1
1198+
end
11601199
end
1161-
C
1200+
return C
11621201
end
11631202
function _generic_matmatmul_generic!(C, A, B, alpha, beta)
11641203
if iszero(alpha) || isempty(A) || isempty(B)

test/matmul.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,4 +1241,29 @@ end
12411241
@test C1 C2
12421242
end
12431243

1244+
@testset "matmul with zero-less types" begin
1245+
struct Mod <: Real
1246+
val::Int
1247+
modulo::Int
1248+
Mod(x::Int, y::Int) = new(x % y, y)
1249+
end
1250+
1251+
Base.:+(x::Mod, y::Mod) = Mod(x.val + y.val, x.modulo)
1252+
Base.:*(x::Mod, y::Mod) = Mod(x.val * y.val, x.modulo)
1253+
Base.zero(x::Mod) = Mod(0, x.modulo)
1254+
1255+
m = Mod.(rand(0:19, 5, 0), 20)
1256+
@test_throws MethodError m * copy(m')
1257+
for n in (2, 3, 5)
1258+
A = rand(0:19, n, n)
1259+
M = Mod.(A, 20)
1260+
@test M * M == Mod.(A * A, 20)
1261+
@test M' * M == Mod.(A' * A, 20)
1262+
@test M * M' == Mod.(A * A', 20)
1263+
@test M' * M' == Mod.(A' * A', 20)
1264+
@test M * M[:, 1] == Mod.(A * A[:, 1], 20)
1265+
@test M' * M[:, 1] == Mod.(A' * A[:, 1], 20)
1266+
end
1267+
end
1268+
12441269
end # module TestMatmul

0 commit comments

Comments
 (0)