Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlx/backend/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sparse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
Expand Down
67 changes: 67 additions & 0 deletions mlx/backend/cpu/sparse.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"

namespace mlx::core {

void SparseMatmulCSR::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);

out.set_data(allocator::malloc(out.nbytes()));

auto& row_ptr = inputs[0];
auto& col_indices = inputs[1];
auto& values = inputs[2];
auto& dense_b = inputs[3];

auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(row_ptr);
encoder.set_input_array(col_indices);
encoder.set_input_array(values);
encoder.set_input_array(dense_b);
encoder.set_output_array(out);

const int* row_ptr_data = row_ptr.data<int>();
const int* col_indices_data = col_indices.data<int>();
const float* values_data = values.data<float>();
const float* dense_b_data = dense_b.data<float>();
float* out_data = out.data<float>();

int n_rows = n_rows_;
int n_cols = n_cols_;
int dense_b_cols = dense_b.shape(1);

encoder.dispatch([row_ptr_data,
col_indices_data,
values_data,
dense_b_data,
out_data,
n_rows,
n_cols,
dense_b_cols]() {
for (int i = 0; i < n_rows * n_cols; i++) {
out_data[i] = 0.0f;
}

for (int row = 0; row < n_rows; row++) {
int row_start = row_ptr_data[row];
int row_end = row_ptr_data[row + 1];

for (int col = 0; col < n_cols; col++) {
float sum = 0.0f;

for (int idx = row_start; idx < row_end; idx++) {
int k = col_indices_data[idx];
float a_val = values_data[idx];
float b_val = dense_b_data[k * dense_b_cols + col];
sum += a_val * b_val;
}

out_data[row * n_cols + col] = sum;
}
}
});
}

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/sparse.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
Expand Down
103 changes: 103 additions & 0 deletions mlx/backend/cuda/sparse.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"

#include <nvtx3/nvtx3.hpp>

#include <cassert>

namespace mlx::core {

namespace cu {

template <typename T, int BLOCK_SIZE = 256>
__global__ void sparse_matmul_csr_kernel(
const int* row_ptr,
const int* col_indices,
const T* values,
const T* dense_b,
T* out,
int n_rows,
int n_cols,
int dense_b_cols) {
// Each block processes one row of the sparse matrix
int row = blockIdx.x;

if (row >= n_rows) {
return;
}

int row_start = row_ptr[row];
int row_end = row_ptr[row + 1];

// Each thread processes multiple columns of the output
for (int col = threadIdx.x; col < n_cols; col += BLOCK_SIZE) {
T sum = 0;

// Iterate through nonzero elements in this row
for (int idx = row_start; idx < row_end; idx++) {
int k = col_indices[idx];
T a_val = values[idx];
T b_val = dense_b[k * dense_b_cols + col];
sum += a_val * b_val;
}

out[row * n_cols + col] = sum;
}
}

} // namespace cu

void SparseMatmulCSR::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("SparseMatmulCSR::eval_gpu");
assert(inputs.size() == 4);

const array& row_ptr = inputs[0];
const array& col_indices = inputs[1];
const array& values = inputs[2];
const array& dense_b = inputs[3];

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

out.set_data(allocator::malloc(out.nbytes()));

encoder.set_input_array(row_ptr);
encoder.set_input_array(col_indices);
encoder.set_input_array(values);
encoder.set_input_array(dense_b);
encoder.set_output_array(out);

int dense_b_cols = dense_b.shape(1);

// Launch kernel
dispatch_float_types(values.dtype(), "sparse_matmul_csr", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;

constexpr int BLOCK_SIZE = 256;
dim3 block_dim(BLOCK_SIZE);
dim3 grid_dim(n_rows_);

auto kernel = cu::sparse_matmul_csr_kernel<DataType, BLOCK_SIZE>;

encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
row_ptr.data<int>(),
col_indices.data<int>(),
values.data<DataType>(),
dense_b.data<DataType>(),
out.data<DataType>(),
n_rows_,
n_cols_,
dense_b_cols);
});
}

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sparse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
Expand Down
6 changes: 6 additions & 0 deletions mlx/backend/metal/jit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ MTL::ComputePipelineState* get_logsumexp_kernel(
return d.get_kernel(kernel_name, lib);
}

MTL::ComputePipelineState* get_sparse_kernel(
metal::Device& d,
const std::string& kernel_name) {
return d.get_kernel(kernel_name);
}

MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
4 changes: 4 additions & 0 deletions mlx/backend/metal/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ MTL::ComputePipelineState* get_logsumexp_kernel(
const std::string& kernel_name,
const array& out);

