Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
<p>
Implement a GPU program that multiplies a sparse matrix <code>A</code> of dimensions <code>M</code> &times; <code>N</code>
by a dense matrix <code>B</code> of dimensions <code>N</code> &times; <code>K</code>, producing a dense output matrix
<code>C</code> of dimensions <code>M</code> &times; <code>K</code>.
All matrices are stored in row-major order using 32-bit floats.
The matrix <code>A</code> is approximately 60&ndash;70% sparse (i.e., 60&ndash;70% of elements are zero),
and <code>nnz</code> gives the number of non-zero elements in <code>A</code>.
</p>

<p>
Mathematically, the operation is defined as:
\[
C_{ij} = \sum_{k=0}^{N-1} A_{ik} \cdot B_{kj} \quad \text{for} \quad i = 0, \ldots, M-1,\; j = 0, \ldots, K-1
\]
</p>

<h2>Implementation Requirements</h2>
<ul>
<li>Use only GPU native features (external libraries are not permitted)</li>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>The final result must be stored in matrix <code>C</code></li>
</ul>

<h2>Example</h2>
<p>
Input:<br>
Matrix \(A\) (\(3 \times 4\)):
\[
\begin{bmatrix}
2.0 & 0.0 & 0.0 & 1.0 \\
0.0 & 3.0 & 0.0 & 0.0 \\
0.0 & 0.0 & 4.0 & 0.0
\end{bmatrix}
\]
Matrix \(B\) (\(4 \times 2\)):
\[
\begin{bmatrix}
1.0 & 2.0 \\
3.0 & 4.0 \\
5.0 & 6.0 \\
7.0 & 8.0
\end{bmatrix}
\]
Output:<br>
Matrix \(C\) (\(3 \times 2\)):
\[
\begin{bmatrix}
9.0 & 12.0 \\
9.0 & 12.0 \\
20.0 & 24.0
\end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>M</code>, <code>N</code>, <code>K</code> &le; 8,192</li>
<li>All values in <code>A</code> and <code>B</code> are 32-bit floats in the range [&minus;10, 10]</li>
<li>The matrix <code>A</code> is approximately 60&ndash;70% sparse</li>
<li>Performance is measured with <code>M</code> = 4,096, <code>N</code> = 2,048, <code>K</code> = 512</li>
</ul>
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Sparse Matrix-Dense Matrix Multiplication (SpMM)",
atol=1e-03,
rtol=1e-03,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
M: int,
N: int,
K: int,
nnz: int,
):
if A.shape == (M * N,):
A_matrix = A.view(M, N)
elif A.shape == (M, N):
A_matrix = A
else:
raise AssertionError(
f"A.shape {A.shape} does not match expected {(M * N,)} or {(M, N)}"
)
if B.shape == (N * K,):
B_matrix = B.view(N, K)
elif B.shape == (N, K):
B_matrix = B
else:
raise AssertionError(
f"B.shape {B.shape} does not match expected {(N * K,)} or {(N, K)}"
)
assert C.shape == (M, K) or C.shape == (
M * K,
), f"C.shape {C.shape} does not match expected {(M, K)} or {(M * K,)}"
assert A_matrix.dtype == torch.float32
assert B_matrix.dtype == torch.float32
assert A_matrix.device.type == "cuda"
assert B_matrix.device.type == "cuda"
assert C.device.type == "cuda"
result = torch.matmul(A_matrix, B_matrix)
C.copy_(result.view(C.shape))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"A": (ctypes.POINTER(ctypes.c_float), "in"),
"B": (ctypes.POINTER(ctypes.c_float), "in"),
"C": (ctypes.POINTER(ctypes.c_float), "out"),
"M": (ctypes.c_int, "in"),
"N": (ctypes.c_int, "in"),
"K": (ctypes.c_int, "in"),
"nnz": (ctypes.c_int, "in"),
}

def generate_example_test(self) -> Dict[str, Any]:
dtype = torch.float32
A = torch.tensor(
[
[2.0, 0.0, 0.0, 1.0],
[0.0, 3.0, 0.0, 0.0],
[0.0, 0.0, 4.0, 0.0],
],
device="cuda",
dtype=dtype,
)
B = torch.tensor(
[
[1.0, 2.0],
[3.0, 4.0],
[5.0, 6.0],
[7.0, 8.0],
],
device="cuda",
dtype=dtype,
)
C = torch.empty((3, 2), device="cuda", dtype=dtype)
return {
"A": A,
"B": B,
"C": C,
"M": 3,
"N": 4,
"K": 2,
"nnz": 4,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
dtype = torch.float32
tests = []

# edge_1x1x1
tests.append(
{
"A": torch.tensor([[3.0]], device="cuda", dtype=dtype),
"B": torch.tensor([[2.0]], device="cuda", dtype=dtype),
"C": torch.empty((1, 1), device="cuda", dtype=dtype),
"M": 1,
"N": 1,
"K": 1,
"nnz": 1,
}
)

# edge_2x2_k1_spmv_like
tests.append(
{
"A": torch.tensor([[1.0, 0.0], [0.0, 2.0]], device="cuda", dtype=dtype),
"B": torch.tensor([[3.0], [4.0]], device="cuda", dtype=dtype),
"C": torch.empty((2, 1), device="cuda", dtype=dtype),
"M": 2,
"N": 2,
"K": 1,
"nnz": 2,
}
)

# edge_zero_matrix
tests.append(
{
"A": torch.zeros((3, 3), device="cuda", dtype=dtype),
"B": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device="cuda", dtype=dtype),
"C": torch.empty((3, 2), device="cuda", dtype=dtype),
"M": 3,
"N": 3,
"K": 2,
"nnz": 0,
}
)

