Skip to content

[CUDA] Implement SegmentedMM#3238

Open
Lyxot wants to merge 6 commits intoml-explore:mainfrom
Lyxot:cuda/segmented_mm
Open

[CUDA] Implement SegmentedMM#3238
Lyxot wants to merge 6 commits intoml-explore:mainfrom
Lyxot:cuda/segmented_mm

Conversation

@Lyxot
Copy link
Contributor

@Lyxot Lyxot commented Mar 10, 2026

Proposed changes

Implement mx.segmented_mm for the CUDA backend using CUTLASS grouped GEMM.

Performance

MLX_ENABLE_TF32=0, random segments.

float32

| Case              | MLX ms | Loop ms | Speedup |
|-------------------|--------|---------|---------|
| 128x128x1024x16   | 0.040  | 0.353   | 8.73x   |
| 128x128x1024x32   | 0.042  | 0.645   | 15.49x  |
| 256x256x2048x16   | 0.130  | 0.566   | 4.35x   |
| 512x512x4096x32   | 0.200  | 0.788   | 3.95x   |
| 1024x1024x4096x32 | 0.596  | 1.509   | 2.53x   |
| 1024x1024x8192x64 | 1.197  | 2.929   | 2.45x   |

bfloat16

| Case              | MLX ms | Loop ms | Speedup |
|-------------------|--------|---------|---------|
| 128x128x1024x16   | 0.041  | 0.184   | 4.54x   |
| 128x128x1024x32   | 0.032  | 0.327   | 10.35x  |
| 256x256x2048x16   | 0.061  | 0.186   | 3.05x   |
| 512x512x4096x32   | 0.161  | 0.395   | 2.46x   |
| 1024x1024x4096x32 | 0.515  | 0.791   | 1.54x   |
| 1024x1024x8192x64 | 1.095  | 1.593   | 1.46x   |

MLX ms = mx.segmented_mm (CUTLASS grouped GEMM), Loop ms = MLX loop-of-matmuls baseline.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Lyxot added 6 commits March 8, 2026 16:47
Replace the host-side cuBLAS loop with a single CUTLASS grouped GEMM
dispatch. A GPU-side prepare kernel builds per-segment problem sizes
and pointer offsets from the segments array, eliminating the host sync
that was required to read segment boundaries.
CUTLASS handles K=0 segments correctly: the mainloop iterates
zero times and the epilogue writes zeros to the output.
Compare mx.segmented_mm (grouped GEMM) against MLX loop-of-matmuls baseline. Remove torch dependency. Add accuracy checks: fp32 vs numpy fp64, fp16/bf16 vs own fp32 result.
Copilot AI review requested due to automatic review settings March 10, 2026 09:39
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements mx.segmented_mm on the CUDA backend by dispatching to a CUTLASS grouped GEMM path, enabling the existing Python BLAS segmented-mm tests to run on CUDA and adding a standalone benchmark script to compare against a loop-of-matmuls baseline.

Changes:

  • Enable CUDA support for the SegmentedMM primitive and remove the CUDA test skip.
  • Add SegmentedMM::eval_gpu implementation that prepares inputs and calls a new CUTLASS grouped-GEMM launcher.
  • Introduce a Python benchmark for mx.segmented_mm performance and numerical error checks.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
python/tests/cuda_skip.py Removes the CUDA skip for the segmented_mm BLAS test.
mlx/backend/cuda/primitives.cpp Marks SegmentedMM as supported on CUDA (removes NO_GPU).
mlx/backend/cuda/matmul.cpp Adds CUDA implementation for SegmentedMM::eval_gpu.
mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu Adds segment-to-grouped-GEMM argument preparation kernel and CUTLASS dispatch wrapper.
mlx/backend/cuda/gemms/grouped_gemm.h Declares cutlass_segmented_mm entrypoint.
benchmarks/python/segmented_mm_bench.py Adds benchmarking and correctness checking script for segmented_mm.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@angeloskath angeloskath requested a review from zcbenz March 10, 2026 20:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants