Skip to content

Commit b59a50f

Browse files
claude[bot]github-actions[bot]claude
authored
Add challenge 75: Sparse Matrix-Dense Matrix Multiplication (SpMM) (Medium) (#198)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent b978380 commit b59a50f

File tree

8 files changed

+377
-0
lines changed

8 files changed

+377
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
<p>
2+
Implement a GPU program that multiplies a sparse matrix <code>A</code> of dimensions <code>M</code> &times; <code>N</code>
3+
by a dense matrix <code>B</code> of dimensions <code>N</code> &times; <code>K</code>, producing a dense output matrix
4+
<code>C</code> of dimensions <code>M</code> &times; <code>K</code>.
5+
All matrices are stored in row-major order using 32-bit floats.
6+
The matrix <code>A</code> is approximately 60&ndash;70% sparse (i.e., 60&ndash;70% of elements are zero),
7+
and <code>nnz</code> gives the number of non-zero elements in <code>A</code>.
8+
</p>
9+
10+
<p>
11+
Mathematically, the operation is defined as:
12+
\[
13+
C_{ij} = \sum_{k=0}^{N-1} A_{ik} \cdot B_{kj} \quad \text{for} \quad i = 0, \ldots, M-1,\; j = 0, \ldots, K-1
14+
\]
15+
</p>
16+
17+
<h2>Implementation Requirements</h2>
18+
<ul>
19+
<li>Use only GPU native features (external libraries are not permitted)</li>
20+
<li>The <code>solve</code> function signature must remain unchanged</li>
21+
<li>The final result must be stored in matrix <code>C</code></li>
22+
</ul>
23+
24+
<h2>Example</h2>
25+
<p>
26+
Input:<br>
27+
Matrix \(A\) (\(3 \times 4\)):
28+
\[
29+
\begin{bmatrix}
30+
2.0 & 0.0 & 0.0 & 1.0 \\
31+
0.0 & 3.0 & 0.0 & 0.0 \\
32+
0.0 & 0.0 & 4.0 & 0.0
33+
\end{bmatrix}
34+
\]
35+
Matrix \(B\) (\(4 \times 2\)):
36+
\[
37+
\begin{bmatrix}
38+
1.0 & 2.0 \\
39+
3.0 & 4.0 \\
40+
5.0 & 6.0 \\
41+
7.0 & 8.0
42+
\end{bmatrix}
43+
\]
44+
Output:<br>
45+
Matrix \(C\) (\(3 \times 2\)):
46+
\[
47+
\begin{bmatrix}
48+
9.0 & 12.0 \\
49+
9.0 & 12.0 \\
50+
20.0 & 24.0
51+
\end{bmatrix}
52+
\]
53+
</p>
54+
55+
<h2>Constraints</h2>
56+
<ul>
57+
<li>1 &le; <code>M</code>, <code>N</code>, <code>K</code> &le; 8,192</li>
58+
<li>All values in <code>A</code> and <code>B</code> are 32-bit floats in the range [&minus;10, 10]</li>
59+
<li>The matrix <code>A</code> is approximately 60&ndash;70% sparse</li>
60+
<li>Performance is measured with <code>M</code> = 4,096, <code>N</code> = 2,048, <code>K</code> = 512</li>
61+
</ul>
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import ctypes
2+
from typing import Any, Dict, List
3+
4+
import torch
5+
from core.challenge_base import ChallengeBase
6+
7+
8+
class Challenge(ChallengeBase):
9+
def __init__(self):
10+
super().__init__(
11+
name="Sparse Matrix-Dense Matrix Multiplication (SpMM)",
12+
atol=1e-03,
13+
rtol=1e-03,
14+
num_gpus=1,
15+
access_tier="free",
16+
)
17+
18+
def reference_impl(
19+
self,
20+
A: torch.Tensor,
21+
B: torch.Tensor,
22+
C: torch.Tensor,
23+
M: int,
24+
N: int,
25+
K: int,
26+
nnz: int,
27+
):
28+
if A.shape == (M * N,):
29+
A_matrix = A.view(M, N)
30+
elif A.shape == (M, N):
31+
A_matrix = A
32+
else:
33+
raise AssertionError(
34+
f"A.shape {A.shape} does not match expected {(M * N,)} or {(M, N)}"
35+
)
36+
if B.shape == (N * K,):
37+
B_matrix = B.view(N, K)
38+
elif B.shape == (N, K):
39+
B_matrix = B
40+
else:
41+
raise AssertionError(
42+
f"B.shape {B.shape} does not match expected {(N * K,)} or {(N, K)}"
43+
)
44+
assert C.shape == (M, K) or C.shape == (
45+
M * K,
46+
), f"C.shape {C.shape} does not match expected {(M, K)} or {(M * K,)}"
47+
assert A_matrix.dtype == torch.float32
48+
assert B_matrix.dtype == torch.float32
49+
assert A_matrix.device.type == "cuda"
50+
assert B_matrix.device.type == "cuda"
51+
assert C.device.type == "cuda"
52+
result = torch.matmul(A_matrix, B_matrix)
53+
C.copy_(result.view(C.shape))
54+
55+
def get_solve_signature(self) -> Dict[str, tuple]:
56+
return {
57+
"A": (ctypes.POINTER(ctypes.c_float), "in"),
58+
"B": (ctypes.POINTER(ctypes.c_float), "in"),
59+
"C": (ctypes.POINTER(ctypes.c_float), "out"),
60+
"M": (ctypes.c_int, "in"),
61+
"N": (ctypes.c_int, "in"),
62+
"K": (ctypes.c_int, "in"),
63+
"nnz": (ctypes.c_int, "in"),
64+
}
65+
66+
def generate_example_test(self) -> Dict[str, Any]:
67+
dtype = torch.float32
68+
A = torch.tensor(
69+
[
70+
[2.0, 0.0, 0.0, 1.0],
71+
[0.0, 3.0, 0.0, 0.0],
72+
[0.0, 0.0, 4.0, 0.0],
73+
],
74+
device="cuda",
75+
dtype=dtype,
76+
)
77+
B = torch.tensor(
78+
[
79+
[1.0, 2.0],
80+
[3.0, 4.0],
81+
[5.0, 6.0],
82+
[7.0, 8.0],
83+
],
84+
device="cuda",
85+
dtype=dtype,
86+
)
87+
C = torch.empty((3, 2), device="cuda", dtype=dtype)
88+
return {
89+
"A": A,
90+
"B": B,
91+
"C": C,
92+
"M": 3,
93+
"N": 4,
94+
"K": 2,
95+
"nnz": 4,
96+
}
97+
98+
def generate_functional_test(self) -> List[Dict[str, Any]]:
99+
dtype = torch.float32
100+
tests = []
101+
102+
# edge_1x1x1
103+
tests.append(
104+
{
105+
"A": torch.tensor([[3.0]], device="cuda", dtype=dtype),
106+
"B": torch.tensor([[2.0]], device="cuda", dtype=dtype),
107+
"C": torch.empty((1, 1), device="cuda", dtype=dtype),
108+
"M": 1,
109+
"N": 1,
110+
"K": 1,
111+
"nnz": 1,
112+
}
113+
)
114+
115+
# edge_2x2_k1_spmv_like
116+
tests.append(
117+
{
118+
"A": torch.tensor([[1.0, 0.0], [0.0, 2.0]], device="cuda", dtype=dtype),
119+
"B": torch.tensor([[3.0], [4.0]], device="cuda", dtype=dtype),
120+
"C": torch.empty((2, 1), device="cuda", dtype=dtype),
121+
"M": 2,
122+
"N": 2,
123+
"K": 1,
124+
"nnz": 2,
125+
}
126+
)
127+
128+
# edge_zero_matrix
129+
tests.append(
130+
{
131+
"A": torch.zeros((3, 3), device="cuda", dtype=dtype),
132+
"B": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device="cuda", dtype=dtype),
133+
"C": torch.empty((3, 2), device="cuda", dtype=dtype),
134+
"M": 3,
135+
"N": 3,
136+
"K": 2,
137+
"nnz": 0,
138+
}
139+
)
140+
141+
# edge_identity_a
142+
tests.append(
143+
{
144+
"A": torch.eye(4, device="cuda", dtype=dtype),
145+
"B": torch.tensor(
146+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
147+
device="cuda",
148+
dtype=dtype,
149+
),
150+
"C": torch.empty((4, 3), device="cuda", dtype=dtype),
151+
"M": 4,
152+
"N": 4,
153+
"K": 3,
154+
"nnz": 4,
155+
}
156+
)
157+
158+
# power_of_2_16x16x8
159+
M, N, K = 16, 16, 8
160+
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-2.0, 2.0)
161+
mask = torch.rand((M, N), device="cuda") > 0.65
162+
A_sparse = A_dense * mask
163+
tests.append(
164+
{
165+
"A": A_sparse,
166+
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
167+
"C": torch.empty((M, K), device="cuda", dtype=dtype),
168+
"M": M,
169+
"N": N,
170+
"K": K,
171+
"nnz": int(mask.sum().item()),
172+
}
173+
)
174+
175+
# power_of_2_64x32x16
176+
M, N, K = 64, 32, 16
177+
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-3.0, 3.0)
178+
mask = torch.rand((M, N), device="cuda") > 0.70
179+
A_sparse = A_dense * mask
180+
tests.append(
181+
{
182+
"A": A_sparse,
183+
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
184+
"C": torch.empty((M, K), device="cuda", dtype=dtype),
185+
"M": M,
186+
"N": N,
187+
"K": K,
188+
"nnz": int(mask.sum().item()),
189+
}
190+
)
191+
192+
# non_power_of_2_negative_values
193+
M, N, K = 30, 50, 20
194+
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-5.0, 5.0)
195+
mask = torch.rand((M, N), device="cuda") > 0.65
196+
A_sparse = A_dense * mask
197+
tests.append(
198+
{
199+
"A": A_sparse,
200+
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-3.0, 3.0),
201+
"C": torch.empty((M, K), device="cuda", dtype=dtype),
202+
"M": M,
203+
"N": N,
204+
"K": K,
205+
"nnz": int(mask.sum().item()),
206+
}
207+
)
208+
209+
# non_power_of_2_255x100x33
210+
M, N, K = 255, 100, 33
211+
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-2.0, 2.0)
212+
mask = torch.rand((M, N), device="cuda") > 0.70
213+
A_sparse = A_dense * mask
214+
tests.append(
215+
{
216+
"A": A_sparse,
217+
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
218+
"C": torch.empty((M, K), device="cuda", dtype=dtype),
219+
"M": M,
220+
"N": N,
221+
"K": K,
222+
"nnz": int(mask.sum().item()),
223+
}
224+
)
225+
226+
# realistic_1000x500x64
227+
M, N, K = 1000, 500, 64
228+
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-1.0, 1.0)
229+
mask = torch.rand((M, N), device="cuda") > 0.65
230+
A_sparse = A_dense * mask
231+
tests.append(
232+
{
233+
"A": A_sparse,
234+
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
235+
"C": torch.empty((M, K), device="cuda", dtype=dtype),
236+
"M": M,
237+
"N": N,
238+
"K": K,
239+
"nnz": int(mask.sum().item()),
240+
}
241+
)
242+
243+
return tests
244+
245+
def generate_performance_test(self) -> Dict[str, Any]:
246+
dtype = torch.float32
247+
M = 4096
248+
N = 2048
249+
K = 512
250+
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-1.0, 1.0)
251+
mask = torch.rand((M, N), device="cuda") > 0.65
252+
A_sparse = A_dense * mask
253+
nnz = int(mask.sum().item())
254+
B = torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0)
255+
C = torch.empty((M, K), device="cuda", dtype=dtype)
256+
return {
257+
"A": A_sparse,
258+
"B": B,
259+
"C": C,
260+
"M": M,
261+
"N": N,
262+
"K": K,
263+
"nnz": nnz,
264+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include <cuda_runtime.h>
2+
3+
// A, B, C are device pointers
4+
extern "C" void solve(const float* A, const float* B, float* C, int M, int N, int K, int nnz) {}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# A, B, C are tensors on the GPU
6+
@cute.jit
7+
def solve(
8+
A: cute.Tensor,
9+
B: cute.Tensor,
10+
C: cute.Tensor,
11+
M: cute.Int32,
12+
N: cute.Int32,
13+
K: cute.Int32,
14+
nnz: cute.Int32,
15+
):
16+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
# A, B are tensors on GPU
6+
@jax.jit
7+
def solve(A: jax.Array, B: jax.Array, M: int, N: int, K: int, nnz: int) -> jax.Array:
8+
# return output tensor directly
9+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from gpu.host import DeviceContext
2+
from gpu.id import block_dim, block_idx, thread_idx
3+
from memory import UnsafePointer
4+
from math import ceildiv
5+
6+
# A, B, C are device pointers
7+
@export
8+
def solve(A: UnsafePointer[Float32], B: UnsafePointer[Float32], C: UnsafePointer[Float32], M: Int32, N: Int32, K: Int32, nnz: Int32):
9+
pass
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import torch
2+
3+
4+
# A, B, C are tensors on the GPU
5+
def solve(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, M: int, N: int, K: int, nnz: int):
6+
pass
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
# A, B, C are tensors on the GPU
7+
def solve(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, M: int, N: int, K: int, nnz: int):
8+
pass

0 commit comments

Comments
 (0)