# edge_identity_a
tests.append(
{
"A": torch.eye(4, device="cuda", dtype=dtype),
"B": torch.tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
device="cuda",
dtype=dtype,
),
"C": torch.empty((4, 3), device="cuda", dtype=dtype),
"M": 4,
"N": 4,
"K": 3,
"nnz": 4,
}
)

# power_of_2_16x16x8
M, N, K = 16, 16, 8
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-2.0, 2.0)
mask = torch.rand((M, N), device="cuda") > 0.65
A_sparse = A_dense * mask
tests.append(
{
"A": A_sparse,
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
"C": torch.empty((M, K), device="cuda", dtype=dtype),
"M": M,
"N": N,
"K": K,
"nnz": int(mask.sum().item()),
}
)

# power_of_2_64x32x16
M, N, K = 64, 32, 16
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-3.0, 3.0)
mask = torch.rand((M, N), device="cuda") > 0.70
A_sparse = A_dense * mask
tests.append(
{
"A": A_sparse,
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
"C": torch.empty((M, K), device="cuda", dtype=dtype),
"M": M,
"N": N,
"K": K,
"nnz": int(mask.sum().item()),
}
)

# non_power_of_2_negative_values
M, N, K = 30, 50, 20
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-5.0, 5.0)
mask = torch.rand((M, N), device="cuda") > 0.65
A_sparse = A_dense * mask
tests.append(
{
"A": A_sparse,
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-3.0, 3.0),
"C": torch.empty((M, K), device="cuda", dtype=dtype),
"M": M,
"N": N,
"K": K,
"nnz": int(mask.sum().item()),
}
)

# non_power_of_2_255x100x33
M, N, K = 255, 100, 33
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-2.0, 2.0)
mask = torch.rand((M, N), device="cuda") > 0.70
A_sparse = A_dense * mask
tests.append(
{
"A": A_sparse,
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
"C": torch.empty((M, K), device="cuda", dtype=dtype),
"M": M,
"N": N,
"K": K,
"nnz": int(mask.sum().item()),
}
)

# realistic_1000x500x64
M, N, K = 1000, 500, 64
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-1.0, 1.0)
mask = torch.rand((M, N), device="cuda") > 0.65
A_sparse = A_dense * mask
tests.append(
{
"A": A_sparse,
"B": torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
"C": torch.empty((M, K), device="cuda", dtype=dtype),
"M": M,
"N": N,
"K": K,
"nnz": int(mask.sum().item()),
}
)

return tests

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
M = 4096
N = 2048
K = 512
A_dense = torch.empty((M, N), device="cuda", dtype=dtype).uniform_(-1.0, 1.0)
mask = torch.rand((M, N), device="cuda") > 0.65
A_sparse = A_dense * mask
nnz = int(mask.sum().item())
B = torch.empty((N, K), device="cuda", dtype=dtype).uniform_(-1.0, 1.0)
C = torch.empty((M, K), device="cuda", dtype=dtype)
return {
"A": A_sparse,
"B": B,
"C": C,
"M": M,
"N": N,
"K": K,
"nnz": nnz,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include <cuda_runtime.h>

// A, B, C are device pointers
extern "C" void solve(const float* A, const float* B, float* C, int M, int N, int K, int nnz) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import cutlass
import cutlass.cute as cute


# A, B, C are tensors on the GPU
@cute.jit
def solve(
A: cute.Tensor,
B: cute.Tensor,
C: cute.Tensor,
M: cute.Int32,
N: cute.Int32,
K: cute.Int32,
nnz: cute.Int32,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import jax
import jax.numpy as jnp


# A, B are tensors on GPU
@jax.jit
def solve(A: jax.Array, B: jax.Array, M: int, N: int, K: int, nnz: int) -> jax.Array:
# return output tensor directly
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from gpu.host import DeviceContext
from gpu.id import block_dim, block_idx, thread_idx
from memory import UnsafePointer
from math import ceildiv

# A, B, C are device pointers
@export
def solve(A: UnsafePointer[Float32], B: UnsafePointer[Float32], C: UnsafePointer[Float32], M: Int32, N: Int32, K: Int32, nnz: Int32):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch


# A, B, C are tensors on the GPU
def solve(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, M: int, N: int, K: int, nnz: int):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch
import triton
import triton.language as tl


# A, B, C are tensors on the GPU
def solve(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, M: int, N: int, K: int, nnz: int):
pass
Loading