Skip to content

Commit b46f175

Browse files
authored
Add tests and fix for SparseArray matrix multiplication (#173)
1 parent 201aafb commit b46f175

File tree

4 files changed

+77
-2
lines changed

4 files changed

+77
-2
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ julia = "1.6"
1414

1515
[extras]
1616
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
17+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718

1819
[targets]
19-
test = ["OffsetArrays"]
20+
test = ["OffsetArrays", "Random"]

src/implementations/SparseArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ function operate!(
213213
B::_SparseMat,
214214
α::Vararg{Union{T,Scaling},N},
215215
) where {T,N}
216+
rhs_constants = prod(α)
216217
_dim_check(ret, A, B)
217218
rowvalA = SparseArrays.rowvals(A)
218219
nzvalA = SparseArrays.nonzeros(A)
@@ -234,7 +235,7 @@ function operate!(
234235
ret.colptr[i] = ip0 = ip
235236
k0 = ip - 1
236237
for jp in SparseArrays.nzrange(B, i)
237-
nzB = nzvalB[jp]
238+
nzB = nzvalB[jp] * rhs_constants
238239
j = rowvalB[jp]
239240
for kp in SparseArrays.nzrange(A, j)
240241
k = rowvalA[kp]

test/SparseArrays.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) 2019 MutableArithmetics.jl contributors
2+
#
3+
# This Source Code Form is subject to the terms of the Mozilla Public License,
4+
# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain
5+
# one at http://mozilla.org/MPL/2.0/.
6+
7+
module TestInterfaceSparseArrays
8+
9+
using Test
10+
11+
import MutableArithmetics
12+
import Random
13+
import SparseArrays
14+
15+
const MA = MutableArithmetics
16+
17+
function runtests()
18+
for name in names(@__MODULE__; all = true)
19+
if startswith("$(name)", "test_")
20+
@testset "$(name)" begin
21+
getfield(@__MODULE__, name)()
22+
end
23+
end
24+
end
25+
return
26+
end
27+
28+
function test_spmatmul()
29+
Random.seed!(1234)
30+
for m in [1, 2, 3, 5, 11]
31+
for n in [1, 2, 3, 5, 11]
32+
A = SparseArrays.sprand(Float64, m, n, 0.5)
33+
B = SparseArrays.sprand(Float64, n, m, 0.5)
34+
ret = SparseArrays.spzeros(Float64, m, m)
35+
MA.operate!(MA.add_mul, ret, A, B)
36+
@test ret A * B
37+
ret = SparseArrays.spzeros(Float64, m, m)
38+
MA.operate!(MA.add_mul, ret, A, A')
39+
@test ret A * A'
40+
ret = SparseArrays.spzeros(Float64, m, m)
41+
MA.operate!(MA.add_mul, ret, A, B, 2.0)
42+
@test ret A * B * 2.0
43+
ret = SparseArrays.spzeros(Float64, m, m)
44+
MA.operate!(MA.add_mul, ret, A, B, 2.0, 1.5)
45+
@test ret A * B * 2.0 * 1.5
46+
end
47+
end
48+
return
49+
end
50+
51+
function test_spmatmul_prefer_sort()
52+
Random.seed!(1234)
53+
m = n = 100
54+
p = 0.01
55+
A = SparseArrays.sprand(Float64, m, n, p)
56+
B = SparseArrays.sprand(Float64, n, m, p)
57+
ret = SparseArrays.spzeros(Float64, m, m)
58+
MA.operate!(MA.add_mul, ret, A, B)
59+
@test ret A * B
60+
ret = SparseArrays.spzeros(Float64, m, m)
61+
MA.operate!(MA.add_mul, ret, A, B, 2.0)
62+
@test ret A * B * 2.0
63+
ret = SparseArrays.spzeros(Float64, m, m)
64+
MA.operate!(MA.add_mul, ret, A, B, 2.0, 1.5)
65+
@test ret A * B * 2.0 * 1.5
66+
return
67+
end
68+
69+
end # module
70+
71+
TestInterfaceSparseArrays.runtests()

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ include("matmul.jl")
2727
include("dispatch.jl")
2828
include("rewrite.jl")
2929

30+
include("SparseArrays.jl")
31+
3032
# It is easy to introduce macro scoping issues into MutableArithmetics,
3133
# particularly ones that rely on `MA` or `MutableArithmetics` being present in
3234
# the current scope. To work around that, include the "hygiene" script in a

0 commit comments

Comments
 (0)