From 7bc03bde52c19342f3e65ae0ef846249424f0168 Mon Sep 17 00:00:00 2001 From: mercush Date: Sun, 16 Nov 2025 22:07:51 -0500 Subject: [PATCH 1/5] sparse matmul kernel --- mlx/backend/cpu/CMakeLists.txt | 1 + mlx/backend/cpu/sparse.cpp | 62 +++++++++++ mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/jit_kernels.cpp | 6 ++ mlx/backend/metal/kernels.h | 4 + mlx/backend/metal/kernels/CMakeLists.txt | 1 + mlx/backend/metal/kernels/sparse.h | 61 +++++++++++ mlx/backend/metal/kernels/sparse.metal | 20 ++++ mlx/backend/metal/nojit_kernels.cpp | 6 ++ mlx/backend/metal/sparse.cpp | 43 ++++++++ mlx/ops.cpp | 31 ++++++ mlx/ops.h | 19 ++++ mlx/primitives.h | 22 ++++ tests/CMakeLists.txt | 1 + tests/sparse_tests.cpp | 131 +++++++++++++++++++++++ 15 files changed, 409 insertions(+) create mode 100644 mlx/backend/cpu/sparse.cpp create mode 100644 mlx/backend/metal/kernels/sparse.h create mode 100644 mlx/backend/metal/kernels/sparse.metal create mode 100644 mlx/backend/metal/sparse.cpp create mode 100644 tests/sparse_tests.cpp diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 9d322c4c49..6f7bdba5f7 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -62,6 +62,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sparse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp diff --git a/mlx/backend/cpu/sparse.cpp b/mlx/backend/cpu/sparse.cpp new file mode 100644 index 0000000000..ba5db85e0d --- /dev/null +++ b/mlx/backend/cpu/sparse.cpp @@ -0,0 +1,62 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/cpu/encoder.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void SparseMatmulCSR::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 4); + + out.set_data(allocator::malloc(out.nbytes())); + + auto& row_ptr = inputs[0]; + auto& col_indices = inputs[1]; + auto& values = inputs[2]; + auto& dense_b = inputs[3]; + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(row_ptr); + encoder.set_input_array(col_indices); + encoder.set_input_array(values); + encoder.set_input_array(dense_b); + encoder.set_output_array(out); + + const int* row_ptr_data = row_ptr.data(); + const int* col_indices_data = col_indices.data(); + const float* values_data = values.data(); + const float* dense_b_data = dense_b.data(); + float* out_data = out.data(); + + int n_rows = n_rows_; + int n_cols = n_cols_; + int dense_b_cols = dense_b.shape(1); + + encoder.dispatch([row_ptr_data, col_indices_data, values_data, dense_b_data, out_data, n_rows, n_cols, dense_b_cols]() { + for (int i = 0; i < n_rows * n_cols; i++) { + out_data[i] = 0.0f; + } + + for (int row = 0; row < n_rows; row++) { + int row_start = row_ptr_data[row]; + int row_end = row_ptr_data[row + 1]; + + for (int col = 0; col < n_cols; col++) { + float sum = 0.0f; + + for (int idx = row_start; idx < row_end; idx++) { + int k = col_indices_data[idx]; + float a_val = values_data[idx]; + float b_val = dense_b_data[k * dense_b_cols + col]; + sum += a_val * b_val; + } + + out_data[row * n_cols + col] = sum; + } + } + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 0fd1834f63..48af65449d 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -116,6 +116,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sparse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index de391abc92..a4d7accdf7 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -326,6 +326,12 @@ MTL::ComputePipelineState* get_logsumexp_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_sparse_kernel( + metal::Device& d, + const std::string& kernel_name) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 0656f3c0fa..11c9c767dd 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -64,6 +64,10 @@ MTL::ComputePipelineState* get_logsumexp_kernel( const std::string& kernel_name, const array& out); +MTL::ComputePipelineState* get_sparse_kernel( + metal::Device& d, + const std::string& kernel_name); + MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 69ac2a5e99..f8ecf60321 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -116,6 +116,7 @@ if(NOT MLX_METAL_JIT) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) build_kernel(sort sort.h) + build_kernel(sparse sparse.h) build_kernel(ternary ternary.h ternary_ops.h) build_kernel(unary unary.h unary_ops.h) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) diff --git a/mlx/backend/metal/kernels/sparse.h b/mlx/backend/metal/kernels/sparse.h new file mode 100644 index 0000000000..ca1536d5bf --- /dev/null +++ b/mlx/backend/metal/kernels/sparse.h @@ -0,0 +1,61 @@ +// Copyright © 2025 Apple Inc. + +// Sparse matrix-matrix multiplication: y = A @ B +// where A is sparse (CSR format) and B is a dense matrix +template +[[kernel]] void sparse_mm_csr( + const device int* row_ptr [[buffer(0)]], + const device int* col_indices [[buffer(1)]], + const device T* values [[buffer(2)]], + const device T* dense_matrix [[buffer(3)]], + device T* output [[buffer(4)]], + constant int& n_rows [[buffer(5)]], + constant int& n_cols [[buffer(6)]], + uint2 gid [[thread_position_in_grid]]) { + + int row = gid.y; + int col = gid.x; + + if (row >= n_rows || col >= n_cols) { + return; + } + + int row_start = row_ptr[row]; + int row_end = row_ptr[row + 1]; + + T sum = T(0); + for (int idx = row_start; idx < row_end; idx++) { + int k = col_indices[idx]; + sum += values[idx] * dense_matrix[k * n_cols + col]; + } + + output[row * n_cols + col] = sum; +} + +// Sparse matrix-vector multiplication: y = A @ x +// where A is sparse (CSR format) and x is a dense vector +template +[[kernel]] void sparse_mv_csr( + const device int* row_ptr [[buffer(0)]], + const device int* col_indices [[buffer(1)]], + const device T* values [[buffer(2)]], + const device T* vector [[buffer(3)]], + device T* output [[buffer(4)]], + constant int& n_rows [[buffer(5)]], + uint gid [[thread_position_in_grid]]) { + + int row = gid; + if (row >= n_rows) { + return; + } + + int row_start = row_ptr[row]; + int row_end = row_ptr[row + 1]; + + T sum = T(0); + for (int idx = row_start; idx < row_end; idx++) { + sum += values[idx] * vector[col_indices[idx]]; + } + + output[row] = sum; +} diff --git a/mlx/backend/metal/kernels/sparse.metal b/mlx/backend/metal/kernels/sparse.metal new file mode 100644 index 0000000000..55cd71bcdc --- /dev/null +++ b/mlx/backend/metal/kernels/sparse.metal @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +// clang-format off +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/sparse.h" + +// Instantiate sparse matrix operations for common types +#define instantiate_sparse_ops(tname, type) \ + instantiate_kernel("sparse_mm_csr_" #tname, sparse_mm_csr, type) \ + instantiate_kernel("sparse_mv_csr_" #tname, sparse_mv_csr, type) \ + +// Instantiate for floating point types +instantiate_sparse_ops(float32, float) +instantiate_sparse_ops(float16, half) +instantiate_sparse_ops(bfloat16, bfloat16_t) +// clang-format on diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 109dd8df78..3066774926 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -79,6 +79,12 @@ MTL::ComputePipelineState* get_logsumexp_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_sparse_kernel( + metal::Device& d, + const std::string& kernel_name) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/sparse.cpp b/mlx/backend/metal/sparse.cpp new file mode 100644 index 0000000000..c37b15e659 --- /dev/null +++ b/mlx/backend/metal/sparse.cpp @@ -0,0 +1,43 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void SparseMatmulCSR::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 4); + + auto& row_ptr = inputs[0]; + auto& col_indices = inputs[1]; + auto& values = inputs[2]; + auto& dense_b = inputs[3]; + + auto& s = stream(); + auto& d = metal::device(s.device); + + out.set_data(allocator::malloc(out.nbytes())); + + auto& compute_encoder = d.get_command_encoder(s.index); + + std::string kernel_name = "sparse_mm_csr_" + type_to_name(values); + auto kernel = get_sparse_kernel(d, kernel_name); + + MTL::Size grid_dims = MTL::Size(n_cols_, n_rows_, 1); + MTL::Size group_dims = + MTL::Size(std::min(32, n_cols_), std::min(32, n_rows_), 1); + + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(row_ptr, 0); + compute_encoder.set_input_array(col_indices, 1); + compute_encoder.set_input_array(values, 2); + compute_encoder.set_input_array(dense_b, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(n_rows_, 5); + compute_encoder.set_bytes(n_cols_, 6); + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + +} // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 271462d56e..23727ad9dd 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3002,6 +3002,37 @@ array matmul( return axes.empty() ? out : squeeze(out, axes, s); } +array sparse_matmul_csr( + const array& row_ptr, + const array& col_indices, + const array& values, + const array& dense_b, + int n_rows, + int n_cols, + StreamOrDevice s /* = {} */) { + if (row_ptr.dtype() != int32) { + throw std::invalid_argument("[sparse_matmul_csr] row_ptr must be int32"); + } + if (col_indices.dtype() != int32) { + throw std::invalid_argument( + "[sparse_matmul_csr] col_indices must be int32"); + } + if (!issubdtype(values.dtype(), floating)) { + throw std::invalid_argument( + "[sparse_matmul_csr] values must be floating point"); + } + if (values.dtype() != dense_b.dtype()) { + throw std::invalid_argument( + "[sparse_matmul_csr] values and dense_b must have the same dtype"); + } + + return array( + {n_rows, n_cols}, + values.dtype(), + std::make_shared(to_stream(s), n_rows, n_cols), + {row_ptr, col_indices, values, dense_b}); +} + array gather( const array& a, const std::vector& indices, diff --git a/mlx/ops.h b/mlx/ops.h index 49c64e74f0..fbf9438f5d 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -957,6 +957,25 @@ inline array round(const array& a, StreamOrDevice s = {}) { /** Matrix-matrix multiplication. */ array matmul(const array& a, const array& b, StreamOrDevice s = {}); +/** Sparse matrix-dense matrix multiplication using CSR format. */ +array sparse_matmul_csr( + const array& row_ptr, + const array& col_indices, + const array& values, + const array& dense_b, + int n_rows, + int n_cols, + StreamOrDevice s = {}); + +/** Sparse matrix-vector multiplication using CSR format. */ +array sparse_matvec_csr( + const array& row_ptr, + const array& col_indices, + const array& values, + const array& vec, + int n_rows, + StreamOrDevice s = {}); + /** Gather array entries given indices and slices */ array gather( const array& a, diff --git a/mlx/primitives.h b/mlx/primitives.h index a1ad2425c0..396fe8566f 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2481,4 +2481,26 @@ class LUF : public Primitive { DEFINE_NAME(LUF) }; +/* Sparse matrix operations using CSR format. */ +class SparseMatmulCSR : public UnaryPrimitive { + public: + explicit SparseMatmulCSR(Stream stream, int n_rows, int n_cols) + : UnaryPrimitive(stream), n_rows_(n_rows), n_cols_(n_cols) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(SparseMatmulCSR) + std::vector output_shapes(const std::vector& inputs) override { + return {{n_rows_, n_cols_}}; + } + auto state() const { + return std::make_pair(n_rows_, n_cols_); + } + + private: + int n_rows_; + int n_cols_; +}; + } // namespace mlx::core diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9e977d6281..511c5169b7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources( ops_tests.cpp random_tests.cpp scheduler_tests.cpp + sparse_tests.cpp utils_tests.cpp vmap_tests.cpp linalg_tests.cpp diff --git a/tests/sparse_tests.cpp b/tests/sparse_tests.cpp new file mode 100644 index 0000000000..745c9f447e --- /dev/null +++ b/tests/sparse_tests.cpp @@ -0,0 +1,131 @@ +// Copyright © 2025 Apple Inc. + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test sparse matrix-dense matrix multiplication") { + // Create a simple sparse matrix in CSR format: + // [[1, 0, 2], + // [0, 3, 0], + // [4, 0, 5]] + // CSR: row_ptr = [0, 2, 3, 5] + // col_indices = [0, 2, 1, 0, 2] + // values = [1, 2, 3, 4, 5] + + auto row_ptr = array({0, 2, 3, 5}, int32); + auto col_indices = array({0, 2, 1, 0, 2}, int32); + auto values = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, float32); + + auto dense_a = array({1.0f, 0.0f, 2.0f, + 0.0f, 3.0f, 0.0f, + 4.0f, 0.0f, 5.0f}, {3, 3}); + + CHECK_EQ(row_ptr.size(), 4); + CHECK_EQ(col_indices.size(), 5); + CHECK_EQ(values.size(), 5); + + eval(row_ptr); + auto row_ptr_data = row_ptr.data(); + CHECK_EQ(row_ptr_data[0], 0); + CHECK_EQ(row_ptr_data[1], 2); + CHECK_EQ(row_ptr_data[2], 3); + CHECK_EQ(row_ptr_data[3], 5); + + auto dense_b = array({1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f}, {3, 2}); + + // Expected result from dense @ dense: + // [[1, 0, 2], [[1, 2], [[11, 14], + // [0, 3, 0], @ [3, 4], = [9, 12], + // [4, 0, 5]] [5, 6]] [29, 38]] + auto expected = matmul(dense_a, dense_b); + + // Test on default device + auto result = sparse_matmul_csr(row_ptr, col_indices, values, dense_b, 3, 2); + CHECK(allclose(result, expected, 1e-5).item()); + + // Test explicitly on CPU + auto result_cpu = sparse_matmul_csr(row_ptr, col_indices, values, dense_b, 3, 2, Device::cpu); + eval(result_cpu); + CHECK(allclose(result_cpu, expected, 1e-5).item()); + + // Verify CPU result matches expected values + auto result_cpu_data = result_cpu.data(); + CHECK_EQ(result_cpu_data[0], 11.0f); // [0,0] + CHECK_EQ(result_cpu_data[1], 14.0f); // [0,1] + CHECK_EQ(result_cpu_data[2], 9.0f); // [1,0] + CHECK_EQ(result_cpu_data[3], 12.0f); // [1,1] + CHECK_EQ(result_cpu_data[4], 29.0f); // [2,0] + CHECK_EQ(result_cpu_data[5], 38.0f); // [2,1] +} + +TEST_CASE("test sparse matrix-vector multiplication") { + // Sparse matrix in CSR format (diagonal): + // [[2, 0, 0], + // [0, 3, 0], + // [0, 0, 4]] + auto row_ptr = array({0, 1, 2, 3}, int32); + auto col_indices = array({0, 1, 2}, int32); + auto values = array({2.0f, 3.0f, 4.0f}, float32); + + auto dense_a = array({2.0f, 0.0f, 0.0f, + 0.0f, 3.0f, 0.0f, + 0.0f, 0.0f, 4.0f}, {3, 3}); + + auto dense_b = array({1.0f, 2.0f, 3.0f}, {3, 1}); + + auto expected = matmul(dense_a, dense_b); + + // Test on default device + auto result = sparse_matmul_csr(row_ptr, col_indices, values, dense_b, 3, 1); + CHECK(allclose(result, expected, 1e-5).item()); + + // Test explicitly on CPU + auto result_cpu = sparse_matmul_csr(row_ptr, col_indices, values, dense_b, 3, 1, Device::cpu); + eval(result_cpu); + CHECK(allclose(result_cpu, expected, 1e-5).item()); + + // Verify CPU result values (diagonal matrix times vector) + auto result_cpu_data = result_cpu.data(); + CHECK_EQ(result_cpu_data[0], 2.0f); // 2 * 1 = 2 + CHECK_EQ(result_cpu_data[1], 6.0f); // 3 * 2 = 6 + CHECK_EQ(result_cpu_data[2], 12.0f); // 4 * 3 = 12 +} + +TEST_CASE("test random sparse matrix") { + int n_rows = 10; + int n_cols = 10; + int dense_cols = 5; + + std::vector row_ptr_vec = {0}; + std::vector col_indices_vec; + std::vector values_vec; + + for (int i = 0; i < n_rows; i++) { + int nnz_this_row = 3 + (i % 3); + for (int j = 0; j < nnz_this_row; j++) { + col_indices_vec.push_back((i * 3 + j * 2) % n_cols); + values_vec.push_back(static_cast(i + j + 1)); + } + row_ptr_vec.push_back(col_indices_vec.size()); + } + + auto row_ptr = array(row_ptr_vec.data(), {static_cast(row_ptr_vec.size())}, int32); + auto col_indices = array(col_indices_vec.data(), {static_cast(col_indices_vec.size())}, int32); + auto values = array(values_vec.data(), {static_cast(values_vec.size())}, float32); + + CHECK_EQ(row_ptr.size(), n_rows + 1); + CHECK(col_indices.size() > 0); + CHECK_EQ(col_indices.size(), values.size()); + + auto dense_b = ones({n_cols, dense_cols}); + + auto result = sparse_matmul_csr(row_ptr, col_indices, values, dense_b, n_rows, dense_cols); + CHECK_EQ(result.shape(0), n_rows); + CHECK_EQ(result.shape(1), dense_cols); +} + From 6b4f7456b16f58484d63cae98c3c50c947e94196 Mon Sep 17 00:00:00 2001 From: mercush Date: Wed, 19 Nov 2025 09:11:59 -0500 Subject: [PATCH 2/5] faster kernel using vectorized read --- mlx/backend/metal/kernels/sparse.h | 62 +++++++++++++++++++++++------- mlx/backend/metal/sparse.cpp | 5 +-- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/mlx/backend/metal/kernels/sparse.h b/mlx/backend/metal/kernels/sparse.h index ca1536d5bf..ac501e2bfc 100644 --- a/mlx/backend/metal/kernels/sparse.h +++ b/mlx/backend/metal/kernels/sparse.h @@ -13,23 +13,57 @@ template constant int& n_cols [[buffer(6)]], uint2 gid [[thread_position_in_grid]]) { - int row = gid.y; - int col = gid.x; + uint row_tid = gid.y; + uint col_vec_idx = gid.x; + + // Vector size + constexpr int BM = 4; + int col_idx = col_vec_idx * BM; - if (row >= n_rows || col >= n_cols) { - return; - } - - int row_start = row_ptr[row]; - int row_end = row_ptr[row + 1]; + if (row_tid >= uint(n_rows) || col_idx >= int(n_cols)) return; - T sum = T(0); - for (int idx = row_start; idx < row_end; idx++) { - int k = col_indices[idx]; - sum += values[idx] * dense_matrix[k * n_cols + col]; + bool full_vector = (col_idx + BM <= n_cols); + + float4 sum = float4(0.0f); + + int row_start = row_ptr[row_tid]; + int row_end = row_ptr[row_tid + 1]; + + if (full_vector) { + for (int idx = row_start; idx < row_end; idx++) { + int k = col_indices[idx]; + float val_a = float(values[idx]); + + // Vectorized read + const device packed_vec* src = (const device packed_vec*)(dense_matrix + k * n_cols + col_idx); + vec val_x_t = *src; + + // Convert to float4 for math + float4 val_x = float4(val_x_t); + + sum += val_a * val_x; + } + + // Store + vec res = vec(sum); + *((device packed_vec*)(output + row_tid * n_cols + col_idx)) = res; + + } else { + // Tail loop + for (int idx = row_start; idx < row_end; idx++) { + int k = col_indices[idx]; + float val_a = float(values[idx]); + + for (int i = 0; i < n_cols - col_idx; i++) { + float val_x = float(dense_matrix[k * n_cols + col_idx + i]); + sum[i] += val_a * val_x; + } + } + + for (int i = 0; i < n_cols - col_idx; i++) { + output[row_tid * n_cols + col_idx + i] = T(sum[i]); + } } - - output[row * n_cols + col] = sum; } // Sparse matrix-vector multiplication: y = A @ x diff --git a/mlx/backend/metal/sparse.cpp b/mlx/backend/metal/sparse.cpp index c37b15e659..5d4da0915c 100644 --- a/mlx/backend/metal/sparse.cpp +++ b/mlx/backend/metal/sparse.cpp @@ -25,9 +25,8 @@ void SparseMatmulCSR::eval_gpu(const std::vector& inputs, array& out) { std::string kernel_name = "sparse_mm_csr_" + type_to_name(values); auto kernel = get_sparse_kernel(d, kernel_name); - MTL::Size grid_dims = MTL::Size(n_cols_, n_rows_, 1); - MTL::Size group_dims = - MTL::Size(std::min(32, n_cols_), std::min(32, n_rows_), 1); + MTL::Size grid_dims = MTL::Size((n_cols_ + 3) / 4, n_rows_, 1); + MTL::Size group_dims = MTL::Size(32, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(row_ptr, 0); From 95b6e3910dd0566f61ca47a3a05930e5f86e8085 Mon Sep 17 00:00:00 2001 From: mercush Date: Wed, 19 Nov 2025 10:56:00 -0500 Subject: [PATCH 3/5] formatting --- mlx/backend/cpu/sparse.cpp | 9 +++- mlx/backend/metal/kernels/sparse.h | 74 +++++++++++++------------- mlx/backend/metal/kernels/sparse.metal | 2 +- tests/sparse_tests.cpp | 53 +++++++++--------- 4 files changed, 74 insertions(+), 64 deletions(-) diff --git a/mlx/backend/cpu/sparse.cpp b/mlx/backend/cpu/sparse.cpp index ba5db85e0d..cfca6bb625 100644 --- a/mlx/backend/cpu/sparse.cpp +++ b/mlx/backend/cpu/sparse.cpp @@ -34,7 +34,14 @@ void SparseMatmulCSR::eval_cpu(const std::vector& inputs, array& out) { int n_cols = n_cols_; int dense_b_cols = dense_b.shape(1); - encoder.dispatch([row_ptr_data, col_indices_data, values_data, dense_b_data, out_data, n_rows, n_cols, dense_b_cols]() { + encoder.dispatch([row_ptr_data, + col_indices_data, + values_data, + dense_b_data, + out_data, + n_rows, + n_cols, + dense_b_cols]() { for (int i = 0; i < n_rows * n_cols; i++) { out_data[i] = 0.0f; } diff --git a/mlx/backend/metal/kernels/sparse.h b/mlx/backend/metal/kernels/sparse.h index ac501e2bfc..f821071f0f 100644 --- a/mlx/backend/metal/kernels/sparse.h +++ b/mlx/backend/metal/kernels/sparse.h @@ -12,57 +12,58 @@ template constant int& n_rows [[buffer(5)]], constant int& n_cols [[buffer(6)]], uint2 gid [[thread_position_in_grid]]) { - uint row_tid = gid.y; uint col_vec_idx = gid.x; - + // Vector size constexpr int BM = 4; int col_idx = col_vec_idx * BM; - if (row_tid >= uint(n_rows) || col_idx >= int(n_cols)) return; + if (row_tid >= uint(n_rows) || col_idx >= int(n_cols)) + return; bool full_vector = (col_idx + BM <= n_cols); - + float4 sum = float4(0.0f); - + int row_start = row_ptr[row_tid]; int row_end = row_ptr[row_tid + 1]; - + if (full_vector) { - for (int idx = row_start; idx < row_end; idx++) { - int k = col_indices[idx]; - float val_a = float(values[idx]); - - // Vectorized read - const device packed_vec* src = (const device packed_vec*)(dense_matrix + k * n_cols + col_idx); - vec val_x_t = *src; - - // Convert to float4 for math - float4 val_x = float4(val_x_t); - - sum += val_a * val_x; - } - - // Store - vec res = vec(sum); - *((device packed_vec*)(output + row_tid * n_cols + col_idx)) = res; - + for (int idx = row_start; idx < row_end; idx++) { + int k = col_indices[idx]; + float val_a = float(values[idx]); + + // Vectorized read + const device packed_vec* src = + (const device packed_vec*)(dense_matrix + k * n_cols + col_idx); + vec val_x_t = *src; + + // Convert to float4 for math + float4 val_x = float4(val_x_t); + + sum += val_a * val_x; + } + + // Store + vec res = vec(sum); + *((device packed_vec*)(output + row_tid * n_cols + col_idx)) = res; + } else { - // Tail loop - for (int idx = row_start; idx < row_end; idx++) { - int k = col_indices[idx]; - float val_a = float(values[idx]); - - for (int i = 0; i < n_cols - col_idx; i++) { - float val_x = float(dense_matrix[k * n_cols + col_idx + i]); - sum[i] += val_a * val_x; - } - } - + // Tail loop + for (int idx = row_start; idx < row_end; idx++) { + int k = col_indices[idx]; + float val_a = float(values[idx]); + for (int i = 0; i < n_cols - col_idx; i++) { - output[row_tid * n_cols + col_idx + i] = T(sum[i]); + float val_x = float(dense_matrix[k * n_cols + col_idx + i]); + sum[i] += val_a * val_x; } + } + + for (int i = 0; i < n_cols - col_idx; i++) { + output[row_tid * n_cols + col_idx + i] = T(sum[i]); + } } } @@ -77,7 +78,6 @@ template device T* output [[buffer(4)]], constant int& n_rows [[buffer(5)]], uint gid [[thread_position_in_grid]]) { - int row = gid; if (row >= n_rows) { return; diff --git a/mlx/backend/metal/kernels/sparse.metal b/mlx/backend/metal/kernels/sparse.metal index 55cd71bcdc..785d92db5f 100644 --- a/mlx/backend/metal/kernels/sparse.metal +++ b/mlx/backend/metal/kernels/sparse.metal @@ -17,4 +17,4 @@ instantiate_sparse_ops(float32, float) instantiate_sparse_ops(float16, half) instantiate_sparse_ops(bfloat16, bfloat16_t) -// clang-format on + // clang-format on diff --git a/tests/sparse_tests.cpp b/tests/sparse_tests.cpp index 745c9f447e..df5d0ee3ac 100644 --- a/tests/sparse_tests.cpp +++ b/tests/sparse_tests.cpp @@ -19,9 +19,8 @@ TEST_CASE("test sparse matrix-dense matrix multiplication") { auto col_indices = array({0, 2, 1, 0, 2}, int32); auto values = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, float32); - auto dense_a = array({1.0f, 0.0f, 2.0f, - 0.0f, 3.0f, 0.0f, - 4.0f, 0.0f, 5.0f}, {3, 3}); + auto dense_a = + array({1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, 0.0f, 5.0f}, {3, 3}); CHECK_EQ(row_ptr.size(), 4); CHECK_EQ(col_indices.size(), 5); @@ -34,9 +33,7 @@ TEST_CASE("test sparse matrix-dense matrix multiplication") { CHECK_EQ(row_ptr_data[2], 3); CHECK_EQ(row_ptr_data[3], 5); - auto dense_b = array({1.0f, 2.0f, - 3.0f, 4.0f, - 5.0f, 6.0f}, {3, 2}); + auto dense_b = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {3, 2}); // Expected result from dense @ dense: // [[1, 0, 2], [[1, 2], [[11, 14], @@ -49,18 +46,19 @@ TEST_CASE("test sparse matrix-dense matrix multiplication") { CHECK(allclose(result, expected, 1e-5).item()); // Test explicitly on CPU - auto result_cpu = sparse_matmul_csr(row_ptr, col_indices, values, dense_b, 3, 2, Device::cpu); + auto result_cpu = sparse_matmul_csr( + row_ptr, col_indices, values, dense_b, 3, 2, Device::cpu); eval(result_cpu); CHECK(allclose(result_cpu, expected, 1e-5).item()); // Verify CPU result matches expected values auto result_cpu_data = result_cpu.data(); - CHECK_EQ(result_cpu_data[0], 11.0f); // [0,0] - CHECK_EQ(result_cpu_data[1], 14.0f); // [0,1] - CHECK_EQ(result_cpu_data[2], 9.0f); // [1,0] - CHECK_EQ(result_cpu_data[3], 12.0f); // [1,1] - CHECK_EQ(result_cpu_data[4], 29.0f); // [2,0] - CHECK_EQ(result_cpu_data[5], 38.0f); // [2,1] + CHECK_EQ(result_cpu_data[0], 11.0f); // [0,0] + CHECK_EQ(result_cpu_data[1], 14.0f); // [0,1] + CHECK_EQ(result_cpu_data[2], 9.0f); // [1,0] + CHECK_EQ(result_cpu_data[3], 12.0f); // [1,1] + CHECK_EQ(result_cpu_data[4], 29.0f); // [2,0] + CHECK_EQ(result_cpu_data[5], 38.0f); // [2,1] } TEST_CASE("test sparse matrix-vector multiplication") { @@ -72,9 +70,8 @@ TEST_CASE("test sparse matrix-vector multiplication") { auto col_indices = array({0, 1, 2}, int32); auto values = array({2.0f, 3.0f, 4.0f}, float32); - auto dense_a = array({2.0f, 0.0f, 0.0f, - 0.0f, 3.0f, 0.0f, - 0.0f, 0.0f, 4.0f}, {3, 3}); + auto dense_a = + array({2.0f, 0.0f, 0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f, 4.0f}, {3, 3}); auto dense_b = array({1.0f, 2.0f, 3.0f}, {3, 1}); @@ -85,15 +82,16 @@ TEST_CASE("test sparse matrix-vector multiplication") { CHECK(allclose(result, expected, 1e-5).item()); // Test explicitly on CPU - auto result_cpu = sparse_matmul_csr(row_ptr, col_indices, values, dense_b, 3, 1, Device::cpu); + auto result_cpu = sparse_matmul_csr( + row_ptr, col_indices, values, dense_b, 3, 1, Device::cpu); eval(result_cpu); CHECK(allclose(result_cpu, expected, 1e-5).item()); // Verify CPU result values (diagonal matrix times vector) auto result_cpu_data = result_cpu.data(); - CHECK_EQ(result_cpu_data[0], 2.0f); // 2 * 1 = 2 - CHECK_EQ(result_cpu_data[1], 6.0f); // 3 * 2 = 6 - CHECK_EQ(result_cpu_data[2], 12.0f); // 4 * 3 = 12 + CHECK_EQ(result_cpu_data[0], 2.0f); // 2 * 1 = 2 + CHECK_EQ(result_cpu_data[1], 6.0f); // 3 * 2 = 6 + CHECK_EQ(result_cpu_data[2], 12.0f); // 4 * 3 = 12 } TEST_CASE("test random sparse matrix") { @@ -114,9 +112,14 @@ TEST_CASE("test random sparse matrix") { row_ptr_vec.push_back(col_indices_vec.size()); } - auto row_ptr = array(row_ptr_vec.data(), {static_cast(row_ptr_vec.size())}, int32); - auto col_indices = array(col_indices_vec.data(), {static_cast(col_indices_vec.size())}, int32); - auto values = array(values_vec.data(), {static_cast(values_vec.size())}, float32); + auto row_ptr = + array(row_ptr_vec.data(), {static_cast(row_ptr_vec.size())}, int32); + auto col_indices = array( + col_indices_vec.data(), + {static_cast(col_indices_vec.size())}, + int32); + auto values = + array(values_vec.data(), {static_cast(values_vec.size())}, float32); CHECK_EQ(row_ptr.size(), n_rows + 1); CHECK(col_indices.size() > 0); @@ -124,8 +127,8 @@ TEST_CASE("test random sparse matrix") { auto dense_b = ones({n_cols, dense_cols}); - auto result = sparse_matmul_csr(row_ptr, col_indices, values, dense_b, n_rows, dense_cols); + auto result = sparse_matmul_csr( + row_ptr, col_indices, values, dense_b, n_rows, dense_cols); CHECK_EQ(result.shape(0), n_rows); CHECK_EQ(result.shape(1), dense_cols); } - From a8be946285c37306542ebef5ac3c443d0838d52e Mon Sep 17 00:00:00 2001 From: mercush Date: Wed, 19 Nov 2025 11:30:27 -0500 Subject: [PATCH 4/5] remove sparse matvec kernel --- mlx/backend/metal/kernels/sparse.h | 27 -------------------------- mlx/backend/metal/kernels/sparse.metal | 3 +-- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/mlx/backend/metal/kernels/sparse.h b/mlx/backend/metal/kernels/sparse.h index f821071f0f..83ef2fec5c 100644 --- a/mlx/backend/metal/kernels/sparse.h +++ b/mlx/backend/metal/kernels/sparse.h @@ -66,30 +66,3 @@ template } } } - -// Sparse matrix-vector multiplication: y = A @ x -// where A is sparse (CSR format) and x is a dense vector -template -[[kernel]] void sparse_mv_csr( - const device int* row_ptr [[buffer(0)]], - const device int* col_indices [[buffer(1)]], - const device T* values [[buffer(2)]], - const device T* vector [[buffer(3)]], - device T* output [[buffer(4)]], - constant int& n_rows [[buffer(5)]], - uint gid [[thread_position_in_grid]]) { - int row = gid; - if (row >= n_rows) { - return; - } - - int row_start = row_ptr[row]; - int row_end = row_ptr[row + 1]; - - T sum = T(0); - for (int idx = row_start; idx < row_end; idx++) { - sum += values[idx] * vector[col_indices[idx]]; - } - - output[row] = sum; -} diff --git a/mlx/backend/metal/kernels/sparse.metal b/mlx/backend/metal/kernels/sparse.metal index 785d92db5f..bbe5359236 100644 --- a/mlx/backend/metal/kernels/sparse.metal +++ b/mlx/backend/metal/kernels/sparse.metal @@ -10,8 +10,7 @@ // Instantiate sparse matrix operations for common types #define instantiate_sparse_ops(tname, type) \ - instantiate_kernel("sparse_mm_csr_" #tname, sparse_mm_csr, type) \ - instantiate_kernel("sparse_mv_csr_" #tname, sparse_mv_csr, type) \ + instantiate_kernel("sparse_mm_csr_" #tname, sparse_mm_csr, type) // Instantiate for floating point types instantiate_sparse_ops(float32, float) From 499a2b8ddea9484e256d715620af7950d9af5b39 Mon Sep 17 00:00:00 2001 From: mercush Date: Sat, 29 Nov 2025 15:10:48 -0500 Subject: [PATCH 5/5] add cuda backend --- mlx/backend/cpu/sparse.cpp | 2 - mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/sparse.cu | 103 ++++++++++++++++++++++++++++++++ mlx/backend/metal/sparse.cpp | 2 - mlx/ops.h | 9 --- 5 files changed, 104 insertions(+), 13 deletions(-) create mode 100644 mlx/backend/cuda/sparse.cu diff --git a/mlx/backend/cpu/sparse.cpp b/mlx/backend/cpu/sparse.cpp index cfca6bb625..5025aefbb9 100644 --- a/mlx/backend/cpu/sparse.cpp +++ b/mlx/backend/cpu/sparse.cpp @@ -1,7 +1,5 @@ // Copyright © 2025 Apple Inc. -#include - #include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7f8f1aadea..facd3c9ab2 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -48,6 +48,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu + ${CMAKE_CURRENT_SOURCE_DIR}/sparse.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu diff --git a/mlx/backend/cuda/sparse.cu b/mlx/backend/cuda/sparse.cu new file mode 100644 index 0000000000..31f97fae77 --- /dev/null +++ b/mlx/backend/cuda/sparse.cu @@ -0,0 +1,103 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +#include + +namespace mlx::core { + +namespace cu { + +template +__global__ void sparse_matmul_csr_kernel( + const int* row_ptr, + const int* col_indices, + const T* values, + const T* dense_b, + T* out, + int n_rows, + int n_cols, + int dense_b_cols) { + // Each block processes one row of the sparse matrix + int row = blockIdx.x; + + if (row >= n_rows) { + return; + } + + int row_start = row_ptr[row]; + int row_end = row_ptr[row + 1]; + + // Each thread processes multiple columns of the output + for (int col = threadIdx.x; col < n_cols; col += BLOCK_SIZE) { + T sum = 0; + + // Iterate through nonzero elements in this row + for (int idx = row_start; idx < row_end; idx++) { + int k = col_indices[idx]; + T a_val = values[idx]; + T b_val = dense_b[k * dense_b_cols + col]; + sum += a_val * b_val; + } + + out[row * n_cols + col] = sum; + } +} + +} // namespace cu + +void SparseMatmulCSR::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("SparseMatmulCSR::eval_gpu"); + assert(inputs.size() == 4); + + const array& row_ptr = inputs[0]; + const array& col_indices = inputs[1]; + const array& values = inputs[2]; + const array& dense_b = inputs[3]; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + out.set_data(allocator::malloc(out.nbytes())); + + encoder.set_input_array(row_ptr); + encoder.set_input_array(col_indices); + encoder.set_input_array(values); + encoder.set_input_array(dense_b); + encoder.set_output_array(out); + + int dense_b_cols = dense_b.shape(1); + + // Launch kernel + dispatch_float_types(values.dtype(), "sparse_matmul_csr", [&](auto type_tag) { + using DataType = cuda_type_t; + + constexpr int BLOCK_SIZE = 256; + dim3 block_dim(BLOCK_SIZE); + dim3 grid_dim(n_rows_); + + auto kernel = cu::sparse_matmul_csr_kernel; + + encoder.add_kernel_node( + kernel, + grid_dim, + block_dim, + 0, + row_ptr.data(), + col_indices.data(), + values.data(), + dense_b.data(), + out.data(), + n_rows_, + n_cols_, + dense_b_cols); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/sparse.cpp b/mlx/backend/metal/sparse.cpp index 5d4da0915c..2d2eb6386c 100644 --- a/mlx/backend/metal/sparse.cpp +++ b/mlx/backend/metal/sparse.cpp @@ -1,9 +1,7 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/ops.h b/mlx/ops.h index fbf9438f5d..4c4fa95e5f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -967,15 +967,6 @@ array sparse_matmul_csr( int n_cols, StreamOrDevice s = {}); -/** Sparse matrix-vector multiplication using CSR format. */ -array sparse_matvec_csr( - const array& row_ptr, - const array& col_indices, - const array& values, - const array& vec, - int n_rows, - StreamOrDevice s = {}); - /** Gather array entries given indices and slices */ array gather( const array& a,