MTL::ComputePipelineState* get_sparse_kernel(
metal::Device& d,
const std::string& kernel_name);

MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h)
build_kernel(sparse sparse.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
Expand Down
68 changes: 68 additions & 0 deletions mlx/backend/metal/kernels/sparse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright © 2025 Apple Inc.

// Sparse matrix-matrix multiplication: y = A @ B
// where A is sparse (CSR format) and B is a dense matrix
template <typename T>
[[kernel]] void sparse_mm_csr(
const device int* row_ptr [[buffer(0)]],
const device int* col_indices [[buffer(1)]],
const device T* values [[buffer(2)]],
const device T* dense_matrix [[buffer(3)]],
device T* output [[buffer(4)]],
constant int& n_rows [[buffer(5)]],
constant int& n_cols [[buffer(6)]],
uint2 gid [[thread_position_in_grid]]) {
uint row_tid = gid.y;
uint col_vec_idx = gid.x;

// Vector size
constexpr int BM = 4;
int col_idx = col_vec_idx * BM;

if (row_tid >= uint(n_rows) || col_idx >= int(n_cols))
return;

bool full_vector = (col_idx + BM <= n_cols);

float4 sum = float4(0.0f);

int row_start = row_ptr[row_tid];
int row_end = row_ptr[row_tid + 1];

if (full_vector) {
for (int idx = row_start; idx < row_end; idx++) {
int k = col_indices[idx];
float val_a = float(values[idx]);

// Vectorized read
const device packed_vec<T, 4>* src =
(const device packed_vec<T, 4>*)(dense_matrix + k * n_cols + col_idx);
vec<T, 4> val_x_t = *src;

// Convert to float4 for math
float4 val_x = float4(val_x_t);

sum += val_a * val_x;
}

// Store
vec<T, 4> res = vec<T, 4>(sum);
*((device packed_vec<T, 4>*)(output + row_tid * n_cols + col_idx)) = res;

} else {
// Tail loop
for (int idx = row_start; idx < row_end; idx++) {
int k = col_indices[idx];
float val_a = float(values[idx]);

for (int i = 0; i < n_cols - col_idx; i++) {
float val_x = float(dense_matrix[k * n_cols + col_idx + i]);
sum[i] += val_a * val_x;
}
}

for (int i = 0; i < n_cols - col_idx; i++) {
output[row_tid * n_cols + col_idx + i] = T(sum[i]);
}
}
}
19 changes: 19 additions & 0 deletions mlx/backend/metal/kernels/sparse.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © 2025 Apple Inc.

#include <metal_integer>
#include <metal_math>

// clang-format off
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sparse.h"

// Instantiate sparse matrix operations for common types
#define instantiate_sparse_ops(tname, type) \
instantiate_kernel("sparse_mm_csr_" #tname, sparse_mm_csr, type)

// Instantiate for floating point types
instantiate_sparse_ops(float32, float)
instantiate_sparse_ops(float16, half)
instantiate_sparse_ops(bfloat16, bfloat16_t)
// clang-format on
6 changes: 6 additions & 0 deletions mlx/backend/metal/nojit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ MTL::ComputePipelineState* get_logsumexp_kernel(
return d.get_kernel(kernel_name);
}

MTL::ComputePipelineState* get_sparse_kernel(
metal::Device& d,
const std::string& kernel_name) {
return d.get_kernel(kernel_name);
}

MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
40 changes: 40 additions & 0 deletions mlx/backend/metal/sparse.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"

namespace mlx::core {

void SparseMatmulCSR::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);

auto& row_ptr = inputs[0];
auto& col_indices = inputs[1];
auto& values = inputs[2];
auto& dense_b = inputs[3];

auto& s = stream();
auto& d = metal::device(s.device);

out.set_data(allocator::malloc(out.nbytes()));

auto& compute_encoder = d.get_command_encoder(s.index);

std::string kernel_name = "sparse_mm_csr_" + type_to_name(values);
auto kernel = get_sparse_kernel(d, kernel_name);

MTL::Size grid_dims = MTL::Size((n_cols_ + 3) / 4, n_rows_, 1);
MTL::Size group_dims = MTL::Size(32, 1, 1);

compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(row_ptr, 0);
compute_encoder.set_input_array(col_indices, 1);
compute_encoder.set_input_array(values, 2);
compute_encoder.set_input_array(dense_b, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(n_rows_, 5);
compute_encoder.set_bytes(n_cols_, 6);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

} // namespace mlx::core
Loading