From 3ed3b1154b4af6d2ab92801c471569620ada5e75 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 18 Nov 2025 19:08:58 +0100 Subject: [PATCH 01/29] qqmm --- examples/cpp/qqmm.cu | 53 +++++ mlx/backend/cpu/quantized.cpp | 5 + mlx/backend/cuda/CMakeLists.txt | 3 + mlx/backend/cuda/cublas_utils.cpp | 160 ++++++++++++++ mlx/backend/cuda/cublas_utils.h | 55 +++++ mlx/backend/cuda/gemms/cublas_gemm.cpp | 140 ++---------- mlx/backend/cuda/gemms/cublas_gemm.h | 7 +- mlx/backend/cuda/matmul.cpp | 3 +- mlx/backend/cuda/quantized/cublas_qqmm.cpp | 243 +++++++++++++++++++++ mlx/backend/cuda/quantized/cublas_qqmm.h | 108 +++++++++ mlx/backend/cuda/quantized/qqmm_utils.cu | 169 ++++++++++++++ mlx/backend/cuda/quantized/qqmm_utils.h | 30 +++ mlx/backend/cuda/quantized/quantized.cpp | 145 +++++++++++- mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_gpu/primitives.cpp | 1 + mlx/ops.cpp | 161 ++++++++++++-- mlx/ops.h | 12 + mlx/primitives.cpp | 19 ++ mlx/primitives.h | 37 ++++ 19 files changed, 1205 insertions(+), 147 deletions(-) create mode 100644 examples/cpp/qqmm.cu create mode 100644 mlx/backend/cuda/cublas_utils.cpp create mode 100644 mlx/backend/cuda/cublas_utils.h create mode 100644 mlx/backend/cuda/quantized/cublas_qqmm.cpp create mode 100644 mlx/backend/cuda/quantized/cublas_qqmm.h create mode 100644 mlx/backend/cuda/quantized/qqmm_utils.cu create mode 100644 mlx/backend/cuda/quantized/qqmm_utils.h diff --git a/examples/cpp/qqmm.cu b/examples/cpp/qqmm.cu new file mode 100644 index 0000000000..eec515dd48 --- /dev/null +++ b/examples/cpp/qqmm.cu @@ -0,0 +1,53 @@ +#include +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/quantized/cublas_qqmm.h" +#include "mlx/mlx.h" +#include "mlx/stream.h" + +namespace mx = mlx::core; + +int main() { + int group_size = 16; + int bits = 4; + int M = 128; + int N = 128; + int K = 256; + std::string quantization_mode = "nvfp4"; + + mx::Device device(mx::Device::gpu, 0); + auto s = mx::default_stream(device); + auto& encoder = mx::cu::get_command_encoder(s); + + mx::array a = mx::random::uniform({M, K}, mx::bfloat16); // (M, K) + mx::array b = mx::random::uniform({N, K}, mx::bfloat16); // (N, K) + + auto scaled_a = mx::quantize(a, group_size, bits, quantization_mode); + auto scaled_b = mx::quantize(b, group_size, bits, quantization_mode); + + mx::array a_quantized = scaled_a[0]; + mx::array a_scale = scaled_a[1]; + mx::array b_quantized = scaled_b[0]; + mx::array b_scale = scaled_b[1]; + + mx::array out = mx::qqmm( + a_quantized, + b_quantized, + a_scale, + b_scale, + true, + group_size, + bits, + quantization_mode); + + mx::array a_dequantized = + mx::dequantize(a_quantized, a_scale, {}, 16, 4, "nvfp4"); + mx::array b_dequantized = + mx::dequantize(b_quantized, b_scale, {}, 16, 4, "nvfp4"); + + mx::array reference_deq = + mx::matmul(a_dequantized, mx::transpose(b_dequantized)); + mx::array isclose = mx::allclose(out, reference_deq, 1e-1f); + + std::cout << isclose << std::endl; + return 0; +} \ No newline at end of file diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 75a8e62337..d3964a28b1 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -1145,4 +1145,9 @@ void fast::ConvertFP8::eval_cpu( }); } +void DualQuantizedMatmul::eval_cpu( + const std::vector& inputs, + array& out) { + throw std::runtime_error("DualQuantizedMatmul not implemented on CPU."); +} } // namespace mlx::core diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index a4606e5e34..5bece69be3 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -18,6 +18,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/cublas_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp @@ -55,7 +56,9 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) diff --git a/mlx/backend/cuda/cublas_utils.cpp b/mlx/backend/cuda/cublas_utils.cpp new file mode 100644 index 0000000000..9af902d5a4 --- /dev/null +++ b/mlx/backend/cuda/cublas_utils.cpp @@ -0,0 +1,160 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cublas_utils.h" +#include "mlx/backend/cuda/cuda.h" +#include "mlx/utils.h" + +namespace mlx::core { +namespace cublas_utils { + +namespace { + +struct CublasPreference { + CublasPreference(cu::Device& device) { + // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB + // for Hopper+: + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace + uint64_t MiB = 1024 * 1024; + uint64_t workspace_size = + device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; + + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( + pref_, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(uint64_t))); + } + + ~CublasPreference() { + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); + } + + cublasLtMatmulPreference_t pref_{nullptr}; +}; + +} // namespace + +cublasLtMatmulPreference_t get_preference(cu::Device& device) { + static CublasPreference pref(device); + return pref.pref_; +} + +void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size) { + if (workspace_size == 0) { + return nullptr; + } + + // Ensure workspace is 256-byte aligned + int nbytes = cuda::ceil_div(workspace_size, 256) * 256; + array workspace( + cu::malloc_async(nbytes, encoder.stream()), + {static_cast(workspace_size)}, + int8); + encoder.add_temporary(workspace); + return gpu_ptr(workspace); +} + +cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + cublasLtMatrixLayout_t desc; + if (transposed) { + std::swap(rows, cols); + } + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); + if (batch_count > 1) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_count, + sizeof(int32_t))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, + sizeof(int64_t))); + } + return desc; +} + +void execute_matmul( + cu::CommandEncoder& encoder, + cublasLtHandle_t handle, + cublasLtMatmulDesc_t matmul_desc, + cublasLtMatrixLayout_t a_desc, + cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, + cublasLtMatrixLayout_t out_desc, + cublasLtMatmulHeuristicResult_t& heuristic, + cublasLtMatmulPreference_t pref, + void* out, + const void* a, + const void* b, + const void* c, + const void* alpha_ptr, + const void* beta_ptr) { + if (heuristic.state != CUBLAS_STATUS_SUCCESS) { + int ret = 0; + CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( + handle, + matmul_desc, + a_desc, + b_desc, + c ? c_desc : out_desc, + out_desc, + pref, + 1, + &heuristic, + &ret)); + if (ret == 0) { + throw std::runtime_error("Can not find algorithm for matmul."); + } + } + + void* workspace_ptr = allocate_workspace(encoder, heuristic.workspaceSize); + + // Execute matmul + auto capture = encoder.capture_context(); + CHECK_CUBLAS_ERROR(cublasLtMatmul( + handle, + matmul_desc, + alpha_ptr, + b, // a and b are swapped for row-major layout + a_desc, + a, + b_desc, + beta_ptr, + c ? c : out, + c ? c_desc : out_desc, + out, + out_desc, + &heuristic.algo, + workspace_ptr, + heuristic.workspaceSize, + encoder.stream())); +} + +void set_bias( + cu::CommandEncoder& encoder, + cublasLtMatmulDesc_t matmul_desc, + const array& bias) { + encoder.set_input_array(bias); + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + auto* bias_ptr = gpu_ptr(bias); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias_ptr, + sizeof(bias_ptr))); +} + +} // namespace cublas_utils +} // namespace mlx::core diff --git a/mlx/backend/cuda/cublas_utils.h b/mlx/backend/cuda/cublas_utils.h new file mode 100644 index 0000000000..63fc6e733e --- /dev/null +++ b/mlx/backend/cuda/cublas_utils.h @@ -0,0 +1,55 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { +namespace cublas_utils { + +// Get the shared cublas preference for a device +cublasLtMatmulPreference_t get_preference(cu::Device& device); + +// Allocate workspace for matmul if needed and return pointer +// The workspace array is added to the encoder's temporaries +void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size); + +// Create matrix layout +cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + +// Execute matmul with pre-configured descriptors +void execute_matmul( + cu::CommandEncoder& encoder, + cublasLtHandle_t handle, + cublasLtMatmulDesc_t matmul_desc, + cublasLtMatrixLayout_t a_desc, + cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, + cublasLtMatrixLayout_t out_desc, + cublasLtMatmulHeuristicResult_t& heuristic, + cublasLtMatmulPreference_t pref, + void* out, + const void* a, + const void* b, + const void* c, + const void* alpha_ptr, + const void* beta_ptr); + +// Set bias for matmul epilogue +void set_bias( + cu::CommandEncoder& encoder, + cublasLtMatmulDesc_t matmul_desc, + const array& bias); + +} // namespace cublas_utils + +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 60ca2ccae0..5efb4af7d6 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/dtype_utils.h" #include "mlx/utils.h" @@ -11,35 +12,6 @@ namespace mlx::core { namespace { -struct CublasPreference { - CublasPreference(cu::Device& device) { - // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB - // for Hopper+: - // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace - uint64_t MiB = 1024 * 1024; - uint64_t workspace_size = - device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; - - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( - pref_, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(uint64_t))); - } - - ~CublasPreference() { - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); - } - - cublasLtMatmulPreference_t pref_{nullptr}; -}; - -cublasLtMatmulPreference_t cublas_preference(cu::Device& device) { - static CublasPreference pref(device); - return pref.pref_; -} - cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { case float16: @@ -78,34 +50,6 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) { } } -cublasLtMatrixLayout_t create_matrix_layout( - cudaDataType_t type, - uint64_t rows, - uint64_t cols, - bool transposed, - int64_t ld, - int32_t batch_count, - int64_t batch_stride) { - cublasLtMatrixLayout_t desc; - if (transposed) { - std::swap(rows, cols); - } - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); - if (batch_count > 1) { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch_count, - sizeof(int32_t))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &batch_stride, - sizeof(int64_t))); - } - return desc; -} - } // namespace CublasGemm::CublasGemm( @@ -123,7 +67,7 @@ CublasGemm::CublasGemm( int64_t a_batch_stride, int64_t b_batch_stride) : handle_(device.lt_handle()), - pref_(cublas_preference(device)), + pref_(cublas_utils::get_preference(device)), M_(a_rows), N_(b_cols) { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; @@ -163,11 +107,11 @@ CublasGemm::CublasGemm( sizeof(cublasOperation_t))); auto type = dtype_to_cublas_type(dtype); - a_desc_ = create_matrix_layout( + a_desc_ = cublas_utils::create_matrix_layout( type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride); - b_desc_ = create_matrix_layout( + b_desc_ = cublas_utils::create_matrix_layout( type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride); - out_desc_ = create_matrix_layout( + out_desc_ = cublas_utils::create_matrix_layout( type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols); } @@ -202,7 +146,7 @@ CublasGemm::CublasGemm( a_batch_stride, b_batch_stride) { auto type = dtype_to_cublas_type(dtype); - c_desc_ = create_matrix_layout( + c_desc_ = cublas_utils::create_matrix_layout( type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride); } @@ -223,7 +167,7 @@ void CublasGemm::set_out( int32_t batch_count, int64_t batch_stride) { CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); - out_desc_ = create_matrix_layout( + out_desc_ = cublas_utils::create_matrix_layout( dtype_to_cublas_type(dtype), cols, rows, @@ -233,22 +177,6 @@ void CublasGemm::set_out( batch_stride); } -void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) { - encoder.set_input_array(bias); - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epilogue, - sizeof(epilogue))); - auto* bias_ptr = gpu_ptr(bias); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_ptr, - sizeof(bias_ptr))); -} - void CublasGemm::run( cu::CommandEncoder& encoder, array& out, @@ -337,24 +265,6 @@ void CublasGemm::execute( const void* c, float alpha /* = 1 */, float beta /* = 0 */) { - if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { - int ret = 0; - CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( - handle_, - matmul_desc_, - a_desc_, - b_desc_, - c ? c_desc_ : out_desc_, - out_desc_, - pref_, - 1, - &heuristic_, - &ret)); - if (ret == 0) { - throw std::runtime_error("Can not find algorithm for matmul."); - } - } - const void* alpha_ptr = α const void* beta_ptr = β complex64_t alpha_c, beta_c; @@ -365,36 +275,22 @@ void CublasGemm::execute( beta_ptr = &beta_c; } - void* workspace_ptr = nullptr; - if (heuristic_.workspaceSize > 0) { - // Ensure workspace is 256-byte aligned - int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; - array workspace( - cu::malloc_async(nbytes, encoder.stream()), - {static_cast(heuristic_.workspaceSize)}, - int8); - encoder.add_temporary(workspace); - workspace_ptr = gpu_ptr(workspace); - } - - auto capture = encoder.capture_context(); - CHECK_CUBLAS_ERROR(cublasLtMatmul( + cublas_utils::execute_matmul( + encoder, handle_, matmul_desc_, - alpha_ptr, - b, // a and b are swapped a_desc_, - a, b_desc_, - beta_ptr, - c ? c : out, - c ? c_desc_ : out_desc_, - out, + c_desc_, out_desc_, - &heuristic_.algo, - workspace_ptr, - heuristic_.workspaceSize, - encoder.stream())); + heuristic_, + pref_, + out, + a, + b, + c, + alpha_ptr, + beta_ptr); } } // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index d6d2189b95..eb02452f2d 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -55,8 +55,6 @@ class CublasGemm { int32_t batch_count, int64_t batch_stride); - void set_bias(cu::CommandEncoder& encoder, const array& bias); - void run( cu::CommandEncoder& encoder, array& out, @@ -80,6 +78,11 @@ class CublasGemm { float alpha, float beta); + // Get the matmul descriptor for setting attributes like bias + cublasLtMatmulDesc_t matmul_desc() const { + return matmul_desc_; + } + private: void run_batched( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 8ccf3c4665..6d371eee36 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/matmul.h" +#include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/gemv.h" @@ -97,7 +98,7 @@ void gemm_and_bias( throw std::runtime_error( "[gemm_and_bias] complex64 bias epilogue isn’t supported in cublasLtMatmul."); } - gemm.set_bias(encoder, *bias); + cublas_utils::set_bias(encoder, gemm.matmul_desc(), *bias); } gemm.run( encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha); diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp new file mode 100644 index 0000000000..c49e0286bf --- /dev/null +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -0,0 +1,243 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/quantized/cublas_qqmm.h" + +#include +#include "mlx/backend/cuda/cublas_utils.h" + +#include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +// Currently cublas supports only mxfp8 and nvfp4 +// quantization modes for block scaled quantization +cudaDataType_t qmode_to_cublas_scale_dtype(std::string_view mode) { + if (mode == "mxfp8") { + return CUDA_R_8F_UE8M0; + } else if (mode == "nvfp4") { + return CUDA_R_8F_UE4M3; + } else { + throw std::runtime_error( + fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); + } +} + +cudaDataType_t qmode_to_cublas_dtype(std::string_view mode) { + if (mode == "mxfp8") { + return CUDA_R_8F_E4M3; + } else if (mode == "nvfp4") { + return CUDA_R_4F_E2M1; + } else { + throw std::runtime_error( + fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); + } +} + +cublasLtMatmulMatrixScale_t qmode_to_cublas_scale_mode(std::string_view mode) { + if (mode == "mxfp8") { + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + } else if (mode == "nvfp4") { + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + } else { + throw std::runtime_error( + fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); + } +} + +} // namespace + +CublasQQMM::CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + std::string_view qmode, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride) + : handle_(device.lt_handle()), + pref_(cublas_utils::get_preference(device)), + M_(a_transposed ? a_cols : a_rows), + N_(b_transposed ? b_rows : b_cols) { + a_scale_mode_ = qmode_to_cublas_scale_mode(qmode); + b_scale_mode_ = qmode_to_cublas_scale_mode(qmode); + + heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; + + cublasComputeType_t gemm_compute_type = + CUBLAS_COMPUTE_32F; // always for narrow precision + CHECK_CUBLAS_ERROR( + cublasLtMatmulDescCreate(&matmul_desc_, gemm_compute_type, CUDA_R_32F)); + + cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &a_op, + sizeof(cublasOperation_t))); + cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &b_op, + sizeof(cublasOperation_t))); + + // alpha, beta pointer mode set to host ? (TODO) + int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(int32_t))); + + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_B_SCALE_MODE, + &a_scale_mode_, + sizeof(a_scale_mode_))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_A_SCALE_MODE, + &b_scale_mode_, + sizeof(b_scale_mode_))); + + // a and b are swaped + a_desc_ = cublas_utils::create_matrix_layout( + qmode_to_cublas_dtype(qmode), + b_cols, + b_rows, + b_transposed, + ldb, + batch_count, + b_batch_stride); + b_desc_ = cublas_utils::create_matrix_layout( + qmode_to_cublas_dtype(qmode), + a_cols, + a_rows, + a_transposed, + lda, + batch_count, + a_batch_stride); + out_desc_ = cublas_utils::create_matrix_layout( + CUDA_R_16BF, // output in bf16 (TODO) + b_transposed ? b_rows : b_cols, // n + a_transposed ? a_cols : a_rows, // m + false, + b_transposed ? b_rows : b_cols, + batch_count, + a_rows * b_cols); + + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( + &a_desc_, qmode_to_cublas_dtype(qmode), b_cols, b_rows, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( + &b_desc_, qmode_to_cublas_dtype(qmode), a_cols, a_rows, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( + &out_desc_, + CUDA_R_16BF, // output in bf16 + b_transposed ? b_rows : b_cols, // m + a_rows, // asume that never transposed (supported only TN layout) + b_transposed ? b_rows : b_cols)); +} + +CublasQQMM::~CublasQQMM() { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); +} + +void CublasQQMM::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha) { + int batch_count = out.size() / (M_ * N_); + // if (batch_count / batch_shape.back() > 1) { + // run_batched( + // encoder, + // out, + // a, + // b, + // batch_shape, + // a_batch_strides, + // b_batch_strides, + // alpha); + // return; + // } + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(a_scale); + encoder.set_input_array(b_scale); + encoder.set_output_array(out); + + execute( + encoder, + gpu_ptr(out), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(a_scale), + gpu_ptr(b_scale), + nullptr, + alpha); +} + +void CublasQQMM::execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + float alpha /* = 1 */, + float beta /* = 0 */) { + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &b_scale, + sizeof(b_scale))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &a_scale, + sizeof(a_scale))); + + const void* alpha_ptr = α + const void* beta_ptr = β + + cublas_utils::execute_matmul( + encoder, + handle_, + matmul_desc_, + a_desc_, + b_desc_, + c_desc_, + out_desc_, + heuristic_, + pref_, + out, + a, + b, + c, + alpha_ptr, + beta_ptr); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h new file mode 100644 index 0000000000..9fa9645a34 --- /dev/null +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -0,0 +1,108 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +class CublasQQMM { + public: + CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + std::string_view quantization_mode, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + ~CublasQQMM(); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha = 1.0f); + + // void run( + // cu::CommandEncoder& encoder, + // array& out, + // const array& a, + // const array& b, + // const array& c, + // const Shape& batch_shape, + // const Strides& a_batch_strides, + // const Strides& b_batch_strides, + // const Strides& c_batch_strides, + // float alpha, + // float beta); + + // private: + // void run_batched( + // cu::CommandEncoder& encoder, + // array& out, + // const array& a, + // const array& b, + // const Shape& batch_shape, + // const Strides& a_batch_strides, + // const Strides& b_batch_strides, + // float alpha); + + // void run_batched( + // cu::CommandEncoder& encoder, + // array& out, + // const array& a, + // const array& b, + // const array& c, + // const Shape& batch_shape, + // const Strides& a_batch_strides, + // const Strides& b_batch_strides, + // const Strides& c_batch_strides, + // float alpha, + // float beta); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + float alpha = 1, + float beta = 0); + + uint64_t M_; + uint64_t N_; + std::string quantization_mode_; + cudaDataType_t scale_type_; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtHandle_t handle_{nullptr}; + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulMatrixScale_t a_scale_mode_; + cublasLtMatmulMatrixScale_t b_scale_mode_; + cublasLtMatmulMatrixScale_t c_scale_mode_; + cublasLtMatmulMatrixScale_t out_scale_mode_; + cublasLtMatmulHeuristicResult_t heuristic_; +}; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu new file mode 100644 index 0000000000..ff19057b08 --- /dev/null +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -0,0 +1,169 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/quantized/qqmm_utils.h" + +#include + +namespace mlx::core { + +namespace cg = cooperative_groups; + +// To pass scales to tensor cores, they need to be repacked into a tiled layout +// https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout +// Tiled layout for scale factors is very well described in CUTLASS +// documentation: +// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts +// Conceptually, it should be like this: +// q_w = mx.zeros(shape=(M, N)) <-- zeros just for an example +// s.shape = (M, N // 16) -- packed in row contigous order, group_size = 16 +// cbg_cnt = N // 16 // 4 +// rb_cnt = M // 128 +// tmp = x.reshape(rb_cnt, 4, 32, cbg_cnt, 4) +// repacked_scales = tmp.transpose(0, 3, 2, 1, 4) +// example: indecis of intial tile 128 x 4 of scales (packed in row major tensor +// (M, K // 16), where M = 128, K = 64): array([[0, 1, 2, 3], +// [4, 5, 6, 7], +// [8, 9, 10, 11], +// ..., +// [500, 501, 502, 503], +// [504, 505, 506, 507], +// [508, 509, 510, 511]] +// packed scales within tile 128 x 4: +// array([[[[[0, 1, 2, 3], <-- s_0,0..s_0,3 scales +// [128, 129, 130, 131], <-- s_32,0..s_32,3 scales +// [256, 257, 258, 259], <-- s_64,0..s_64,3 scales +// [384, 385, 386, 387]], <-- s_96,0..s_96,3 scales +// [[4, 5, 6, 7], <-- s_1,0..s_1,3 scales +// [132, 133, 134, 135], ... +// [260, 261, 262, 263], +// [388, 389, 390, 391]], +// [[124, 125, 126, 127], +// [252, 253, 254, 255], +// [380, 381, 382, 383], +// [508, 509, 510, 511]]]]], +__device__ size_t +scale_tiled_offset(size_t scale_index, size_t num_rows, size_t num_scale_cols) { + // Compute the tiled layout offset for scale factors used in tensor cores + // This function maps from a linear scale index to the tiled layout expected + // by tensor cores (and cublaslt). + // + // Input: linear scale index (e.g., for a matrix M x K with group_size, + // scale_index ranges from 0 to (M * K/group_size - 1)) + // + // The tiled layout organizes scales into tiles of 128 rows x 4 columns, + // where each tile is subdivided into 4 sub-blocks of 32 rows x 4 columns. + size_t row = scale_index / num_scale_cols; + size_t col = scale_index % num_scale_cols; + + constexpr size_t rows_per_tile = 128; + constexpr size_t rows_per_sub_block = 32; + constexpr size_t cols_per_sub_block = 4; + constexpr size_t sub_blocks_per_tile = 4; // Vertically stacked + + // Decompose row position + size_t tile_row = row / rows_per_tile; // Which tile row + size_t row_in_tile = row % rows_per_tile; // Row within tile + size_t sub_block_row = + row_in_tile / rows_per_sub_block; // Sub-block within tile + size_t row_in_sub_block = + row_in_tile % rows_per_sub_block; // Row in sub-block + + // Decompose column position + size_t col_tile = col / cols_per_sub_block; // Which column tile + size_t col_in_sub_block = col % cols_per_sub_block; // Column within sub-block + + // Compute tile index and offset within tile + size_t num_col_tiles = cuda::ceil_div(num_scale_cols, cols_per_sub_block); + size_t tile_idx = tile_row * num_col_tiles + col_tile; + + size_t offset_in_tile = + (row_in_sub_block * sub_blocks_per_tile * cols_per_sub_block) + + (sub_block_row * cols_per_sub_block) + col_in_sub_block; + + constexpr size_t tile_size = rows_per_tile * cols_per_sub_block; + return tile_idx * tile_size + offset_in_tile; +} + +namespace cu { + +__global__ void repack_scales( + const uint8_t* scales_linear, + uint8_t* scales_tiled, + size_t input_rows, + size_t input_cols, + size_t output_rows, + size_t output_cols) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim_x = + cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; + + size_t output_index = tidx + grid_dim_x * size_t(tidy); + size_t output_size = output_rows * output_cols; + + if (output_index >= output_size) { + return; + } + + size_t tiled_offset = + scale_tiled_offset(output_index, output_rows, output_cols); + + size_t row = output_index / output_cols; + size_t col = output_index % output_cols; + + // Probably this can be done better with 2 separated paths for valid and + // padding + if (row < input_rows && col < input_cols) { + size_t input_index = row * input_cols + col; + scales_tiled[tiled_offset] = scales_linear[input_index]; + } else { + // Zero-fill padding region + scales_tiled[tiled_offset] = 0; + } +} + +} // namespace cu + +void repack_scales( + const array& scales, + array& scales_tiled, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(scales); + enc.set_output_array(scales_tiled); + + // Note: scales_tiled is padded to full tiles so if num_rows or num_cols + // are not multiples of tile sizes, the extra space is filled with zeros + + size_t input_rows = scales.shape(-2); + size_t input_cols = scales.shape(-1); + + size_t output_rows = scales_tiled.shape(-2); + size_t output_cols = scales_tiled.shape(-1); + size_t output_size = output_rows * output_cols; + + bool large = output_size > UINT_MAX; + auto [num_blocks, block_dims] = get_launch_args( + output_size, scales_tiled.shape(), scales_tiled.strides(), large); + + enc.add_kernel_node( + cu::repack_scales, + num_blocks, + block_dims, + 0, + gpu_ptr(scales), + gpu_ptr(scales_tiled), + input_rows, + input_cols, + output_rows, + output_cols); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.h b/mlx/backend/cuda/quantized/qqmm_utils.h new file mode 100644 index 0000000000..126cc298b2 --- /dev/null +++ b/mlx/backend/cuda/quantized/qqmm_utils.h @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +// Compute padded dimensions for tiled layout +// Tiles are 128 rows × 4 columns, must allocate full tiles +inline std::pair get_padded_scale_dims(int num_rows, int num_cols) { + constexpr int rows_per_tile = 128; + constexpr int cols_per_tile = 4; + + int padded_rows = + ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile; + int padded_cols = + ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile; + + return {padded_rows, padded_cols}; +} + +void repack_scales( + const array& scales, + array& scales_tiled, + cu::CommandEncoder& enc, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index f75064d4e9..7e7ab7d328 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -1,12 +1,14 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/quantized/quantized.h" +#include +#include "mlx/backend/common/matmul.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/quantized/cublas_qqmm.h" +#include "mlx/backend/cuda/quantized/qqmm_utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" -#include - namespace mlx::core { namespace { @@ -44,6 +46,29 @@ inline array ensure_row_contiguous_matrix( return x_copy; } +array pad_and_repack_scales( + const array& scale, + cu::CommandEncoder& encoder, + const Stream& s) { + // Compute padded dimensions for full tiles (128 rows × 4 cols) + auto [pad_outer, pad_inner] = + get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); + // cuBLAS requirements for scale factor layout: + // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) + // 2. Out-of-bounds values must be filled with zeros + // 3. Starting addresses must be 16-byte aligned + // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + // Note: cu::malloc_async already provides 256-byte alignment + array scale_tiled( + cu::malloc_async(pad_outer * pad_inner, encoder.stream()), + Shape{pad_outer, pad_inner}, + scale.dtype()); + repack_scales(scale, scale_tiled, encoder, s); + + encoder.add_temporary(scale_tiled); + return scale_tiled; +} + } // namespace void fast::Quantize::eval_gpu( @@ -84,4 +109,120 @@ void fast::Quantize::eval_gpu( } } +namespace { +void qqmm_impl( + cu::CommandEncoder& encoder, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + QuantizationMode mode, + float alpha = 1.0f) { + // Invoke CublasQQMM + auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); + auto batch_count = out.size() / (M * N); + + std::string_view qmode = quantization_mode_to_string(mode); + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + batch_shape = {1}; + } + + CublasQQMM qqmm( + encoder.device(), + a_transposed, + M, + K, + lda, + b_transposed, + N, + K, + ldb, + qmode, + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + + qqmm.run( + encoder, + out, + a, + b, + a_scale, + b_scale, + batch_shape, + a_batch_strides, + b_batch_strides, + alpha); +} +} // namespace + +void DualQuantizedMatmul::eval_gpu( + const std::vector& inputs, + array& out) { + nvtx3::scoped_range r("DualQuantizedMatmul::eval_gpu"); + // WIP need to add primitive + // TODO: for now minimalistic implementation without batching support + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + assert(inputs.size() == 4); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& scale_a_pre = inputs[2]; + auto& scale_b_pre = inputs[3]; + // Return 0s if either input is empty. + if (a.size() == 0 || b.size() == 0) { + array zero(0, a.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); + + int M = a.shape(-2); + int N = b.shape(-2); // b always transposed + int K_packed = a.shape(-1); + int K = K_packed * (32 / bits_); + + // Repack scales from linear to tiled layout for tensor cores + array scale_a_tiled = pad_and_repack_scales(scale_a_pre, encoder, s); + array scale_b_tiled = pad_and_repack_scales(scale_b_pre, encoder, s); + + bool a_transposed = false; // a is normal (M x K) + bool b_transposed = true; // b is transposed (N x K -> K x N) + int64_t lda = K; // Leading dimension of a (packed) + int64_t ldb = K; // Leading dimension of b (packed) + + qqmm_impl( + encoder, + M, + N, + K, + a_transposed, + lda, + b_transposed, + ldb, + out, + a, + b, + scale_a_tiled, + scale_b_tiled, + mode_); +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index b32e074e8f..63db747392 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -96,6 +96,7 @@ NO_CPU(Partition) NO_CPU(Power) NO_CPU_MULTI(QRF) NO_CPU(QuantizedMatmul) +NO_CPU(DualQuantizedMatmul) NO_CPU(RandomBits) NO_CPU(Real) NO_CPU(Reduce) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 406a627b9c..dcc7e04e78 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -112,6 +112,7 @@ NO_GPU(Partition) NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) +NO_GPU(DualQuantizedMatmul) NO_GPU(RandomBits) NO_GPU(Real) NO_GPU(Reduce) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f58c73a6ce..15621f4899 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -70,22 +70,24 @@ array indices_or_default( return reshape(arange(total, uint32, s), std::move(shape), s); } -std::pair extract_quantized_matmul_dims( +void validate_quantized_input( std::string_view tag, - const array& x, - const array& w, + const array& matrix, const array& scales, - const std::optional& biases, - bool transpose, + const std::string& matrix_name, + const std::string& scales_name, int group_size, - int bits) { - if (w.dtype() != uint32) { + int bits, + const std::optional& biases = std::nullopt) { + // If matrix is quantized + if (matrix.dtype() != uint32) { std::ostringstream msg; - msg << "[" << tag << "] The weight matrix should be uint32 " - << "but received " << w.dtype(); + msg << "[" << tag << "] The " << matrix_name << " should be uint32 " + << "but received " << matrix.dtype(); throw std::invalid_argument(msg.str()); } + // If biases and scales have same shape if biases provided if (biases && scales.shape() != biases->shape()) { std::ostringstream msg; msg << "[" << tag << "] Scales and biases should have the same shape. " @@ -94,24 +96,43 @@ std::pair extract_quantized_matmul_dims( throw std::invalid_argument(msg.str()); } + // Batch shapes match if (!std::equal( - w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) { + matrix.shape().begin(), + matrix.shape().end() - 2, + scales.shape().begin())) { std::ostringstream msg; - msg << "[" << tag - << "] Weight and scales should have the same batch shape. " - << "Received weight with shape " << w.shape() << ", scales with " - << scales.shape() << "."; + msg << "[" << tag << "] " << matrix_name << " and " << scales_name + << " should have the same batch shape. " + << "Received " << matrix_name << " with shape " << matrix.shape() + << ", " << scales_name << " with " << scales.shape() << "."; throw std::invalid_argument(msg.str()); } - if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { + // Shape compatibility based on bits and group_size + if (matrix.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { std::ostringstream msg; - msg << "[" << tag << "] The shapes of the weight and scales are " - << "incompatible based on bits and group_size. w.shape() == " - << w.shape() << " and scales.shape() == " << scales.shape() - << " with group_size=" << group_size << " and bits=" << bits; + msg << "[" << tag << "] The shapes of the " << matrix_name << " and " + << scales_name << " are " + << "incompatible based on bits and group_size. " << matrix_name + << ".shape() == " << matrix.shape() << " and " << scales_name + << ".shape() == " << scales.shape() << " with group_size=" << group_size + << " and bits=" << bits; throw std::invalid_argument(msg.str()); } +} + +std::pair extract_quantized_matmul_dims( + std::string_view tag, + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + bool transpose, + int group_size, + int bits) { + validate_quantized_input( + tag, w, scales, "weight", "scales", group_size, bits, biases); int x_inner_dims = x.shape(-1); @@ -133,6 +154,45 @@ std::pair extract_quantized_matmul_dims( return {w_inner_dims, w_outer_dims}; } +std::pair, std::pair> extract_qqmm_dims( + std::string_view tag, + const array& x, + const array& w, + const array& scales_x, + const array& scales_w, + bool transpose, + int group_size, + int bits) { + // Validate x and scales_x + validate_quantized_input( + tag, x, scales_x, "x matrix", "scales_x", group_size, bits); + + // Validate w and scales_w + validate_quantized_input( + tag, w, scales_w, "weight matrix", "scales_w", group_size, bits); + + // For narrow precision types (mxfp4, nvfp4) the only supported layout is TN + // A is MxK, B is NxK (transposed) + int x_inner_dims = x.shape(-1); // K // (32 / bits) + int x_outer_dims = x.shape(-2); // M + + // Calculate the expanded w's dimensions + int w_inner_dims = (transpose) ? w.shape(-1) : w.shape(-2); + int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1); + + if (w_inner_dims != x_inner_dims) { + std::ostringstream msg; + msg << "[" << tag << "] Last dimension of first quantized input with " + << "shape (..., " << x_inner_dims << ") does not match " + << "the quantized matrix (" << w_inner_dims << ", " << w_outer_dims + << ") computed with transpose=" << std::boolalpha << transpose; + + throw std::invalid_argument(msg.str()); + } + + return {{x_inner_dims, x_outer_dims}, {w_inner_dims, w_outer_dims}}; +} + } // namespace array arange( @@ -4146,7 +4206,6 @@ array quantized_matmul( } else { inputs = {x, w, scales}; } - if (x.ndim() > 2 && w.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } @@ -4161,6 +4220,68 @@ array quantized_matmul( std::move(inputs)); } +array qqmm( + array x, + array w, + array scales_x, + array scales_w, + bool transpose /* = true */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, + const std::string& mode /* = "nvfp4" */, + bool quantize_output /* = false */, + StreamOrDevice s /* = {} */) { + // currently only simetric quantization is supported for qqmm + auto qmode = string_to_quantization_mode(mode, "qqmm"); + // For narrow precision MMAs on B200 only TN layout is supported: + // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html + // TODO: handle it better + if ((qmode == QuantizationMode::Nvfp4 || qmode == QuantizationMode::Mxfp4) && + !transpose) { + std::ostringstream msg; + msg << "[qqmm] transpose must be set to true with " << mode + << " quantization but " + << "transpose == false was provided."; + throw std::invalid_argument(msg.str()); + } + if (qmode == QuantizationMode::Affine) { + std::ostringstream msg; + msg << "[qqmm] Affine quantization is not supported for qqmm."; + throw std::invalid_argument(msg.str()); + } + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); + // Check and extract the quantized matrix shape against x + + auto [x_dims, w_dims] = extract_qqmm_dims( + "qqmm", x, w, scales_x, scales_w, transpose, group_size, bits); + auto [x_inner_dims, x_outer_dims] = x_dims; + auto [w_inner_dims, w_outer_dims] = w_dims; + + std::vector inputs = {x, w, scales_x, scales_w}; + + if (x.ndim() > 2 && w.ndim() > 2) { + inputs = broadcast_arrays(inputs, {-2, -1}, s); + } + + auto out_shape = inputs[0].shape(); + if (!quantize_output) { + out_shape.back() = w_outer_dims; // result should be the same shape (M, N) + // if not packed in uint32 + } else { + out_shape.back() = w_outer_dims / (32 / bits); // packed output + } + + // out dtype can be only bf16 if not quantized + auto dtype = quantize_output ? x.dtype() : bfloat16; + return array( + std::move(out_shape), + dtype, + std::make_shared( + to_stream(s), group_size, bits, qmode, transpose, quantize_output), + std::move(inputs)); +} + array pack_and_quantize( array& packed_w, const array& scales, diff --git a/mlx/ops.h b/mlx/ops.h index 6c44c032c0..47effbd580 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1403,6 +1403,18 @@ std::vector quantize( const std::string& mode = "affine", StreamOrDevice s = {}); +array qqmm( + const array x, + const array w, + const array x_scales, + const array w_scales, + bool transpose = true, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "nvfp4", + bool quantize_output = false, + StreamOrDevice s = {}); + /** Dequantize a matrix produced by quantize() */ array dequantize( const array& w, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 976200ba34..242722cdf3 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -10,6 +10,7 @@ #include #include +#include #include "mlx/backend/common/utils.h" #include "mlx/fft.h" #include "mlx/linalg.h" @@ -3468,6 +3469,24 @@ std::vector QuantizedMatmul::output_shapes( return {std::move(out_shape)}; } +bool DualQuantizedMatmul::is_equivalent(const Primitive& other) const { + const DualQuantizedMatmul& qm_other = + static_cast(other); + return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && + mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; +} + +std::vector DualQuantizedMatmul::output_shapes( + const std::vector& inputs) { + auto out_shape = inputs[0].shape(); + auto& w = inputs[1]; + int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1); + w_outer_dims /= quantize_output_ ? (32 / bits_) : 1; + out_shape.back() = w_outer_dims; + std::cout << "DualQuantizedMatmul output shape: " << out_shape << std::endl; + return {std::move(out_shape)}; +} + std::pair, std::vector> GatherQMM::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index d3260f0f4c..f77f6922a4 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1642,6 +1642,43 @@ class QuantizedMatmul : public UnaryPrimitive { bool transpose_; }; +class DualQuantizedMatmul : public UnaryPrimitive { + public: + explicit DualQuantizedMatmul( + Stream stream, + int group_size, + int bits, + QuantizationMode mode, + bool transpose, + bool quantize_output) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode), + transpose_(transpose), + quantize_output_(quantize_output) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + // DEFINE_VMAP() + // DEFINE_GRADS() + DEFINE_NAME(DualQuantizedMatmul) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple( + group_size_, bits_, mode_, transpose_, quantize_output_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool transpose_; + bool quantize_output_; +}; + class GatherQMM : public UnaryPrimitive { public: explicit GatherQMM( From 54f99589949b0f87b436d92ab29be8684a683554 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 18 Nov 2025 19:26:39 +0100 Subject: [PATCH 02/29] merge main --- mlx/backend/cuda/cublas_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/cublas_utils.cpp b/mlx/backend/cuda/cublas_utils.cpp index 9af902d5a4..3393633b0e 100644 --- a/mlx/backend/cuda/cublas_utils.cpp +++ b/mlx/backend/cuda/cublas_utils.cpp @@ -48,7 +48,7 @@ void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size) { // Ensure workspace is 256-byte aligned int nbytes = cuda::ceil_div(workspace_size, 256) * 256; array workspace( - cu::malloc_async(nbytes, encoder.stream()), + cu::malloc_async(nbytes, encoder), {static_cast(workspace_size)}, int8); encoder.add_temporary(workspace); From 61e30ea7f04b87178bbe1d54ff58102e45051ac1 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 19 Nov 2025 18:17:23 +0100 Subject: [PATCH 03/29] refactoring --- mlx/backend/cuda/cublas_utils.cpp | 142 +++++++++++++----- mlx/backend/cuda/cublas_utils.h | 81 +++++++---- mlx/backend/cuda/gemms/cublas_gemm.cpp | 86 +++-------- mlx/backend/cuda/gemms/cublas_gemm.h | 17 +-- mlx/backend/cuda/quantized/cublas_qqmm.cpp | 162 +++++++++------------ mlx/backend/cuda/quantized/cublas_qqmm.h | 35 +++-- mlx/ops.cpp | 14 +- mlx/ops.h | 1 - mlx/primitives.cpp | 1 - mlx/primitives.h | 10 +- 10 files changed, 282 insertions(+), 267 deletions(-) diff --git a/mlx/backend/cuda/cublas_utils.cpp b/mlx/backend/cuda/cublas_utils.cpp index 3393633b0e..4ccb0816e2 100644 --- a/mlx/backend/cuda/cublas_utils.cpp +++ b/mlx/backend/cuda/cublas_utils.cpp @@ -83,78 +83,154 @@ cublasLtMatrixLayout_t create_matrix_layout( return desc; } -void execute_matmul( +} // namespace cublas_utils + +CublasMatmulBase::~CublasMatmulBase() { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); +} + +void CublasMatmulBase::init_base( + cu::Device& device, + cudaDataType_t scale_type, + cublasComputeType_t compute_type, + cudaDataType_t data_type, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride) { + M_ = a_rows; + N_ = b_cols; + scale_type_ = scale_type; + handle_ = device.lt_handle(); + pref_ = cublas_utils::get_preference(device); + heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; + + CHECK_CUBLAS_ERROR( + cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type)); + + int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(int32_t))); + + // In cublasLt matrices use column-major layout, while it is possible to use + // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias + // epilogue does not work with the option. So instead we swap A and B to make + // cublasLt return the row-major result, which works because: + // - the data of a matrix in row-major layout is identical to its transpose in + // column-major layout + // - C^T = (A @ B)^T = B^T @ A^T + cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &a_op, + sizeof(cublasOperation_t))); + cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &b_op, + sizeof(cublasOperation_t))); + + a_desc_ = cublas_utils::create_matrix_layout( + data_type, + b_cols, + b_rows, + b_transposed, + ldb, + batch_count, + b_batch_stride); + b_desc_ = cublas_utils::create_matrix_layout( + data_type, + a_cols, + a_rows, + a_transposed, + lda, + batch_count, + a_batch_stride); + out_desc_ = cublas_utils::create_matrix_layout( + data_type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols); +} + +void CublasMatmulBase::execute_matmul( cu::CommandEncoder& encoder, - cublasLtHandle_t handle, - cublasLtMatmulDesc_t matmul_desc, - cublasLtMatrixLayout_t a_desc, - cublasLtMatrixLayout_t b_desc, - cublasLtMatrixLayout_t c_desc, - cublasLtMatrixLayout_t out_desc, - cublasLtMatmulHeuristicResult_t& heuristic, - cublasLtMatmulPreference_t pref, void* out, const void* a, const void* b, const void* c, const void* alpha_ptr, const void* beta_ptr) { - if (heuristic.state != CUBLAS_STATUS_SUCCESS) { + if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { int ret = 0; CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( - handle, - matmul_desc, - a_desc, - b_desc, - c ? c_desc : out_desc, - out_desc, - pref, + handle_, + matmul_desc_, + a_desc_, + b_desc_, + c ? c_desc_ : out_desc_, + out_desc_, + pref_, 1, - &heuristic, + &heuristic_, &ret)); if (ret == 0) { throw std::runtime_error("Can not find algorithm for matmul."); } } - void* workspace_ptr = allocate_workspace(encoder, heuristic.workspaceSize); + void* workspace_ptr = + cublas_utils::allocate_workspace(encoder, heuristic_.workspaceSize); // Execute matmul auto capture = encoder.capture_context(); CHECK_CUBLAS_ERROR(cublasLtMatmul( - handle, - matmul_desc, + handle_, + matmul_desc_, alpha_ptr, b, // a and b are swapped for row-major layout - a_desc, + a_desc_, a, - b_desc, + b_desc_, beta_ptr, c ? c : out, - c ? c_desc : out_desc, + c ? c_desc_ : out_desc_, out, - out_desc, - &heuristic.algo, + out_desc_, + &heuristic_.algo, workspace_ptr, - heuristic.workspaceSize, + heuristic_.workspaceSize, encoder.stream())); } -void set_bias( +void CublasMatmulBase::set_bias( cu::CommandEncoder& encoder, - cublasLtMatmulDesc_t matmul_desc, const array& bias) { encoder.set_input_array(bias); cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + matmul_desc_, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue))); auto* bias_ptr = gpu_ptr(bias); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc, + matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); } - -} // namespace cublas_utils } // namespace mlx::core diff --git a/mlx/backend/cuda/cublas_utils.h b/mlx/backend/cuda/cublas_utils.h index 63fc6e733e..e349386b1d 100644 --- a/mlx/backend/cuda/cublas_utils.h +++ b/mlx/backend/cuda/cublas_utils.h @@ -12,11 +12,8 @@ namespace cublas_utils { // Get the shared cublas preference for a device cublasLtMatmulPreference_t get_preference(cu::Device& device); -// Allocate workspace for matmul if needed and return pointer -// The workspace array is added to the encoder's temporaries void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size); -// Create matrix layout cublasLtMatrixLayout_t create_matrix_layout( cudaDataType_t type, uint64_t rows, @@ -26,30 +23,60 @@ cublasLtMatrixLayout_t create_matrix_layout( int32_t batch_count, int64_t batch_stride); -// Execute matmul with pre-configured descriptors -void execute_matmul( - cu::CommandEncoder& encoder, - cublasLtHandle_t handle, - cublasLtMatmulDesc_t matmul_desc, - cublasLtMatrixLayout_t a_desc, - cublasLtMatrixLayout_t b_desc, - cublasLtMatrixLayout_t c_desc, - cublasLtMatrixLayout_t out_desc, - cublasLtMatmulHeuristicResult_t& heuristic, - cublasLtMatmulPreference_t pref, - void* out, - const void* a, - const void* b, - const void* c, - const void* alpha_ptr, - const void* beta_ptr); - -// Set bias for matmul epilogue -void set_bias( - cu::CommandEncoder& encoder, - cublasLtMatmulDesc_t matmul_desc, - const array& bias); - } // namespace cublas_utils +class CublasMatmulBase { + public: + virtual ~CublasMatmulBase(); + + cublasLtMatmulDesc_t matmul_desc() const { + return matmul_desc_; + } + + protected: + CublasMatmulBase() = default; + + // Common member variables shared by all matmul types + uint64_t M_; + uint64_t N_; + cudaDataType_t scale_type_; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtHandle_t handle_{nullptr}; + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulHeuristicResult_t heuristic_; + + void init_base( + cu::Device& device, + cudaDataType_t scale_type, + cublasComputeType_t compute_type, + cudaDataType_t data_type, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + // Execute matmul using the configured descriptors + void execute_matmul( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + const void* alpha_ptr, + const void* beta_ptr); +}; + +void set_bias(cu::CommandEncoder& encoder, const array& bias); + } // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 5efb4af7d6..af7b0b8fd3 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -65,54 +65,27 @@ CublasGemm::CublasGemm( int64_t ldb, int32_t batch_count, int64_t a_batch_stride, - int64_t b_batch_stride) - : handle_(device.lt_handle()), - pref_(cublas_utils::get_preference(device)), - M_(a_rows), - N_(b_cols) { - heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - + int64_t b_batch_stride) { scale_type_ = dtype_to_cublas_type(dtype); if (dtype == bfloat16 || dtype == float16) { scale_type_ = CUDA_R_32F; } - - CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( - &matmul_desc_, dtype_to_compute_type(dtype), scale_type_)); - int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, - sizeof(int32_t))); - - // In cublasLt matrices use column-major layout, while it is possible to use - // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias - // epilogue does not work with the option. So instead we swap A and B to make - // cublasLt return the row-major result, which works because: - // - the data of a matrix in row-major layout is identical to its transpose in - // column-major layout - // - C^T = (A @ B)^T = B^T @ A^T - cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &a_op, - sizeof(cublasOperation_t))); - cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, - &b_op, - sizeof(cublasOperation_t))); - - auto type = dtype_to_cublas_type(dtype); - a_desc_ = cublas_utils::create_matrix_layout( - type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride); - b_desc_ = cublas_utils::create_matrix_layout( - type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride); - out_desc_ = cublas_utils::create_matrix_layout( - type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols); + init_base( + device, + scale_type_, + dtype_to_compute_type(dtype), + dtype_to_cublas_type(dtype), + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride); } CublasGemm::CublasGemm( @@ -150,14 +123,6 @@ CublasGemm::CublasGemm( type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride); } -CublasGemm::~CublasGemm() { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); -} - void CublasGemm::set_out( Dtype dtype, bool transposed, @@ -275,22 +240,7 @@ void CublasGemm::execute( beta_ptr = &beta_c; } - cublas_utils::execute_matmul( - encoder, - handle_, - matmul_desc_, - a_desc_, - b_desc_, - c_desc_, - out_desc_, - heuristic_, - pref_, - out, - a, - b, - c, - alpha_ptr, - beta_ptr); + execute_matmul(encoder, out, a, b, c, alpha_ptr, beta_ptr); } } // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index eb02452f2d..aca4f54dc0 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -2,13 +2,14 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include namespace mlx::core { -class CublasGemm { +class CublasGemm : public CublasMatmulBase { public: CublasGemm( cu::Device& device, @@ -42,8 +43,6 @@ class CublasGemm { int64_t b_batch_stride, int64_t c_batch_stride); - ~CublasGemm(); - // The output's descriptor is inferred from inputs by default, use this method // for unusual output. void set_out( @@ -115,18 +114,6 @@ class CublasGemm { const void* c, float alpha = 1, float beta = 0); - - uint64_t M_; - uint64_t N_; - cudaDataType_t scale_type_; - cublasLtMatmulPreference_t pref_{nullptr}; - cublasLtHandle_t handle_{nullptr}; - cublasLtMatmulDesc_t matmul_desc_{nullptr}; - cublasLtMatrixLayout_t a_desc_{nullptr}; - cublasLtMatrixLayout_t b_desc_{nullptr}; - cublasLtMatrixLayout_t c_desc_{nullptr}; - cublasLtMatrixLayout_t out_desc_{nullptr}; - cublasLtMatmulHeuristicResult_t heuristic_; }; } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index c49e0286bf..c5a59900da 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -60,44 +60,36 @@ CublasQQMM::CublasQQMM( uint64_t b_rows, uint64_t b_cols, int64_t ldb, - std::string_view qmode, int32_t batch_count, int64_t a_batch_stride, - int64_t b_batch_stride) - : handle_(device.lt_handle()), - pref_(cublas_utils::get_preference(device)), - M_(a_transposed ? a_cols : a_rows), - N_(b_transposed ? b_rows : b_cols) { - a_scale_mode_ = qmode_to_cublas_scale_mode(qmode); - b_scale_mode_ = qmode_to_cublas_scale_mode(qmode); - - heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - + int64_t b_batch_stride, + std::string_view qmode) { + cudaDataType_t scale_type = CUDA_R_32F; cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; // always for narrow precision - CHECK_CUBLAS_ERROR( - cublasLtMatmulDescCreate(&matmul_desc_, gemm_compute_type, CUDA_R_32F)); + cudaDataType_t data_type = qmode_to_cublas_dtype(qmode); - cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &a_op, - sizeof(cublasOperation_t))); - cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, - &b_op, - sizeof(cublasOperation_t))); + quantization_mode_ = std::string(qmode); - // alpha, beta pointer mode set to host ? (TODO) - int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, - sizeof(int32_t))); + init_base( + device, + scale_type, + gemm_compute_type, + data_type, + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride); + + a_scale_mode_ = qmode_to_cublas_scale_mode(qmode); + b_scale_mode_ = qmode_to_cublas_scale_mode(qmode); CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, @@ -110,50 +102,55 @@ CublasQQMM::CublasQQMM( &b_scale_mode_, sizeof(b_scale_mode_))); - // a and b are swaped - a_desc_ = cublas_utils::create_matrix_layout( - qmode_to_cublas_dtype(qmode), - b_cols, - b_rows, - b_transposed, - ldb, - batch_count, - b_batch_stride); - b_desc_ = cublas_utils::create_matrix_layout( - qmode_to_cublas_dtype(qmode), - a_cols, - a_rows, - a_transposed, - lda, - batch_count, - a_batch_stride); - out_desc_ = cublas_utils::create_matrix_layout( - CUDA_R_16BF, // output in bf16 (TODO) - b_transposed ? b_rows : b_cols, // n - a_transposed ? a_cols : a_rows, // m - false, - b_transposed ? b_rows : b_cols, - batch_count, - a_rows * b_cols); - - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &a_desc_, qmode_to_cublas_dtype(qmode), b_cols, b_rows, ldb)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &b_desc_, qmode_to_cublas_dtype(qmode), a_cols, a_rows, lda)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &out_desc_, - CUDA_R_16BF, // output in bf16 - b_transposed ? b_rows : b_cols, // m - a_rows, // asume that never transposed (supported only TN layout) - b_transposed ? b_rows : b_cols)); + // out_desc_ = create_matrix_layout( + // CUDA_R_16BF, // output in bf16 + // b_transposed ? b_rows : b_cols, + // a_transposed ? a_cols : a_rows, + // false, + // b_transposed ? b_rows : b_cols, + // batch_count, + // (a_transposed ? a_cols : a_rows) * (b_transposed ? b_rows : b_cols)); } -CublasQQMM::~CublasQQMM() { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); +CublasQQMM::CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride, + std::string_view qmode) + : CublasQQMM( + device, + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride, + qmode) { + auto type = CUDA_R_16BF; // always c in bf16 + c_desc_ = cublas_utils::create_matrix_layout( + type, + b_transposed ? b_rows : b_cols, + a_transposed ? a_cols : a_rows, + false, + ldc, + batch_count, + c_batch_stride); } void CublasQQMM::run( @@ -222,22 +219,7 @@ void CublasQQMM::execute( const void* alpha_ptr = α const void* beta_ptr = β - cublas_utils::execute_matmul( - encoder, - handle_, - matmul_desc_, - a_desc_, - b_desc_, - c_desc_, - out_desc_, - heuristic_, - pref_, - out, - a, - b, - c, - alpha_ptr, - beta_ptr); + execute_matmul(encoder, out, a, b, c, alpha_ptr, beta_ptr); } } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index 9fa9645a34..a5650c9a17 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -2,13 +2,14 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include namespace mlx::core { -class CublasQQMM { +class CublasQQMM : public CublasMatmulBase { public: CublasQQMM( cu::Device& device, @@ -20,12 +21,27 @@ class CublasQQMM { uint64_t b_rows, uint64_t b_cols, int64_t ldb, - std::string_view quantization_mode, int32_t batch_count, int64_t a_batch_stride, - int64_t b_batch_stride); + int64_t b_batch_stride, + std::string_view quantization_mode); - ~CublasQQMM(); + CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride, + std::string_view quantization_mode); void run( cu::CommandEncoder& encoder, @@ -87,22 +103,11 @@ class CublasQQMM { float alpha = 1, float beta = 0); - uint64_t M_; - uint64_t N_; std::string quantization_mode_; - cudaDataType_t scale_type_; - cublasLtMatmulPreference_t pref_{nullptr}; - cublasLtHandle_t handle_{nullptr}; - cublasLtMatmulDesc_t matmul_desc_{nullptr}; - cublasLtMatrixLayout_t a_desc_{nullptr}; - cublasLtMatrixLayout_t b_desc_{nullptr}; - cublasLtMatrixLayout_t c_desc_{nullptr}; - cublasLtMatrixLayout_t out_desc_{nullptr}; cublasLtMatmulMatrixScale_t a_scale_mode_; cublasLtMatmulMatrixScale_t b_scale_mode_; cublasLtMatmulMatrixScale_t c_scale_mode_; cublasLtMatmulMatrixScale_t out_scale_mode_; - cublasLtMatmulHeuristicResult_t heuristic_; }; } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 15621f4899..5b245c5a24 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4229,7 +4229,6 @@ array qqmm( std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, - bool quantize_output /* = false */, StreamOrDevice s /* = {} */) { // currently only simetric quantization is supported for qqmm auto qmode = string_to_quantization_mode(mode, "qqmm"); @@ -4265,20 +4264,15 @@ array qqmm( } auto out_shape = inputs[0].shape(); - if (!quantize_output) { - out_shape.back() = w_outer_dims; // result should be the same shape (M, N) - // if not packed in uint32 - } else { - out_shape.back() = w_outer_dims / (32 / bits); // packed output - } + out_shape.back() = w_outer_dims; - // out dtype can be only bf16 if not quantized - auto dtype = quantize_output ? x.dtype() : bfloat16; + // out dtype can be only bf16 + auto dtype = bfloat16; return array( std::move(out_shape), dtype, std::make_shared( - to_stream(s), group_size, bits, qmode, transpose, quantize_output), + to_stream(s), group_size, bits, qmode, transpose), std::move(inputs)); } diff --git a/mlx/ops.h b/mlx/ops.h index 47effbd580..6ca5dd43b6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1412,7 +1412,6 @@ array qqmm( std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "nvfp4", - bool quantize_output = false, StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 242722cdf3..bca5935f36 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3481,7 +3481,6 @@ std::vector DualQuantizedMatmul::output_shapes( auto out_shape = inputs[0].shape(); auto& w = inputs[1]; int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1); - w_outer_dims /= quantize_output_ ? (32 / bits_) : 1; out_shape.back() = w_outer_dims; std::cout << "DualQuantizedMatmul output shape: " << out_shape << std::endl; return {std::move(out_shape)}; diff --git a/mlx/primitives.h b/mlx/primitives.h index f77f6922a4..6f2821794a 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1649,14 +1649,12 @@ class DualQuantizedMatmul : public UnaryPrimitive { int group_size, int bits, QuantizationMode mode, - bool transpose, - bool quantize_output) + bool transpose) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), mode_(mode), - transpose_(transpose), - quantize_output_(quantize_output) {} + transpose_(transpose) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1667,8 +1665,7 @@ class DualQuantizedMatmul : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple( - group_size_, bits_, mode_, transpose_, quantize_output_); + return std::make_tuple(group_size_, bits_, mode_, transpose_); } private: @@ -1676,7 +1673,6 @@ class DualQuantizedMatmul : public UnaryPrimitive { int bits_; QuantizationMode mode_; bool transpose_; - bool quantize_output_; }; class GatherQMM : public UnaryPrimitive { From 9e4879b405d2631913e5cc0dbf8462584bdbe043 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 19 Nov 2025 20:02:17 +0000 Subject: [PATCH 04/29] refactoring --- mlx/backend/cuda/cublas_utils.cpp | 4 +++- mlx/backend/cuda/cublas_utils.h | 6 +++--- mlx/backend/cuda/gemms/cublas_gemm.cpp | 1 + mlx/backend/cuda/matmul.cpp | 2 +- mlx/backend/cuda/quantized/cublas_qqmm.cpp | 11 ++--------- mlx/backend/cuda/quantized/quantized.cpp | 10 +++++----- 6 files changed, 15 insertions(+), 19 deletions(-) diff --git a/mlx/backend/cuda/cublas_utils.cpp b/mlx/backend/cuda/cublas_utils.cpp index 4ccb0816e2..9a2717fdeb 100644 --- a/mlx/backend/cuda/cublas_utils.cpp +++ b/mlx/backend/cuda/cublas_utils.cpp @@ -98,6 +98,7 @@ void CublasMatmulBase::init_base( cudaDataType_t scale_type, cublasComputeType_t compute_type, cudaDataType_t data_type, + cudaDataType_t output_type, bool a_transposed, uint64_t a_rows, uint64_t a_cols, @@ -163,7 +164,7 @@ void CublasMatmulBase::init_base( batch_count, a_batch_stride); out_desc_ = cublas_utils::create_matrix_layout( - data_type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols); + output_type, b_cols, a_rows, false, b_cols, batch_count, b_cols * a_rows); } void CublasMatmulBase::execute_matmul( @@ -233,4 +234,5 @@ void CublasMatmulBase::set_bias( &bias_ptr, sizeof(bias_ptr))); } + } // namespace mlx::core diff --git a/mlx/backend/cuda/cublas_utils.h b/mlx/backend/cuda/cublas_utils.h index e349386b1d..ebd422454f 100644 --- a/mlx/backend/cuda/cublas_utils.h +++ b/mlx/backend/cuda/cublas_utils.h @@ -33,6 +33,8 @@ class CublasMatmulBase { return matmul_desc_; } + void set_bias(cu::CommandEncoder& encoder, const array& bias); + protected: CublasMatmulBase() = default; @@ -54,6 +56,7 @@ class CublasMatmulBase { cudaDataType_t scale_type, cublasComputeType_t compute_type, cudaDataType_t data_type, + cudaDataType_t output_type, bool a_transposed, uint64_t a_rows, uint64_t a_cols, @@ -66,7 +69,6 @@ class CublasMatmulBase { int64_t a_batch_stride, int64_t b_batch_stride); - // Execute matmul using the configured descriptors void execute_matmul( cu::CommandEncoder& encoder, void* out, @@ -77,6 +79,4 @@ class CublasMatmulBase { const void* beta_ptr); }; -void set_bias(cu::CommandEncoder& encoder, const array& bias); - } // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index af7b0b8fd3..05d1fecbc2 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -75,6 +75,7 @@ CublasGemm::CublasGemm( scale_type_, dtype_to_compute_type(dtype), dtype_to_cublas_type(dtype), + dtype_to_cublas_type(dtype), a_transposed, a_rows, a_cols, diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index bd01c6b284..12862f0c39 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -98,7 +98,7 @@ void gemm_and_bias( throw std::runtime_error( "[gemm_and_bias] complex64 bias epilogue isn’t supported in cublasLtMatmul."); } - cublas_utils::set_bias(encoder, gemm.matmul_desc(), *bias); + gemm.set_bias(encoder, *bias); } gemm.run( encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha); diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index c5a59900da..a04cfefc14 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -65,6 +65,7 @@ CublasQQMM::CublasQQMM( int64_t b_batch_stride, std::string_view qmode) { cudaDataType_t scale_type = CUDA_R_32F; + cudaDataType_t output_type = CUDA_R_16BF; // always output in bf16 cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; // always for narrow precision cudaDataType_t data_type = qmode_to_cublas_dtype(qmode); @@ -76,6 +77,7 @@ CublasQQMM::CublasQQMM( scale_type, gemm_compute_type, data_type, + output_type, a_transposed, a_rows, a_cols, @@ -101,15 +103,6 @@ CublasQQMM::CublasQQMM( CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &b_scale_mode_, sizeof(b_scale_mode_))); - - // out_desc_ = create_matrix_layout( - // CUDA_R_16BF, // output in bf16 - // b_transposed ? b_rows : b_cols, - // a_transposed ? a_cols : a_rows, - // false, - // b_transposed ? b_rows : b_cols, - // batch_count, - // (a_transposed ? a_cols : a_rows) * (b_transposed ? b_rows : b_cols)); } CublasQQMM::CublasQQMM( diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 9a2a7192f0..088fcbc7b2 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -60,7 +60,7 @@ array pad_and_repack_scales( // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout // Note: cu::malloc_async already provides 256-byte alignment array scale_tiled( - cu::malloc_async(pad_outer * pad_inner, encoder.stream()), + cu::malloc_async(pad_outer * pad_inner, encoder), Shape{pad_outer, pad_inner}, scale.dtype()); repack_scales(scale, scale_tiled, encoder, s); @@ -149,13 +149,13 @@ void qqmm_impl( K, lda, b_transposed, - N, K, + N, ldb, - qmode, batch_shape.back(), a_batch_strides.back(), - b_batch_strides.back()); + b_batch_strides.back(), + qmode); qqmm.run( encoder, @@ -192,7 +192,7 @@ void DualQuantizedMatmul::eval_gpu( fill_gpu(zero, out, s); return; } - out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = a.shape(-2); int N = b.shape(-2); // b always transposed From df45b39861d821956b881e5dbe8e0fc733710ab5 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 20 Nov 2025 00:34:00 +0100 Subject: [PATCH 05/29] quantize activations on the fly --- mlx/backend/cuda/quantized/quantized.cpp | 10 ++++++++-- mlx/ops.cpp | 16 +++++++++------- mlx/ops.h | 5 ++--- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 088fcbc7b2..e279b3d3c5 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -181,9 +181,15 @@ void DualQuantizedMatmul::eval_gpu( auto& encoder = cu::get_command_encoder(s); assert(inputs.size() == 4); - auto& a = inputs[0]; + auto& a_pre = inputs[0]; // activations are not quantized, only weights are auto& b = inputs[1]; - auto& scale_a_pre = inputs[2]; + + auto a_q = quantize(a_pre, group_size_, bits_, mode_, s); + encoder.add_temporary(a_q[0]); + encoder.add_temporary(a_q[1]); + + auto& a = a_q[0]; + auto& scale_a_pre = a_q[1]; auto& scale_b_pre = inputs[3]; // Return 0s if either input is empty. if (a.size() == 0 || b.size() == 0) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5b245c5a24..9fad3bf214 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4222,8 +4222,7 @@ array quantized_matmul( array qqmm( array x, - array w, - array scales_x, + array w_q, array scales_w, bool transpose /* = true */, std::optional group_size_ /* = std::nullopt */, @@ -4248,18 +4247,21 @@ array qqmm( msg << "[qqmm] Affine quantization is not supported for qqmm."; throw std::invalid_argument(msg.str()); } + auto quantized_x = quantize(x, group_size_, bits, mode, s); + auto x_q = quantized_x[0]; + auto scales_x = quantized_x[1]; + encoder.add_temporary(x_q); + encoder.add_temporary(scales_x); auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); // Check and extract the quantized matrix shape against x - auto [x_dims, w_dims] = extract_qqmm_dims( - "qqmm", x, w, scales_x, scales_w, transpose, group_size, bits); + "qqmm", x_q, w_q, scales_x, scales_w, transpose, group_size, bits); auto [x_inner_dims, x_outer_dims] = x_dims; auto [w_inner_dims, w_outer_dims] = w_dims; - std::vector inputs = {x, w, scales_x, scales_w}; - - if (x.ndim() > 2 && w.ndim() > 2) { + std::vector inputs = {x_q, w_q, scales_x, scales_w}; + if (x_q.ndim() > 2 && w_q.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } diff --git a/mlx/ops.h b/mlx/ops.h index 6ca5dd43b6..d5af7e7e18 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1404,9 +1404,8 @@ std::vector quantize( StreamOrDevice s = {}); array qqmm( - const array x, - const array w, - const array x_scales, + const array x, // input activations + const array w_q, // quantized weights const array w_scales, bool transpose = true, std::optional group_size = std::nullopt, From 7a012a7682ef7b883c6b8e75a8f50be69873eaec Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 20 Nov 2025 01:08:55 +0100 Subject: [PATCH 06/29] quantize in eval --- mlx/backend/cuda/quantized/quantized.cpp | 9 ++++- mlx/ops.cpp | 48 ++++++++++-------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index e279b3d3c5..a7958fbfa7 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -180,11 +180,16 @@ void DualQuantizedMatmul::eval_gpu( auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - assert(inputs.size() == 4); + assert(inputs.size() == 3); auto& a_pre = inputs[0]; // activations are not quantized, only weights are auto& b = inputs[1]; - auto a_q = quantize(a_pre, group_size_, bits_, mode_, s); + auto a_q = fp_quantize( + a_pre, + group_size_, + bits_, + mode_, + s); // here i assume that ist is only for nvfp4/mxfp8 encoder.add_temporary(a_q[0]); encoder.add_temporary(a_q[1]); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9fad3bf214..a2f9e1002e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -154,27 +154,21 @@ std::pair extract_quantized_matmul_dims( return {w_inner_dims, w_outer_dims}; } -std::pair, std::pair> extract_qqmm_dims( +std::pair extract_qqmm_dims( std::string_view tag, const array& x, const array& w, - const array& scales_x, const array& scales_w, bool transpose, int group_size, int bits) { - // Validate x and scales_x - validate_quantized_input( - tag, x, scales_x, "x matrix", "scales_x", group_size, bits); - // Validate w and scales_w validate_quantized_input( tag, w, scales_w, "weight matrix", "scales_w", group_size, bits); // For narrow precision types (mxfp4, nvfp4) the only supported layout is TN // A is MxK, B is NxK (transposed) - int x_inner_dims = x.shape(-1); // K // (32 / bits) - int x_outer_dims = x.shape(-2); // M + int x_inner_dims = x.shape(-1) / (32 / bits); // K // Calculate the expanded w's dimensions int w_inner_dims = (transpose) ? w.shape(-1) : w.shape(-2); @@ -182,15 +176,19 @@ std::pair, std::pair> extract_qqmm_dims( if (w_inner_dims != x_inner_dims) { std::ostringstream msg; - msg << "[" << tag << "] Last dimension of first quantized input with " - << "shape (..., " << x_inner_dims << ") does not match " - << "the quantized matrix (" << w_inner_dims << ", " << w_outer_dims - << ") computed with transpose=" << std::boolalpha << transpose; + msg << "[" << tag << "] Inner dimension of second input with " + << "shape (" << w_inner_dims << ", " << w_outer_dims << ")" + << " computed with transpose=" << std::boolalpha << transpose + << " does not match the packed inner dimension of the first" + << "input (...," << x_inner_dims << ") computed with bits=" << bits + << " and transpose=" << std::boolalpha << transpose; throw std::invalid_argument(msg.str()); } - return {{x_inner_dims, x_outer_dims}, {w_inner_dims, w_outer_dims}}; + return { + w_inner_dims, w_outer_dims + } } } // namespace @@ -4231,9 +4229,8 @@ array qqmm( StreamOrDevice s /* = {} */) { // currently only simetric quantization is supported for qqmm auto qmode = string_to_quantization_mode(mode, "qqmm"); - // For narrow precision MMAs on B200 only TN layout is supported: - // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html - // TODO: handle it better + // here we need to check that inputs and otputs will be quantized in the same + // way... if ((qmode == QuantizationMode::Nvfp4 || qmode == QuantizationMode::Mxfp4) && !transpose) { std::ostringstream msg; @@ -4247,21 +4244,14 @@ array qqmm( msg << "[qqmm] Affine quantization is not supported for qqmm."; throw std::invalid_argument(msg.str()); } - auto quantized_x = quantize(x, group_size_, bits, mode, s); - auto x_q = quantized_x[0]; - auto scales_x = quantized_x[1]; - encoder.add_temporary(x_q); - encoder.add_temporary(scales_x); auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); - // Check and extract the quantized matrix shape against x - auto [x_dims, w_dims] = extract_qqmm_dims( - "qqmm", x_q, w_q, scales_x, scales_w, transpose, group_size, bits); - auto [x_inner_dims, x_outer_dims] = x_dims; - auto [w_inner_dims, w_outer_dims] = w_dims; + // + auto [w_inner_dims, w_outer_dims] = + extract_qqmm_dims("qqmm", x, w_q, scales_w, transpose, group_size, bits); - std::vector inputs = {x_q, w_q, scales_x, scales_w}; - if (x_q.ndim() > 2 && w_q.ndim() > 2) { + std::vector inputs = {x, w_q, scales_w}; + if (x.ndim() > 2 && w_q.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } @@ -5969,4 +5959,4 @@ array contiguous( {a}); } -} // namespace mlx::core +} // namespace mlx::core \ No newline at end of file From e34157213159bd611daa0d4234eb5d9558228d28 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 20 Nov 2025 01:22:50 +0100 Subject: [PATCH 07/29] Revert "quantize in eval" This reverts commit 7a012a7682ef7b883c6b8e75a8f50be69873eaec. --- mlx/backend/cuda/quantized/quantized.cpp | 9 +---- mlx/ops.cpp | 48 ++++++++++++++---------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index a7958fbfa7..e279b3d3c5 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -180,16 +180,11 @@ void DualQuantizedMatmul::eval_gpu( auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - assert(inputs.size() == 3); + assert(inputs.size() == 4); auto& a_pre = inputs[0]; // activations are not quantized, only weights are auto& b = inputs[1]; - auto a_q = fp_quantize( - a_pre, - group_size_, - bits_, - mode_, - s); // here i assume that ist is only for nvfp4/mxfp8 + auto a_q = quantize(a_pre, group_size_, bits_, mode_, s); encoder.add_temporary(a_q[0]); encoder.add_temporary(a_q[1]); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a2f9e1002e..9fad3bf214 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -154,21 +154,27 @@ std::pair extract_quantized_matmul_dims( return {w_inner_dims, w_outer_dims}; } -std::pair extract_qqmm_dims( +std::pair, std::pair> extract_qqmm_dims( std::string_view tag, const array& x, const array& w, + const array& scales_x, const array& scales_w, bool transpose, int group_size, int bits) { + // Validate x and scales_x + validate_quantized_input( + tag, x, scales_x, "x matrix", "scales_x", group_size, bits); + // Validate w and scales_w validate_quantized_input( tag, w, scales_w, "weight matrix", "scales_w", group_size, bits); // For narrow precision types (mxfp4, nvfp4) the only supported layout is TN // A is MxK, B is NxK (transposed) - int x_inner_dims = x.shape(-1) / (32 / bits); // K + int x_inner_dims = x.shape(-1); // K // (32 / bits) + int x_outer_dims = x.shape(-2); // M // Calculate the expanded w's dimensions int w_inner_dims = (transpose) ? w.shape(-1) : w.shape(-2); @@ -176,19 +182,15 @@ std::pair extract_qqmm_dims( if (w_inner_dims != x_inner_dims) { std::ostringstream msg; - msg << "[" << tag << "] Inner dimension of second input with " - << "shape (" << w_inner_dims << ", " << w_outer_dims << ")" - << " computed with transpose=" << std::boolalpha << transpose - << " does not match the packed inner dimension of the first" - << "input (...," << x_inner_dims << ") computed with bits=" << bits - << " and transpose=" << std::boolalpha << transpose; + msg << "[" << tag << "] Last dimension of first quantized input with " + << "shape (..., " << x_inner_dims << ") does not match " + << "the quantized matrix (" << w_inner_dims << ", " << w_outer_dims + << ") computed with transpose=" << std::boolalpha << transpose; throw std::invalid_argument(msg.str()); } - return { - w_inner_dims, w_outer_dims - } + return {{x_inner_dims, x_outer_dims}, {w_inner_dims, w_outer_dims}}; } } // namespace @@ -4229,8 +4231,9 @@ array qqmm( StreamOrDevice s /* = {} */) { // currently only simetric quantization is supported for qqmm auto qmode = string_to_quantization_mode(mode, "qqmm"); - // here we need to check that inputs and otputs will be quantized in the same - // way... + // For narrow precision MMAs on B200 only TN layout is supported: + // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html + // TODO: handle it better if ((qmode == QuantizationMode::Nvfp4 || qmode == QuantizationMode::Mxfp4) && !transpose) { std::ostringstream msg; @@ -4244,14 +4247,21 @@ array qqmm( msg << "[qqmm] Affine quantization is not supported for qqmm."; throw std::invalid_argument(msg.str()); } + auto quantized_x = quantize(x, group_size_, bits, mode, s); + auto x_q = quantized_x[0]; + auto scales_x = quantized_x[1]; + encoder.add_temporary(x_q); + encoder.add_temporary(scales_x); auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); - // - auto [w_inner_dims, w_outer_dims] = - extract_qqmm_dims("qqmm", x, w_q, scales_w, transpose, group_size, bits); + // Check and extract the quantized matrix shape against x + auto [x_dims, w_dims] = extract_qqmm_dims( + "qqmm", x_q, w_q, scales_x, scales_w, transpose, group_size, bits); + auto [x_inner_dims, x_outer_dims] = x_dims; + auto [w_inner_dims, w_outer_dims] = w_dims; - std::vector inputs = {x, w_q, scales_w}; - if (x.ndim() > 2 && w_q.ndim() > 2) { + std::vector inputs = {x_q, w_q, scales_x, scales_w}; + if (x_q.ndim() > 2 && w_q.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } @@ -5959,4 +5969,4 @@ array contiguous( {a}); } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core From de49f80e81ac2e900e38a13f6c1430662b4f98a1 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 20 Nov 2025 01:24:42 +0100 Subject: [PATCH 08/29] Revert "Revert "quantize in eval"" This reverts commit e34157213159bd611daa0d4234eb5d9558228d28. --- mlx/backend/cuda/quantized/quantized.cpp | 9 ++++- mlx/ops.cpp | 48 ++++++++++-------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index e279b3d3c5..a7958fbfa7 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -180,11 +180,16 @@ void DualQuantizedMatmul::eval_gpu( auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - assert(inputs.size() == 4); + assert(inputs.size() == 3); auto& a_pre = inputs[0]; // activations are not quantized, only weights are auto& b = inputs[1]; - auto a_q = quantize(a_pre, group_size_, bits_, mode_, s); + auto a_q = fp_quantize( + a_pre, + group_size_, + bits_, + mode_, + s); // here i assume that ist is only for nvfp4/mxfp8 encoder.add_temporary(a_q[0]); encoder.add_temporary(a_q[1]); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9fad3bf214..a2f9e1002e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -154,27 +154,21 @@ std::pair extract_quantized_matmul_dims( return {w_inner_dims, w_outer_dims}; } -std::pair, std::pair> extract_qqmm_dims( +std::pair extract_qqmm_dims( std::string_view tag, const array& x, const array& w, - const array& scales_x, const array& scales_w, bool transpose, int group_size, int bits) { - // Validate x and scales_x - validate_quantized_input( - tag, x, scales_x, "x matrix", "scales_x", group_size, bits); - // Validate w and scales_w validate_quantized_input( tag, w, scales_w, "weight matrix", "scales_w", group_size, bits); // For narrow precision types (mxfp4, nvfp4) the only supported layout is TN // A is MxK, B is NxK (transposed) - int x_inner_dims = x.shape(-1); // K // (32 / bits) - int x_outer_dims = x.shape(-2); // M + int x_inner_dims = x.shape(-1) / (32 / bits); // K // Calculate the expanded w's dimensions int w_inner_dims = (transpose) ? w.shape(-1) : w.shape(-2); @@ -182,15 +176,19 @@ std::pair, std::pair> extract_qqmm_dims( if (w_inner_dims != x_inner_dims) { std::ostringstream msg; - msg << "[" << tag << "] Last dimension of first quantized input with " - << "shape (..., " << x_inner_dims << ") does not match " - << "the quantized matrix (" << w_inner_dims << ", " << w_outer_dims - << ") computed with transpose=" << std::boolalpha << transpose; + msg << "[" << tag << "] Inner dimension of second input with " + << "shape (" << w_inner_dims << ", " << w_outer_dims << ")" + << " computed with transpose=" << std::boolalpha << transpose + << " does not match the packed inner dimension of the first" + << "input (...," << x_inner_dims << ") computed with bits=" << bits + << " and transpose=" << std::boolalpha << transpose; throw std::invalid_argument(msg.str()); } - return {{x_inner_dims, x_outer_dims}, {w_inner_dims, w_outer_dims}}; + return { + w_inner_dims, w_outer_dims + } } } // namespace @@ -4231,9 +4229,8 @@ array qqmm( StreamOrDevice s /* = {} */) { // currently only simetric quantization is supported for qqmm auto qmode = string_to_quantization_mode(mode, "qqmm"); - // For narrow precision MMAs on B200 only TN layout is supported: - // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html - // TODO: handle it better + // here we need to check that inputs and otputs will be quantized in the same + // way... if ((qmode == QuantizationMode::Nvfp4 || qmode == QuantizationMode::Mxfp4) && !transpose) { std::ostringstream msg; @@ -4247,21 +4244,14 @@ array qqmm( msg << "[qqmm] Affine quantization is not supported for qqmm."; throw std::invalid_argument(msg.str()); } - auto quantized_x = quantize(x, group_size_, bits, mode, s); - auto x_q = quantized_x[0]; - auto scales_x = quantized_x[1]; - encoder.add_temporary(x_q); - encoder.add_temporary(scales_x); auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); - // Check and extract the quantized matrix shape against x - auto [x_dims, w_dims] = extract_qqmm_dims( - "qqmm", x_q, w_q, scales_x, scales_w, transpose, group_size, bits); - auto [x_inner_dims, x_outer_dims] = x_dims; - auto [w_inner_dims, w_outer_dims] = w_dims; + // + auto [w_inner_dims, w_outer_dims] = + extract_qqmm_dims("qqmm", x, w_q, scales_w, transpose, group_size, bits); - std::vector inputs = {x_q, w_q, scales_x, scales_w}; - if (x_q.ndim() > 2 && w_q.ndim() > 2) { + std::vector inputs = {x, w_q, scales_w}; + if (x.ndim() > 2 && w_q.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } @@ -5969,4 +5959,4 @@ array contiguous( {a}); } -} // namespace mlx::core +} // namespace mlx::core \ No newline at end of file From 545dded6d7055a323ebc1f97fc821f8b778dc52d Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 20 Nov 2025 02:09:20 +0100 Subject: [PATCH 09/29] on the fly activation quantization --- examples/cpp/qqmm.cu | 12 ++-- mlx/backend/cuda/quantized/quantized.cpp | 70 ++++++++++++------------ mlx/ops.cpp | 3 +- 3 files changed, 39 insertions(+), 46 deletions(-) diff --git a/examples/cpp/qqmm.cu b/examples/cpp/qqmm.cu index eec515dd48..7598b9eb13 100644 --- a/examples/cpp/qqmm.cu +++ b/examples/cpp/qqmm.cu @@ -21,28 +21,24 @@ int main() { mx::array a = mx::random::uniform({M, K}, mx::bfloat16); // (M, K) mx::array b = mx::random::uniform({N, K}, mx::bfloat16); // (N, K) - auto scaled_a = mx::quantize(a, group_size, bits, quantization_mode); auto scaled_b = mx::quantize(b, group_size, bits, quantization_mode); - - mx::array a_quantized = scaled_a[0]; - mx::array a_scale = scaled_a[1]; mx::array b_quantized = scaled_b[0]; mx::array b_scale = scaled_b[1]; mx::array out = mx::qqmm( - a_quantized, + a, b_quantized, - a_scale, b_scale, true, group_size, bits, quantization_mode); + auto aq = mx::quantize(a, group_size, bits, quantization_mode); mx::array a_dequantized = - mx::dequantize(a_quantized, a_scale, {}, 16, 4, "nvfp4"); + mx::dequantize(aq[0], aq[1], {}, group_size, bits, quantization_mode); mx::array b_dequantized = - mx::dequantize(b_quantized, b_scale, {}, 16, 4, "nvfp4"); + mx::dequantize(b_quantized, b_scale, {}, group_size, bits, quantization_mode); mx::array reference_deq = mx::matmul(a_dequantized, mx::transpose(b_dequantized)); diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index a7958fbfa7..f89ddbbdb4 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -179,43 +179,41 @@ void DualQuantizedMatmul::eval_gpu( // TODO: for now minimalistic implementation without batching support auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - + assert(inputs.size() == 3); - auto& a_pre = inputs[0]; // activations are not quantized, only weights are - auto& b = inputs[1]; - - auto a_q = fp_quantize( - a_pre, - group_size_, - bits_, - mode_, - s); // here i assume that ist is only for nvfp4/mxfp8 - encoder.add_temporary(a_q[0]); - encoder.add_temporary(a_q[1]); - - auto& a = a_q[0]; - auto& scale_a_pre = a_q[1]; - auto& scale_b_pre = inputs[3]; - // Return 0s if either input is empty. - if (a.size() == 0 || b.size() == 0) { - array zero(0, a.dtype()); - encoder.add_temporary(zero); - fill_gpu(zero, out, s); - return; - } + auto& x = inputs[0]; // activations should be quantized on the fly + auto& w_q = inputs[1]; // quantized weights + + auto quantize_activation = [&](const array& input, cu::CommandEncoder& encoder, const Stream& s) { + auto x = ensure_row_contiguous(input, encoder, s); + auto xq_shape = x.shape(); + xq_shape.back() = x.shape(-1) * bits_ / 32; + auto sshape = x.shape(); + sshape.back() = x.shape(-1) / group_size_; + array x_q(cu::malloc_async(x.size() * bits_ / 8, encoder), xq_shape, uint32); + array scales_x(cu::malloc_async(x.size() / group_size_ * sizeof(uint8), encoder), sshape, uint8); + fp_quantize(x, x_q, scales_x, group_size_, bits_, encoder, s); + encoder.add_temporary(scales_x); + encoder.add_temporary(x_q); + return std::make_pair(x_q, scales_x); + }; + + auto [x_q, scale_x_pre] = quantize_activation(inputs[0], encoder, s); + auto& scale_w_pre = inputs[2]; + out.set_data(cu::malloc_async(out.nbytes(), encoder)); - int M = a.shape(-2); - int N = b.shape(-2); // b always transposed - int K_packed = a.shape(-1); + int M = x_q.shape(-2); + int N = w_q.shape(-2); // b always transposed + int K_packed = x_q.shape(-1); int K = K_packed * (32 / bits_); // Repack scales from linear to tiled layout for tensor cores - array scale_a_tiled = pad_and_repack_scales(scale_a_pre, encoder, s); - array scale_b_tiled = pad_and_repack_scales(scale_b_pre, encoder, s); + array scale_x = pad_and_repack_scales(scale_x_pre, encoder, s); + array scale_w = pad_and_repack_scales(scale_w_pre, encoder, s); - bool a_transposed = false; // a is normal (M x K) - bool b_transposed = true; // b is transposed (N x K -> K x N) + bool x_transposed = false; // a is normal (M x K) + bool w_transposed = true; // b is transposed (N x K -> K x N) int64_t lda = K; // Leading dimension of a (packed) int64_t ldb = K; // Leading dimension of b (packed) @@ -224,15 +222,15 @@ void DualQuantizedMatmul::eval_gpu( M, N, K, - a_transposed, + x_transposed, lda, - b_transposed, + w_transposed, ldb, out, - a, - b, - scale_a_tiled, - scale_b_tiled, + x_q, + w_q, + scale_x, + scale_w, mode_); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a2f9e1002e..fbc09c4fdd 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -188,7 +188,7 @@ std::pair extract_qqmm_dims( return { w_inner_dims, w_outer_dims - } + }; } } // namespace @@ -4246,7 +4246,6 @@ array qqmm( } auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); - // auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims("qqmm", x, w_q, scales_w, transpose, group_size, bits); From 5318b38ccc5bcac1ae603f435d82de47fdf90e95 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 20 Nov 2025 02:10:11 +0100 Subject: [PATCH 10/29] pre-commit --- mlx/backend/cuda/quantized/quantized.cpp | 35 ++++++++++++++---------- mlx/ops.cpp | 4 +-- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index f89ddbbdb4..7024825f5e 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -179,28 +179,33 @@ void DualQuantizedMatmul::eval_gpu( // TODO: for now minimalistic implementation without batching support auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - + assert(inputs.size() == 3); auto& x = inputs[0]; // activations should be quantized on the fly auto& w_q = inputs[1]; // quantized weights - auto quantize_activation = [&](const array& input, cu::CommandEncoder& encoder, const Stream& s) { - auto x = ensure_row_contiguous(input, encoder, s); - auto xq_shape = x.shape(); - xq_shape.back() = x.shape(-1) * bits_ / 32; - auto sshape = x.shape(); - sshape.back() = x.shape(-1) / group_size_; - array x_q(cu::malloc_async(x.size() * bits_ / 8, encoder), xq_shape, uint32); - array scales_x(cu::malloc_async(x.size() / group_size_ * sizeof(uint8), encoder), sshape, uint8); - fp_quantize(x, x_q, scales_x, group_size_, bits_, encoder, s); - encoder.add_temporary(scales_x); - encoder.add_temporary(x_q); - return std::make_pair(x_q, scales_x); - }; + auto quantize_activation = + [&](const array& input, cu::CommandEncoder& encoder, const Stream& s) { + auto x = ensure_row_contiguous(input, encoder, s); + auto xq_shape = x.shape(); + xq_shape.back() = x.shape(-1) * bits_ / 32; + auto sshape = x.shape(); + sshape.back() = x.shape(-1) / group_size_; + array x_q( + cu::malloc_async(x.size() * bits_ / 8, encoder), xq_shape, uint32); + array scales_x( + cu::malloc_async(x.size() / group_size_ * sizeof(uint8), encoder), + sshape, + uint8); + fp_quantize(x, x_q, scales_x, group_size_, bits_, encoder, s); + encoder.add_temporary(scales_x); + encoder.add_temporary(x_q); + return std::make_pair(x_q, scales_x); + }; auto [x_q, scale_x_pre] = quantize_activation(inputs[0], encoder, s); auto& scale_w_pre = inputs[2]; - + out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = x_q.shape(-2); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index fbc09c4fdd..c4ee472e7e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -186,9 +186,7 @@ std::pair extract_qqmm_dims( throw std::invalid_argument(msg.str()); } - return { - w_inner_dims, w_outer_dims - }; + return {w_inner_dims, w_outer_dims}; } } // namespace From 2184744e54e04f509e765370ef38a43153810d85 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 24 Nov 2025 17:07:11 +0100 Subject: [PATCH 11/29] qqmm inputs bf16 second arg --- mlx/backend/cuda/quantized/quantized.cpp | 12 +++-- mlx/ops.cpp | 30 +++++++---- mlx/primitives.cpp | 64 +++++++++++++++++++++++- 3 files changed, 89 insertions(+), 17 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 7024825f5e..f79b146971 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -175,14 +175,16 @@ void DualQuantizedMatmul::eval_gpu( const std::vector& inputs, array& out) { nvtx3::scoped_range r("DualQuantizedMatmul::eval_gpu"); - // WIP need to add primitive - // TODO: for now minimalistic implementation without batching support + // for now it is size of 4: bf16 x, bf16 w, w_q, scale_w, <- this is like this + // for the inference & vjp auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - assert(inputs.size() == 3); - auto& x = inputs[0]; // activations should be quantized on the fly - auto& w_q = inputs[1]; // quantized weights + assert(inputs.size() == 4); + auto& x = inputs[0]; // activations bf16 + auto& w = inputs[1]; // weights bf16 + auto& w_q = inputs[2]; // quantized weights + auto& scale_w_pre = inputs[3]; auto quantize_activation = [&](const array& input, cu::CommandEncoder& encoder, const Stream& s) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c4ee472e7e..f48ac5b51e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -158,21 +158,31 @@ std::pair extract_qqmm_dims( std::string_view tag, const array& x, const array& w, + const array& w_q, const array& scales_w, bool transpose, int group_size, int bits) { - // Validate w and scales_w + // Validate w_q and scales_w validate_quantized_input( - tag, w, scales_w, "weight matrix", "scales_w", group_size, bits); - + tag, w_q, scales_w, "weight matrix", "scales_w", group_size, bits); + // Calculate the expanded w's dimensions + if (w.shape(-1) != w_q.shape(-1) * 32 / bits || + w.shape(-2) != w_q.shape(-2)) { + std::ostringstream msg; + msg << "[" << tag << "] The shape of the weight matrix and its " + << "quantized version are incompatible. Received weight matrix " + << "with shape " << w.shape() << " and quantized weight matrix " + << "with shape " << w_q.shape() << " with bits=" << bits; + throw std::invalid_argument(msg.str()); + } // For narrow precision types (mxfp4, nvfp4) the only supported layout is TN // A is MxK, B is NxK (transposed) int x_inner_dims = x.shape(-1) / (32 / bits); // K // Calculate the expanded w's dimensions - int w_inner_dims = (transpose) ? w.shape(-1) : w.shape(-2); - int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1); + int w_inner_dims = (transpose) ? w_q.shape(-1) : w_q.shape(-2); + int w_outer_dims = (transpose) ? w_q.shape(-2) : w_q.shape(-1); if (w_inner_dims != x_inner_dims) { std::ostringstream msg; @@ -4218,6 +4228,7 @@ array quantized_matmul( array qqmm( array x, + array w, array w_q, array scales_w, bool transpose /* = true */, @@ -4227,8 +4238,7 @@ array qqmm( StreamOrDevice s /* = {} */) { // currently only simetric quantization is supported for qqmm auto qmode = string_to_quantization_mode(mode, "qqmm"); - // here we need to check that inputs and otputs will be quantized in the same - // way... + // here we need to check that w_q, w and scales_w are compatible with if ((qmode == QuantizationMode::Nvfp4 || qmode == QuantizationMode::Mxfp4) && !transpose) { std::ostringstream msg; @@ -4244,10 +4254,10 @@ array qqmm( } auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); - auto [w_inner_dims, w_outer_dims] = - extract_qqmm_dims("qqmm", x, w_q, scales_w, transpose, group_size, bits); + auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims( + "qqmm", x, w, w_q, scales_w, transpose, group_size, bits); - std::vector inputs = {x, w_q, scales_w}; + std::vector inputs = {x, w, w_q, scales_w}; if (x.ndim() > 2 && w_q.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index bca5935f36..b5635c9242 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3479,13 +3479,73 @@ bool DualQuantizedMatmul::is_equivalent(const Primitive& other) const { std::vector DualQuantizedMatmul::output_shapes( const std::vector& inputs) { auto out_shape = inputs[0].shape(); - auto& w = inputs[1]; - int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1); + auto& w_q = inputs[2]; + int w_outer_dims = (transpose_) ? w_q.shape(-2) : w_q.shape(-1); out_shape.back() = w_outer_dims; std::cout << "DualQuantizedMatmul output shape: " << out_shape << std::endl; return {std::move(out_shape)}; } +std::vector DualQuantizedMatmul::vjp( + const std::vector& primals, // bf16 x, quantized weights + const std::vector& cotangents, // bf16 grads + const std::vector& argnums, + const std::vector&) { + std::vector vjps; + auto& cotan = cotangents[0]; + std::vector reorder(cotan.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + std::iter_swap(reorder.end() - 1, reorder.end() - 2); + auto& s = stream(); + // primal[1] -- bf16 weights + // primal[2] -- quantized weights (row wise) + // primal[3] -- scales_w + // primal[0] -- bf16 activations (M, K) + // cotan -- bf16 activation grads (M, N) + for (auto arg : argnums) { + if (arg == 0) { // gradient wrt to x + // We transpose weights -> quantize along N -> qqmm (cotan quantized in + // eval_gpu) + auto wtq = quantize( + transpose(primals[1], reorder, s), + group_size_, + bits_, + mode_, + s); // (K, N_packed), scales + vjps.push_back(qqmm( + cotan, // M X N + primals[1], // bf16 weights (for compatability) + wtq[0], // K X N_packed + wtq[1], // scales + true, + group_size_, + bits_, + mode_, + s)); + } else if (arg == 1) { // gradient wrt to weights + // it is a bit complicated -- we need to quantize along M but cotan is + // (M,N) so we transpose cotan -> quantize along M -> qqmm + auto ctq = quantize( + transpose(cotan, reorder, s), + group_size_, + bits_, + mode_, + s); // (N, M_packed) + vjps.push_back(qqmm( + transpose(primals[0], reorder, s), // + cotan, // (M, N) + ctq[0], // (N, M_packed) + ctq[1], // scales + true, + group_size_, + bits_, + mode_, + s)); // (K, M), (N, M_packed) + } + } + return vjps; +} + std::pair, std::vector> GatherQMM::vmap( const std::vector& inputs, const std::vector& axes) { From 61d09eaa86798e6f829f99a134425a0a90b1c681 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 24 Nov 2025 20:33:55 +0100 Subject: [PATCH 12/29] fix --- mlx/backend/cuda/quantized/quantized.cpp | 8 +++--- mlx/ops.cpp | 7 +++++- mlx/ops.h | 7 +++--- mlx/primitives.cpp | 31 ++++++++++++++---------- mlx/primitives.h | 2 +- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index f79b146971..e0fb147510 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -175,7 +175,7 @@ void DualQuantizedMatmul::eval_gpu( const std::vector& inputs, array& out) { nvtx3::scoped_range r("DualQuantizedMatmul::eval_gpu"); - // for now it is size of 4: bf16 x, bf16 w, w_q, scale_w, <- this is like this + // for now it is size of 4: bf16 x, bf16 w, w_q, scale_w // for the inference & vjp auto& s = stream(); auto& encoder = cu::get_command_encoder(s); @@ -206,8 +206,6 @@ void DualQuantizedMatmul::eval_gpu( }; auto [x_q, scale_x_pre] = quantize_activation(inputs[0], encoder, s); - auto& scale_w_pre = inputs[2]; - out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = x_q.shape(-2); @@ -221,8 +219,8 @@ void DualQuantizedMatmul::eval_gpu( bool x_transposed = false; // a is normal (M x K) bool w_transposed = true; // b is transposed (N x K -> K x N) - int64_t lda = K; // Leading dimension of a (packed) - int64_t ldb = K; // Leading dimension of b (packed) + int64_t lda = K; // Leading dimension of a + int64_t ldb = K; // Leading dimension of b qqmm_impl( encoder, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f48ac5b51e..9f83a2bfb3 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4257,7 +4257,12 @@ array qqmm( auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims( "qqmm", x, w, w_q, scales_w, transpose, group_size, bits); - std::vector inputs = {x, w, w_q, scales_w}; + std::vector inputs = { + x, + w, + stop_gradient(w_q), + stop_gradient( + scales_w)}; // we don't backprope through qunatized w and scales if (x.ndim() > 2 && w_q.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } diff --git a/mlx/ops.h b/mlx/ops.h index d5af7e7e18..ef5907f9be 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1404,9 +1404,10 @@ std::vector quantize( StreamOrDevice s = {}); array qqmm( - const array x, // input activations - const array w_q, // quantized weights - const array w_scales, + array x, // input activations + array w, + array w_q, // quantized weights + array w_scales, bool transpose = true, std::optional group_size = std::nullopt, std::optional bits = std::nullopt, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b5635c9242..45cd4c19e6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3502,6 +3502,7 @@ std::vector DualQuantizedMatmul::vjp( // primal[3] -- scales_w // primal[0] -- bf16 activations (M, K) // cotan -- bf16 activation grads (M, N) + auto qmode = quantization_mode_to_string(mode_); for (auto arg : argnums) { if (arg == 0) { // gradient wrt to x // We transpose weights -> quantize along N -> qqmm (cotan quantized in @@ -3510,7 +3511,7 @@ std::vector DualQuantizedMatmul::vjp( transpose(primals[1], reorder, s), group_size_, bits_, - mode_, + qmode, s); // (K, N_packed), scales vjps.push_back(qqmm( cotan, // M X N @@ -3520,32 +3521,36 @@ std::vector DualQuantizedMatmul::vjp( true, group_size_, bits_, - mode_, + qmode, s)); } else if (arg == 1) { // gradient wrt to weights // it is a bit complicated -- we need to quantize along M but cotan is // (M,N) so we transpose cotan -> quantize along M -> qqmm - auto ctq = quantize( - transpose(cotan, reorder, s), - group_size_, - bits_, - mode_, - s); // (N, M_packed) + auto xt = transpose(primals[0], reorder, s); // (K, M) + auto xtq = quantize(xt, group_size_, bits_, qmode, + s); // (N, M_packed) vjps.push_back(qqmm( - transpose(primals[0], reorder, s), // - cotan, // (M, N) - ctq[0], // (N, M_packed) - ctq[1], // scales + transpose(cotan, reorder, s), // (N, M) + xt, // (K, M) + xtq[0], // (N, M_packed) + xtq[1], // scales true, group_size_, bits_, - mode_, + qmode, s)); // (K, M), (N, M_packed) } } return vjps; } +std::vector DualQuantizedMatmul::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + throw std::runtime_error("DualQuantizedMatmul::jvp NYI"); +} + std::pair, std::vector> GatherQMM::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 6f2821794a..4b9654ee0c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1660,7 +1660,7 @@ class DualQuantizedMatmul : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; // DEFINE_VMAP() - // DEFINE_GRADS() + DEFINE_GRADS() DEFINE_NAME(DualQuantizedMatmul) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; From 6af6eb3692ad35742ced9d224e06268ba747b7f4 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 25 Nov 2025 00:00:47 +0100 Subject: [PATCH 13/29] bf16 weights are optional --- mlx/ops.cpp | 20 ++++++++++++-------- mlx/ops.h | 2 +- mlx/primitives.cpp | 14 +++++++------- python/src/ops.cpp | 1 + 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9f83a2bfb3..7eb14d82dd 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -157,9 +157,9 @@ std::pair extract_quantized_matmul_dims( std::pair extract_qqmm_dims( std::string_view tag, const array& x, - const array& w, const array& w_q, const array& scales_w, + const std::optional& w, bool transpose, int group_size, int bits) { @@ -167,8 +167,10 @@ std::pair extract_qqmm_dims( validate_quantized_input( tag, w_q, scales_w, "weight matrix", "scales_w", group_size, bits); // Calculate the expanded w's dimensions - if (w.shape(-1) != w_q.shape(-1) * 32 / bits || - w.shape(-2) != w_q.shape(-2)) { + + if (w && + (w.shape(-1) != w_q.shape(-1) * 32 / bits || + w.shape(-2) != w_q.shape(-2))) { std::ostringstream msg; msg << "[" << tag << "] The shape of the weight matrix and its " << "quantized version are incompatible. Received weight matrix " @@ -4228,9 +4230,9 @@ array quantized_matmul( array qqmm( array x, - array w, array w_q, array scales_w, + std::optional w /* = std::nullopt */, bool transpose /* = true */, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, @@ -4255,14 +4257,16 @@ array qqmm( auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims( - "qqmm", x, w, w_q, scales_w, transpose, group_size, bits); + "qqmm", x, w_q, scales_w, w, transpose, group_size, bits); std::vector inputs = { x, - w, stop_gradient(w_q), - stop_gradient( - scales_w)}; // we don't backprope through qunatized w and scales + stop_gradient(scales_w), + }; // we don't backprope through qunatized w and scales, + if (w) { + inputs.push_back(w); + } if (x.ndim() > 2 && w_q.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } diff --git a/mlx/ops.h b/mlx/ops.h index ef5907f9be..002800a523 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1405,9 +1405,9 @@ std::vector quantize( array qqmm( array x, // input activations - array w, array w_q, // quantized weights array w_scales, + std::optional w = std::nullopt, // optional bf16 weights for vjp bool transpose = true, std::optional group_size = std::nullopt, std::optional bits = std::nullopt, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 45cd4c19e6..e44c28aa1c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3497,10 +3497,10 @@ std::vector DualQuantizedMatmul::vjp( std::iota(reorder.begin(), reorder.end(), 0); std::iter_swap(reorder.end() - 1, reorder.end() - 2); auto& s = stream(); - // primal[1] -- bf16 weights - // primal[2] -- quantized weights (row wise) - // primal[3] -- scales_w + // primal[1] -- quantized weights (row wise) + // primal[2] -- scales_w // primal[0] -- bf16 activations (M, K) + // primal[3] -- bf16 weights (N, K) // cotan -- bf16 activation grads (M, N) auto qmode = quantization_mode_to_string(mode_); for (auto arg : argnums) { @@ -3508,22 +3508,22 @@ std::vector DualQuantizedMatmul::vjp( // We transpose weights -> quantize along N -> qqmm (cotan quantized in // eval_gpu) auto wtq = quantize( - transpose(primals[1], reorder, s), + transpose(primals[3], reorder, s), group_size_, bits_, qmode, s); // (K, N_packed), scales vjps.push_back(qqmm( cotan, // M X N - primals[1], // bf16 weights (for compatability) wtq[0], // K X N_packed wtq[1], // scales + // primals[3], // bf16 weights (for compatability) true, group_size_, bits_, qmode, s)); - } else if (arg == 1) { // gradient wrt to weights + } else if (arg == 3) { // gradient wrt to weights // it is a bit complicated -- we need to quantize along M but cotan is // (M,N) so we transpose cotan -> quantize along M -> qqmm auto xt = transpose(primals[0], reorder, s); // (K, M) @@ -3531,7 +3531,7 @@ std::vector DualQuantizedMatmul::vjp( s); // (N, M_packed) vjps.push_back(qqmm( transpose(cotan, reorder, s), // (N, M) - xt, // (K, M) + // xt, // (K, M) xtq[0], // (N, M_packed) xtq[1], // scales true, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c28cf6a517..71eddd8a15 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5432,4 +5432,5 @@ void init_ops(nb::module_& m) { Returns: array or Sequence[array]: The outputs which depend on dependencies. )pbdoc"); + m.def("qqmm", &mx::qqmm, nb::arg(), nb::arg(), nb::kw_only(), ) } From 1e37ef6f82cca73a8f6cca9e5dacf400590d5f08 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 25 Nov 2025 00:29:16 +0100 Subject: [PATCH 14/29] op in python, typos --- mlx/ops.cpp | 6 +++--- mlx/primitives.cpp | 4 ++-- python/src/ops.cpp | 41 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7eb14d82dd..0248f8e37a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -169,12 +169,12 @@ std::pair extract_qqmm_dims( // Calculate the expanded w's dimensions if (w && - (w.shape(-1) != w_q.shape(-1) * 32 / bits || - w.shape(-2) != w_q.shape(-2))) { + (w->shape(-1) != w_q.shape(-1) * 32 / bits || + w->shape(-2) != w_q.shape(-2))) { std::ostringstream msg; msg << "[" << tag << "] The shape of the weight matrix and its " << "quantized version are incompatible. Received weight matrix " - << "with shape " << w.shape() << " and quantized weight matrix " + << "with shape " << w->shape() << " and quantized weight matrix " << "with shape " << w_q.shape() << " with bits=" << bits; throw std::invalid_argument(msg.str()); } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e44c28aa1c..0ad5ca3f32 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3517,7 +3517,7 @@ std::vector DualQuantizedMatmul::vjp( cotan, // M X N wtq[0], // K X N_packed wtq[1], // scales - // primals[3], // bf16 weights (for compatability) + std::nullopt, true, group_size_, bits_, @@ -3531,9 +3531,9 @@ std::vector DualQuantizedMatmul::vjp( s); // (N, M_packed) vjps.push_back(qqmm( transpose(cotan, reorder, s), // (N, M) - // xt, // (K, M) xtq[0], // (N, M_packed) xtq[1], // scales + std::nullopt, true, group_size_, bits_, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 71eddd8a15..fb31fefb17 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5432,5 +5432,44 @@ void init_ops(nb::module_& m) { Returns: array or Sequence[array]: The outputs which depend on dependencies. )pbdoc"); - m.def("qqmm", &mx::qqmm, nb::arg(), nb::arg(), nb::kw_only(), ) + m.def( + "qqmm", + &mx::qqmm, + nb::arg(), // x + nb::arg(), // w_q + "scales"_a, // scales w + "w"_a = nb::none(), // bf16 weights + "transpose"_a = true, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "affine", + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def qqmm(x: array, w_q: array, /, scales: array, w: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Perform the matrix multiplication with the quantized matrix ``w_q`` and x that is + quantized on-the-fly using provided group size, bits and mode which must be the same + as used to quantize ``w_q``. ``w`` must be provided during training for correct + gradient computation, but optional for the inference. + + Args: + x (array): Input array + w (array): Quantized matrix packed in unsigned integers + scales (array): The scales to use per ``group_size`` elements of ``w_q`` + w (array, optional): bf16 or float32 weights used during training for + correct gradient computation. Default: ``None``. + transpose (bool, optional): Defines whether to multiply with the + transposed ``w_q`` or not, namely whether we are performing + ``x @ w_q.T`` or ``x @ w_q``. Default: ``True``. + group_size (int, optional): The size of the group in ``w_q`` that shares a + scale and bias. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): The number of bits occupied by each element of + ``w_q`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + mode (str, optional): The quantization mode. Default: ``"affine"``. + Returns: + array: The result of the multiplication of quantized ``x`` with ``w_q``. + )pbdoc"); } From 9c584f81e6c87e5d7137c865872def2e9efa226a Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 25 Nov 2025 00:37:31 +0100 Subject: [PATCH 15/29] typo --- mlx/backend/cuda/quantized/quantized.cpp | 7 +++---- mlx/ops.cpp | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index e0fb147510..b42b8a9504 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -180,11 +180,10 @@ void DualQuantizedMatmul::eval_gpu( auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - assert(inputs.size() == 4); + assert(inputs.size() == 4 || inputs.size() == 3); auto& x = inputs[0]; // activations bf16 - auto& w = inputs[1]; // weights bf16 - auto& w_q = inputs[2]; // quantized weights - auto& scale_w_pre = inputs[3]; + auto& w_q = inputs[1]; // quantized weights + auto& scale_w_pre = inputs[2]; auto quantize_activation = [&](const array& input, cu::CommandEncoder& encoder, const Stream& s) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0248f8e37a..d739cf3aa7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4264,8 +4264,8 @@ array qqmm( stop_gradient(w_q), stop_gradient(scales_w), }; // we don't backprope through qunatized w and scales, - if (w) { - inputs.push_back(w); + if (w.has_value()) { + inputs.push_back(*w); } if (x.ndim() > 2 && w_q.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); From 9a83d3ca6f71cafe284e23ae1110e7c8ff8b446f Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 25 Nov 2025 23:48:57 +0100 Subject: [PATCH 16/29] batched qqmm --- mlx/backend/cuda/CMakeLists.txt | 9 +++- mlx/backend/cuda/quantized/cublas_qqmm.cpp | 26 ++++----- mlx/backend/cuda/quantized/cublas_qqmm.h | 22 ++++---- .../cuda/quantized/cublas_qqmm_batched.cpp | 53 +++++++++++++++++++ mlx/backend/cuda/quantized/qqmm_utils.cu | 37 +++++++++---- mlx/backend/cuda/quantized/quantized.cpp | 36 ++++++++++++- mlx/ops.cpp | 26 +++++++-- 7 files changed, 169 insertions(+), 40 deletions(-) create mode 100644 mlx/backend/cuda/quantized/cublas_qqmm_batched.cpp diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8d7a8b03cd..217bb86e1f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -56,9 +56,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu - ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) @@ -67,6 +65,13 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) # fp4 is not available on < 12.8 if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0) target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/) +else() + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu) + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp) + target_sources( + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm_batched.cpp) endif() if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index a04cfefc14..d1a6103c9a 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -158,18 +158,20 @@ void CublasQQMM::run( const Strides& b_batch_strides, float alpha) { int batch_count = out.size() / (M_ * N_); - // if (batch_count / batch_shape.back() > 1) { - // run_batched( - // encoder, - // out, - // a, - // b, - // batch_shape, - // a_batch_strides, - // b_batch_strides, - // alpha); - // return; - // } + if (batch_count / batch_shape.back() > 1) { + run_batched( + encoder, + out, + a, + b, + a_scale, + b_scale, + batch_shape, + a_batch_strides, + b_batch_strides, + alpha); + return; + } encoder.set_input_array(a); encoder.set_input_array(b); diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index a5650c9a17..178ba9d12e 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -68,16 +68,18 @@ class CublasQQMM : public CublasMatmulBase { // float alpha, // float beta); - // private: - // void run_batched( - // cu::CommandEncoder& encoder, - // array& out, - // const array& a, - // const array& b, - // const Shape& batch_shape, - // const Strides& a_batch_strides, - // const Strides& b_batch_strides, - // float alpha); + private: + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha); // void run_batched( // cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/quantized/cublas_qqmm_batched.cpp b/mlx/backend/cuda/quantized/cublas_qqmm_batched.cpp new file mode 100644 index 0000000000..166248a3ff --- /dev/null +++ b/mlx/backend/cuda/quantized/cublas_qqmm_batched.cpp @@ -0,0 +1,53 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/quantized/cublas_qqmm.h" + +namespace mlx::core { + +void CublasQQMM::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(a_scale); + encoder.set_input_array(b_scale); + encoder.set_output_array(out); + + auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); + + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + + // Scales are contiguous, so their batch stride is just the size of one scale + // matrix (?) + size_t a_scale_batch_stride = a_scale.shape(-2) * a_scale.shape(-1); + size_t b_scale_batch_stride = b_scale.shape(-2) * b_scale.shape(-1); + + auto concurrent = encoder.concurrent_context(); + for (size_t i = 0; i < nbatch; ++i) { + execute( + encoder, + gpu_ptr(out) + + out.itemsize() * i * batch_shape.back() * M_ * N_, + gpu_ptr(a) + a.itemsize() * a_it.loc, + gpu_ptr(b) + b.itemsize() * b_it.loc, + gpu_ptr(a_scale) + i * a_scale_batch_stride, + gpu_ptr(b_scale) + i * b_scale_batch_stride, + nullptr, + alpha); + a_it.step(); + b_it.step(); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index ff19057b08..9ba6d19cf2 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -94,7 +94,8 @@ __global__ void repack_scales( size_t input_rows, size_t input_cols, size_t output_rows, - size_t output_cols) { + size_t output_cols, + size_t batch_size) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); @@ -105,15 +106,25 @@ __global__ void repack_scales( auto grid_dim_x = cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; - size_t output_index = tidx + grid_dim_x * size_t(tidy); - size_t output_size = output_rows * output_cols; + size_t global_index = tidx + grid_dim_x * size_t(tidy); + size_t total_output_size = batch_size * output_rows * output_cols; - if (output_index >= output_size) { + if (global_index >= total_output_size) { return; } + // Compute batch strides from shape (scales are contiguous from fp_quantize) + size_t input_batch_stride = input_rows * input_cols; + size_t output_batch_stride = output_rows * output_cols; + + // Determine which batch and position within batch + size_t batch_idx = global_index / output_batch_stride; + size_t output_index = global_index % output_batch_stride; + size_t tiled_offset = scale_tiled_offset(output_index, output_rows, output_cols); + // Add batch offset for output + tiled_offset += batch_idx * output_batch_stride; size_t row = output_index / output_cols; size_t col = output_index % output_cols; @@ -121,7 +132,9 @@ __global__ void repack_scales( // Probably this can be done better with 2 separated paths for valid and // padding if (row < input_rows && col < input_cols) { - size_t input_index = row * input_cols + col; + // Compute input index with batch offset + size_t input_index = + batch_idx * input_batch_stride + row * input_cols + col; scales_tiled[tiled_offset] = scales_linear[input_index]; } else { // Zero-fill padding region @@ -147,11 +160,16 @@ void repack_scales( size_t output_rows = scales_tiled.shape(-2); size_t output_cols = scales_tiled.shape(-1); - size_t output_size = output_rows * output_cols; - bool large = output_size > UINT_MAX; + // Calculate batch size (all dimensions except last 2) + size_t batch_size = scales.size() / (input_rows * input_cols); + + // Total output size across all batches + size_t total_output_size = batch_size * output_rows * output_cols; + + bool large = total_output_size > UINT_MAX; auto [num_blocks, block_dims] = get_launch_args( - output_size, scales_tiled.shape(), scales_tiled.strides(), large); + total_output_size, scales_tiled.shape(), scales_tiled.strides(), large); enc.add_kernel_node( cu::repack_scales, @@ -163,7 +181,8 @@ void repack_scales( input_rows, input_cols, output_rows, - output_cols); + output_cols, + batch_size); } } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index b42b8a9504..122a71101c 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -53,6 +53,18 @@ array pad_and_repack_scales( // Compute padded dimensions for full tiles (128 rows × 4 cols) auto [pad_outer, pad_inner] = get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); + + // Calculate batch size (all dimensions except last 2) + size_t batch_size = scale.size() / (scale.shape(-2) * scale.shape(-1)); + + // Build output shape preserving batch dimensions + Shape out_shape; + for (int i = 0; i < scale.ndim() - 2; ++i) { + out_shape.push_back(scale.shape(i)); + } + out_shape.push_back(pad_outer); + out_shape.push_back(pad_inner); + // cuBLAS requirements for scale factor layout: // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) // 2. Out-of-bounds values must be filled with zeros @@ -60,8 +72,8 @@ array pad_and_repack_scales( // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout // Note: cu::malloc_async already provides 256-byte alignment array scale_tiled( - cu::malloc_async(pad_outer * pad_inner, encoder), - Shape{pad_outer, pad_inner}, + cu::malloc_async(batch_size * pad_outer * pad_inner, encoder), + out_shape, scale.dtype()); repack_scales(scale, scale_tiled, encoder, s); @@ -216,6 +228,26 @@ void DualQuantizedMatmul::eval_gpu( array scale_x = pad_and_repack_scales(scale_x_pre, encoder, s); array scale_w = pad_and_repack_scales(scale_w_pre, encoder, s); + // Debug: print scale shapes after repacking + std::cerr << "[DEBUG] scale_x_pre shape: ["; + for (int i = 0; i < scale_x_pre.ndim(); ++i) { + std::cerr << scale_x_pre.shape(i) + << (i < scale_x_pre.ndim() - 1 ? ", " : ""); + } + std::cerr << "]" << std::endl; + + std::cerr << "[DEBUG] scale_x (repacked) shape: ["; + for (int i = 0; i < scale_x.ndim(); ++i) { + std::cerr << scale_x.shape(i) << (i < scale_x.ndim() - 1 ? ", " : ""); + } + std::cerr << "], size: " << scale_x.size() << std::endl; + + std::cerr << "[DEBUG] scale_w (repacked) shape: ["; + for (int i = 0; i < scale_w.ndim(); ++i) { + std::cerr << scale_w.shape(i) << (i < scale_w.ndim() - 1 ? ", " : ""); + } + std::cerr << "], size: " << scale_w.size() << std::endl; + bool x_transposed = false; // a is normal (M x K) bool w_transposed = true; // b is transposed (N x K -> K x N) int64_t lda = K; // Leading dimension of a diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d739cf3aa7..f7752bbd2e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4238,11 +4238,25 @@ array qqmm( std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, StreamOrDevice s /* = {} */) { - // currently only simetric quantization is supported for qqmm +// currently only simetric quantization is supported for qqmm +#if defined(MLX_USE_CUDA) && CUDART_VERSION < 12080 + throw std::runtime_error( + "[qqmm] Requires CUDA >= 12.8.0 for cuBLAS block-scaled matmul support. " + "Please upgrade your CUDA toolkit."); +#endif auto qmode = string_to_quantization_mode(mode, "qqmm"); - // here we need to check that w_q, w and scales_w are compatible with - if ((qmode == QuantizationMode::Nvfp4 || qmode == QuantizationMode::Mxfp4) && - !transpose) { + // cuBLAS block scaled matmul only supports nvfp4 and mxfp8 + if (qmode != QuantizationMode::Nvfp4 && qmode != QuantizationMode::Mxfp8) { + std::ostringstream msg; + msg << "[qqmm] only 'nvfp4' and 'mxfp8' quantization modes are supported but '" + << mode << "' was provided."; + throw std::invalid_argument(msg.str()); + } + // for fp4 block scaling the only supported layout is TN + // https://docs.nvidia.com/cutlass/4.2.1/media/docs/cpp/blackwell_functionality.html + // because cublaslt is column major we need the second argument to be + // transposed, not the first one + if ((qmode == QuantizationMode::Nvfp4) && !transpose) { std::ostringstream msg; msg << "[qqmm] transpose must be set to true with " << mode << " quantization but " @@ -4259,11 +4273,13 @@ array qqmm( auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims( "qqmm", x, w_q, scales_w, w, transpose, group_size, bits); + // we don't backprope through qunatized w and scales std::vector inputs = { x, stop_gradient(w_q), stop_gradient(scales_w), - }; // we don't backprope through qunatized w and scales, + }; + // if bf16 w is provided, add it to inputs for vjps if (w.has_value()) { inputs.push_back(*w); } From 88361a34638058536d5a46076f181ceb81ee934c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 26 Nov 2025 00:50:30 +0000 Subject: [PATCH 17/29] delete batching --- mlx/backend/cuda/CMakeLists.txt | 2 - mlx/backend/cuda/quantized/cublas_qqmm.cpp | 20 ----- mlx/backend/cuda/quantized/cublas_qqmm.h | 3 - .../cuda/quantized/cublas_qqmm_batched.cpp | 53 ------------- mlx/backend/cuda/quantized/quantized.cpp | 79 ++++--------------- mlx/ops.cpp | 7 +- 6 files changed, 16 insertions(+), 148 deletions(-) delete mode 100644 mlx/backend/cuda/quantized/cublas_qqmm_batched.cpp diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 217bb86e1f..827bb77529 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -70,8 +70,6 @@ else() PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp) - target_sources( - mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm_batched.cpp) endif() if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index d1a6103c9a..67903f4a34 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -69,7 +69,6 @@ CublasQQMM::CublasQQMM( cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; // always for narrow precision cudaDataType_t data_type = qmode_to_cublas_dtype(qmode); - quantization_mode_ = std::string(qmode); init_base( @@ -153,26 +152,7 @@ void CublasQQMM::run( const array& b, const array& a_scale, const array& b_scale, - const Shape& batch_shape, - const Strides& a_batch_strides, - const Strides& b_batch_strides, float alpha) { - int batch_count = out.size() / (M_ * N_); - if (batch_count / batch_shape.back() > 1) { - run_batched( - encoder, - out, - a, - b, - a_scale, - b_scale, - batch_shape, - a_batch_strides, - b_batch_strides, - alpha); - return; - } - encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(a_scale); diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index 178ba9d12e..3137040165 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -50,9 +50,6 @@ class CublasQQMM : public CublasMatmulBase { const array& b, const array& a_scale, const array& b_scale, - const Shape& batch_shape, - const Strides& a_batch_strides, - const Strides& b_batch_strides, float alpha = 1.0f); // void run( diff --git a/mlx/backend/cuda/quantized/cublas_qqmm_batched.cpp b/mlx/backend/cuda/quantized/cublas_qqmm_batched.cpp deleted file mode 100644 index 166248a3ff..0000000000 --- a/mlx/backend/cuda/quantized/cublas_qqmm_batched.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/common/utils.h" -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/quantized/cublas_qqmm.h" - -namespace mlx::core { - -void CublasQQMM::run_batched( - cu::CommandEncoder& encoder, - array& out, - const array& a, - const array& b, - const array& a_scale, - const array& b_scale, - const Shape& batch_shape, - const Strides& a_batch_strides, - const Strides& b_batch_strides, - float alpha) { - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_input_array(a_scale); - encoder.set_input_array(b_scale); - encoder.set_output_array(out); - - auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); - - ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); - ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); - - // Scales are contiguous, so their batch stride is just the size of one scale - // matrix (?) - size_t a_scale_batch_stride = a_scale.shape(-2) * a_scale.shape(-1); - size_t b_scale_batch_stride = b_scale.shape(-2) * b_scale.shape(-1); - - auto concurrent = encoder.concurrent_context(); - for (size_t i = 0; i < nbatch; ++i) { - execute( - encoder, - gpu_ptr(out) + - out.itemsize() * i * batch_shape.back() * M_ * N_, - gpu_ptr(a) + a.itemsize() * a_it.loc, - gpu_ptr(b) + b.itemsize() * b_it.loc, - gpu_ptr(a_scale) + i * a_scale_batch_stride, - gpu_ptr(b_scale) + i * b_scale_batch_stride, - nullptr, - alpha); - a_it.step(); - b_it.step(); - } -} - -} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 122a71101c..785c87384f 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -50,20 +50,14 @@ array pad_and_repack_scales( const array& scale, cu::CommandEncoder& encoder, const Stream& s) { - // Compute padded dimensions for full tiles (128 rows × 4 cols) - auto [pad_outer, pad_inner] = - get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); - // Calculate batch size (all dimensions except last 2) size_t batch_size = scale.size() / (scale.shape(-2) * scale.shape(-1)); + size_t collapsed_outer = batch_size * scale.shape(-2); - // Build output shape preserving batch dimensions - Shape out_shape; - for (int i = 0; i < scale.ndim() - 2; ++i) { - out_shape.push_back(scale.shape(i)); - } - out_shape.push_back(pad_outer); - out_shape.push_back(pad_inner); + auto [pad_outer, pad_inner] = + get_padded_scale_dims(collapsed_outer, scale.shape(-1)); + + Shape out_shape = {pad_outer, pad_inner}; // cuBLAS requirements for scale factor layout: // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) @@ -72,7 +66,7 @@ array pad_and_repack_scales( // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout // Note: cu::malloc_async already provides 256-byte alignment array scale_tiled( - cu::malloc_async(batch_size * pad_outer * pad_inner, encoder), + cu::malloc_async(pad_outer * pad_inner, encoder), out_shape, scale.dtype()); repack_scales(scale, scale_tiled, encoder, s); @@ -139,20 +133,7 @@ void qqmm_impl( QuantizationMode mode, float alpha = 1.0f) { // Invoke CublasQQMM - auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); - auto batch_count = out.size() / (M * N); - std::string_view qmode = quantization_mode_to_string(mode); - if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && - b_batch_strides.back() == 0) { - M *= batch_shape.back(); - batch_count = 1; - - a_batch_strides = {0}; - b_batch_strides = {0}; - batch_shape = {1}; - } CublasQQMM qqmm( encoder.device(), @@ -164,22 +145,12 @@ void qqmm_impl( K, N, ldb, - batch_shape.back(), - a_batch_strides.back(), - b_batch_strides.back(), + 1, // batch_count + 0, // a_batch_stride + 0, // b_batch_stride qmode); - qqmm.run( - encoder, - out, - a, - b, - a_scale, - b_scale, - batch_shape, - a_batch_strides, - b_batch_strides, - alpha); + qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha); } } // namespace @@ -220,7 +191,7 @@ void DualQuantizedMatmul::eval_gpu( out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = x_q.shape(-2); - int N = w_q.shape(-2); // b always transposed + int N = transpose_ ? w_q.shape(-2) : w_q.shape(-1); int K_packed = x_q.shape(-1); int K = K_packed * (32 / bits_); @@ -228,30 +199,10 @@ void DualQuantizedMatmul::eval_gpu( array scale_x = pad_and_repack_scales(scale_x_pre, encoder, s); array scale_w = pad_and_repack_scales(scale_w_pre, encoder, s); - // Debug: print scale shapes after repacking - std::cerr << "[DEBUG] scale_x_pre shape: ["; - for (int i = 0; i < scale_x_pre.ndim(); ++i) { - std::cerr << scale_x_pre.shape(i) - << (i < scale_x_pre.ndim() - 1 ? ", " : ""); - } - std::cerr << "]" << std::endl; - - std::cerr << "[DEBUG] scale_x (repacked) shape: ["; - for (int i = 0; i < scale_x.ndim(); ++i) { - std::cerr << scale_x.shape(i) << (i < scale_x.ndim() - 1 ? ", " : ""); - } - std::cerr << "], size: " << scale_x.size() << std::endl; - - std::cerr << "[DEBUG] scale_w (repacked) shape: ["; - for (int i = 0; i < scale_w.ndim(); ++i) { - std::cerr << scale_w.shape(i) << (i < scale_w.ndim() - 1 ? ", " : ""); - } - std::cerr << "], size: " << scale_w.size() << std::endl; - - bool x_transposed = false; // a is normal (M x K) - bool w_transposed = true; // b is transposed (N x K -> K x N) - int64_t lda = K; // Leading dimension of a - int64_t ldb = K; // Leading dimension of b + bool x_transposed = false; + bool w_transposed = transpose_; + int64_t lda = K; + int64_t ldb = transpose_ ? K : N; qqmm_impl( encoder, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f7752bbd2e..1c7af381f8 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4238,12 +4238,7 @@ array qqmm( std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, StreamOrDevice s /* = {} */) { -// currently only simetric quantization is supported for qqmm -#if defined(MLX_USE_CUDA) && CUDART_VERSION < 12080 - throw std::runtime_error( - "[qqmm] Requires CUDA >= 12.8.0 for cuBLAS block-scaled matmul support. " - "Please upgrade your CUDA toolkit."); -#endif + // currently only simetric quantization is supported for qqmm auto qmode = string_to_quantization_mode(mode, "qqmm"); // cuBLAS block scaled matmul only supports nvfp4 and mxfp8 if (qmode != QuantizationMode::Nvfp4 && qmode != QuantizationMode::Mxfp8) { From b9e73ab65ec23a7e8c29f4b26e4ed4e83d813376 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 26 Nov 2025 01:04:47 +0000 Subject: [PATCH 18/29] string instead of stringz-view --- mlx/backend/cuda/quantized/cublas_qqmm.cpp | 10 +++++----- mlx/backend/cuda/quantized/cublas_qqmm.h | 4 ++-- mlx/backend/cuda/quantized/quantized.cpp | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index 67903f4a34..1ef6d7bc67 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -15,7 +15,7 @@ namespace { // Currently cublas supports only mxfp8 and nvfp4 // quantization modes for block scaled quantization -cudaDataType_t qmode_to_cublas_scale_dtype(std::string_view mode) { +cudaDataType_t qmode_to_cublas_scale_dtype(std::string mode) { if (mode == "mxfp8") { return CUDA_R_8F_UE8M0; } else if (mode == "nvfp4") { @@ -26,7 +26,7 @@ cudaDataType_t qmode_to_cublas_scale_dtype(std::string_view mode) { } } -cudaDataType_t qmode_to_cublas_dtype(std::string_view mode) { +cudaDataType_t qmode_to_cublas_dtype(std::string mode) { if (mode == "mxfp8") { return CUDA_R_8F_E4M3; } else if (mode == "nvfp4") { @@ -37,7 +37,7 @@ cudaDataType_t qmode_to_cublas_dtype(std::string_view mode) { } } -cublasLtMatmulMatrixScale_t qmode_to_cublas_scale_mode(std::string_view mode) { +cublasLtMatmulMatrixScale_t qmode_to_cublas_scale_mode(std::string mode) { if (mode == "mxfp8") { return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; } else if (mode == "nvfp4") { @@ -63,7 +63,7 @@ CublasQQMM::CublasQQMM( int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride, - std::string_view qmode) { + std::string qmode) { cudaDataType_t scale_type = CUDA_R_32F; cudaDataType_t output_type = CUDA_R_16BF; // always output in bf16 cublasComputeType_t gemm_compute_type = @@ -119,7 +119,7 @@ CublasQQMM::CublasQQMM( int64_t a_batch_stride, int64_t b_batch_stride, int64_t c_batch_stride, - std::string_view qmode) + std::string qmode) : CublasQQMM( device, a_transposed, diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index 3137040165..5dbaec17fa 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -24,7 +24,7 @@ class CublasQQMM : public CublasMatmulBase { int32_t batch_count, int64_t a_batch_stride, int64_t b_batch_stride, - std::string_view quantization_mode); + std::string quantization_mode); CublasQQMM( cu::Device& device, @@ -41,7 +41,7 @@ class CublasQQMM : public CublasMatmulBase { int64_t a_batch_stride, int64_t b_batch_stride, int64_t c_batch_stride, - std::string_view quantization_mode); + std::string quantization_mode); void run( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 785c87384f..9e090fb4d3 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -133,8 +133,8 @@ void qqmm_impl( QuantizationMode mode, float alpha = 1.0f) { // Invoke CublasQQMM - std::string_view qmode = quantization_mode_to_string(mode); - + std::string qmode = quantization_mode_to_string(mode); + CublasQQMM qqmm( encoder.device(), a_transposed, From 95c275b56a95c6007fc89377e3de9eaf4a7b09fb Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 26 Nov 2025 16:00:52 +0100 Subject: [PATCH 19/29] add 2D input condition --- mlx/backend/cuda/quantized/quantized.cpp | 22 ++++++++-------------- mlx/ops.cpp | 7 +++++++ mlx/primitives.cpp | 2 +- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 9e090fb4d3..d2eb72ebf3 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -50,15 +50,9 @@ array pad_and_repack_scales( const array& scale, cu::CommandEncoder& encoder, const Stream& s) { - // Calculate batch size (all dimensions except last 2) - size_t batch_size = scale.size() / (scale.shape(-2) * scale.shape(-1)); - size_t collapsed_outer = batch_size * scale.shape(-2); - + // Compute padded dimensions for full tiles (128 rows × 4 cols) auto [pad_outer, pad_inner] = - get_padded_scale_dims(collapsed_outer, scale.shape(-1)); - - Shape out_shape = {pad_outer, pad_inner}; - + get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); // cuBLAS requirements for scale factor layout: // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) // 2. Out-of-bounds values must be filled with zeros @@ -67,7 +61,7 @@ array pad_and_repack_scales( // Note: cu::malloc_async already provides 256-byte alignment array scale_tiled( cu::malloc_async(pad_outer * pad_inner, encoder), - out_shape, + Shape{pad_outer, pad_inner}, scale.dtype()); repack_scales(scale, scale_tiled, encoder, s); @@ -134,7 +128,10 @@ void qqmm_impl( float alpha = 1.0f) { // Invoke CublasQQMM std::string qmode = quantization_mode_to_string(mode); - + + // Currently only supports non-batched QQMM operations + // that covers all use cases for training, we will just collapse (batch, + // seq_len) into (tokens) CublasQQMM qqmm( encoder.device(), a_transposed, @@ -158,15 +155,12 @@ void DualQuantizedMatmul::eval_gpu( const std::vector& inputs, array& out) { nvtx3::scoped_range r("DualQuantizedMatmul::eval_gpu"); - // for now it is size of 4: bf16 x, bf16 w, w_q, scale_w - // for the inference & vjp auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - assert(inputs.size() == 4 || inputs.size() == 3); auto& x = inputs[0]; // activations bf16 auto& w_q = inputs[1]; // quantized weights - auto& scale_w_pre = inputs[2]; + auto& scale_w_pre = inputs[2]; // weight scales auto quantize_activation = [&](const array& input, cu::CommandEncoder& encoder, const Stream& s) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 1c7af381f8..8d7e23f86d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4263,6 +4263,13 @@ array qqmm( msg << "[qqmm] Affine quantization is not supported for qqmm."; throw std::invalid_argument(msg.str()); } + if (x.ndim() > 2 || w_q.ndim() > 2) { + std::ostringstream msg; + msg << "[qqmm] Only 2D inputs are supported but " + << "x.ndim() == " << x.ndim() << " and " + << "w_q.ndim() == " << w_q.ndim() << "."; + throw std::invalid_argument(msg.str()); + } auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0ad5ca3f32..6e29a98ec6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3508,7 +3508,7 @@ std::vector DualQuantizedMatmul::vjp( // We transpose weights -> quantize along N -> qqmm (cotan quantized in // eval_gpu) auto wtq = quantize( - transpose(primals[3], reorder, s), + transpose(primals[3], {1, 0}, s), // we assume that weights are 2D group_size_, bits_, qmode, From 64b8cbe5f9bbc110425829bac3452bde791b78f4 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 26 Nov 2025 20:14:47 +0100 Subject: [PATCH 20/29] force transpose --- mlx/backend/cuda/quantized/quantized.cpp | 6 ++--- mlx/ops.cpp | 28 ++++++++---------------- mlx/primitives.cpp | 7 ++---- mlx/primitives.h | 5 ++--- 4 files changed, 16 insertions(+), 30 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index d2eb72ebf3..8ad9111882 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -185,7 +185,7 @@ void DualQuantizedMatmul::eval_gpu( out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = x_q.shape(-2); - int N = transpose_ ? w_q.shape(-2) : w_q.shape(-1); + int N = w_q.shape(-2); // always transposed int K_packed = x_q.shape(-1); int K = K_packed * (32 / bits_); @@ -194,9 +194,9 @@ void DualQuantizedMatmul::eval_gpu( array scale_w = pad_and_repack_scales(scale_w_pre, encoder, s); bool x_transposed = false; - bool w_transposed = transpose_; + bool w_transposed = true; // always transposed int64_t lda = K; - int64_t ldb = transpose_ ? K : N; + int64_t ldb = K; qqmm_impl( encoder, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8d7e23f86d..4dd1db3f79 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -160,7 +160,6 @@ std::pair extract_qqmm_dims( const array& w_q, const array& scales_w, const std::optional& w, - bool transpose, int group_size, int bits) { // Validate w_q and scales_w @@ -178,13 +177,10 @@ std::pair extract_qqmm_dims( << "with shape " << w_q.shape() << " with bits=" << bits; throw std::invalid_argument(msg.str()); } - // For narrow precision types (mxfp4, nvfp4) the only supported layout is TN - // A is MxK, B is NxK (transposed) int x_inner_dims = x.shape(-1) / (32 / bits); // K - // Calculate the expanded w's dimensions - int w_inner_dims = (transpose) ? w_q.shape(-1) : w_q.shape(-2); - int w_outer_dims = (transpose) ? w_q.shape(-2) : w_q.shape(-1); + int w_inner_dims = w_q.shape(-1); + int w_outer_dims = w_q.shape(-2); if (w_inner_dims != x_inner_dims) { std::ostringstream msg; @@ -4233,7 +4229,6 @@ array qqmm( array w_q, array scales_w, std::optional w /* = std::nullopt */, - bool transpose /* = true */, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, @@ -4249,15 +4244,10 @@ array qqmm( } // for fp4 block scaling the only supported layout is TN // https://docs.nvidia.com/cutlass/4.2.1/media/docs/cpp/blackwell_functionality.html - // because cublaslt is column major we need the second argument to be - // transposed, not the first one - if ((qmode == QuantizationMode::Nvfp4) && !transpose) { - std::ostringstream msg; - msg << "[qqmm] transpose must be set to true with " << mode - << " quantization but " - << "transpose == false was provided."; - throw std::invalid_argument(msg.str()); - } + // because w_q should always be quantized along the reduction dimension + // and we quantize so that the last dim is packed, we enforce transpose = true + // always here + if (qmode == QuantizationMode::Affine) { std::ostringstream msg; msg << "[qqmm] Affine quantization is not supported for qqmm."; @@ -4272,8 +4262,8 @@ array qqmm( } auto [group_size, bits] = quantization_params_from_mode(qmode, group_size_, bits_); - auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims( - "qqmm", x, w_q, scales_w, w, transpose, group_size, bits); + auto [w_inner_dims, w_outer_dims] = + extract_qqmm_dims("qqmm", x, w_q, scales_w, w, group_size, bits); // we don't backprope through qunatized w and scales std::vector inputs = { @@ -4298,7 +4288,7 @@ array qqmm( std::move(out_shape), dtype, std::make_shared( - to_stream(s), group_size, bits, qmode, transpose), + to_stream(s), group_size, bits, qmode), std::move(inputs)); } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6e29a98ec6..b419c050f8 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3473,14 +3473,13 @@ bool DualQuantizedMatmul::is_equivalent(const Primitive& other) const { const DualQuantizedMatmul& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && - mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; + mode_ == qm_other.mode_; } std::vector DualQuantizedMatmul::output_shapes( const std::vector& inputs) { auto out_shape = inputs[0].shape(); - auto& w_q = inputs[2]; - int w_outer_dims = (transpose_) ? w_q.shape(-2) : w_q.shape(-1); + int w_outer_dims = inputs[2].shape(-2); out_shape.back() = w_outer_dims; std::cout << "DualQuantizedMatmul output shape: " << out_shape << std::endl; return {std::move(out_shape)}; @@ -3518,7 +3517,6 @@ std::vector DualQuantizedMatmul::vjp( wtq[0], // K X N_packed wtq[1], // scales std::nullopt, - true, group_size_, bits_, qmode, @@ -3534,7 +3532,6 @@ std::vector DualQuantizedMatmul::vjp( xtq[0], // (N, M_packed) xtq[1], // scales std::nullopt, - true, group_size_, bits_, qmode, diff --git a/mlx/primitives.h b/mlx/primitives.h index 4b9654ee0c..041b376e75 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1653,8 +1653,7 @@ class DualQuantizedMatmul : public UnaryPrimitive { : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - mode_(mode), - transpose_(transpose) {} + mode_(mode) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1665,7 +1664,7 @@ class DualQuantizedMatmul : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(group_size_, bits_, mode_, transpose_); + return std::make_tuple(group_size_, bits_, mode_); } private: From 4b685952c4ad5c148bca281599da253e6530145a Mon Sep 17 00:00:00 2001 From: root Date: Wed, 26 Nov 2025 19:34:05 +0000 Subject: [PATCH 21/29] fix transpose --- mlx/ops.cpp | 14 +++++--------- mlx/ops.h | 1 - mlx/primitives.h | 3 +-- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4dd1db3f79..28670d1831 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -165,7 +165,6 @@ std::pair extract_qqmm_dims( // Validate w_q and scales_w validate_quantized_input( tag, w_q, scales_w, "weight matrix", "scales_w", group_size, bits); - // Calculate the expanded w's dimensions if (w && (w->shape(-1) != w_q.shape(-1) * 32 / bits || @@ -186,10 +185,8 @@ std::pair extract_qqmm_dims( std::ostringstream msg; msg << "[" << tag << "] Inner dimension of second input with " << "shape (" << w_inner_dims << ", " << w_outer_dims << ")" - << " computed with transpose=" << std::boolalpha << transpose << " does not match the packed inner dimension of the first" - << "input (...," << x_inner_dims << ") computed with bits=" << bits - << " and transpose=" << std::boolalpha << transpose; + << "input (...," << x_inner_dims << ") computed with bits=" << bits; throw std::invalid_argument(msg.str()); } @@ -4245,9 +4242,9 @@ array qqmm( // for fp4 block scaling the only supported layout is TN // https://docs.nvidia.com/cutlass/4.2.1/media/docs/cpp/blackwell_functionality.html // because w_q should always be quantized along the reduction dimension - // and we quantize so that the last dim is packed, we enforce transpose = true - // always here - + // and we quantize so that the last dim is packed, we assume that the last dim + // always the reduction dim so the firat argument in cubals column major is + // (the second argument in qqmm) always transposed if (qmode == QuantizationMode::Affine) { std::ostringstream msg; msg << "[qqmm] Affine quantization is not supported for qqmm."; @@ -4281,9 +4278,8 @@ array qqmm( auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; - - // out dtype can be only bf16 auto dtype = bfloat16; + // out dtype can be only bf16 for now return array( std::move(out_shape), dtype, diff --git a/mlx/ops.h b/mlx/ops.h index 002800a523..df4be8777a 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1408,7 +1408,6 @@ array qqmm( array w_q, // quantized weights array w_scales, std::optional w = std::nullopt, // optional bf16 weights for vjp - bool transpose = true, std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "nvfp4", diff --git a/mlx/primitives.h b/mlx/primitives.h index 041b376e75..73e93ef710 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1648,8 +1648,7 @@ class DualQuantizedMatmul : public UnaryPrimitive { Stream stream, int group_size, int bits, - QuantizationMode mode, - bool transpose) + QuantizationMode mode) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), From ee0ea9f26a73713324846bb16ab7446c12799a50 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 27 Nov 2025 01:15:15 +0000 Subject: [PATCH 22/29] add pythong tests --- mlx/backend/cuda/quantized/qqmm_utils.cu | 37 +++-------- mlx/ops.cpp | 1 + python/src/ops.cpp | 3 +- python/tests/test_quantized.py | 81 ++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 30 deletions(-) diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index 9ba6d19cf2..ff19057b08 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -94,8 +94,7 @@ __global__ void repack_scales( size_t input_rows, size_t input_cols, size_t output_rows, - size_t output_cols, - size_t batch_size) { + size_t output_cols) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); @@ -106,25 +105,15 @@ __global__ void repack_scales( auto grid_dim_x = cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; - size_t global_index = tidx + grid_dim_x * size_t(tidy); - size_t total_output_size = batch_size * output_rows * output_cols; + size_t output_index = tidx + grid_dim_x * size_t(tidy); + size_t output_size = output_rows * output_cols; - if (global_index >= total_output_size) { + if (output_index >= output_size) { return; } - // Compute batch strides from shape (scales are contiguous from fp_quantize) - size_t input_batch_stride = input_rows * input_cols; - size_t output_batch_stride = output_rows * output_cols; - - // Determine which batch and position within batch - size_t batch_idx = global_index / output_batch_stride; - size_t output_index = global_index % output_batch_stride; - size_t tiled_offset = scale_tiled_offset(output_index, output_rows, output_cols); - // Add batch offset for output - tiled_offset += batch_idx * output_batch_stride; size_t row = output_index / output_cols; size_t col = output_index % output_cols; @@ -132,9 +121,7 @@ __global__ void repack_scales( // Probably this can be done better with 2 separated paths for valid and // padding if (row < input_rows && col < input_cols) { - // Compute input index with batch offset - size_t input_index = - batch_idx * input_batch_stride + row * input_cols + col; + size_t input_index = row * input_cols + col; scales_tiled[tiled_offset] = scales_linear[input_index]; } else { // Zero-fill padding region @@ -160,16 +147,11 @@ void repack_scales( size_t output_rows = scales_tiled.shape(-2); size_t output_cols = scales_tiled.shape(-1); + size_t output_size = output_rows * output_cols; - // Calculate batch size (all dimensions except last 2) - size_t batch_size = scales.size() / (input_rows * input_cols); - - // Total output size across all batches - size_t total_output_size = batch_size * output_rows * output_cols; - - bool large = total_output_size > UINT_MAX; + bool large = output_size > UINT_MAX; auto [num_blocks, block_dims] = get_launch_args( - total_output_size, scales_tiled.shape(), scales_tiled.strides(), large); + output_size, scales_tiled.shape(), scales_tiled.strides(), large); enc.add_kernel_node( cu::repack_scales, @@ -181,8 +163,7 @@ void repack_scales( input_rows, input_cols, output_rows, - output_cols, - batch_size); + output_cols); } } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 28670d1831..213d1220ce 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -163,6 +163,7 @@ std::pair extract_qqmm_dims( int group_size, int bits) { // Validate w_q and scales_w + // https://docs.nvidia.com/cuda/cublas/#d-block-scaling-for-fp8-and-fp4-data-types validate_quantized_input( tag, w_q, scales_w, "weight matrix", "scales_w", group_size, bits); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index fb31fefb17..319d55c4b6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5439,7 +5439,6 @@ void init_ops(nb::module_& m) { nb::arg(), // w_q "scales"_a, // scales w "w"_a = nb::none(), // bf16 weights - "transpose"_a = true, "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "affine", @@ -5455,7 +5454,7 @@ void init_ops(nb::module_& m) { Args: x (array): Input array - w (array): Quantized matrix packed in unsigned integers + w_q (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w_q`` w (array, optional): bf16 or float32 weights used during training for correct gradient computation. Default: ``None``. diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index e751063036..449e9c4700 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -975,6 +975,87 @@ def gmm(s, x, wq): ds = mx.grad(gmm)(s, x, wq) + def test_qqmm(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + dtype = mx.bfloat16 + + tests = ( + (16, "nvfp4", 4), + (32, "mxfp8", 8), + ) + shapes = ( + [64, 65, 33, 128, 256, 1024, 1024 * 8], # M + [64, 128, 256, 1024, 1024 * 8], # N + [64, 128, 256, 1024, 1024 * 8], # K + ) + for group_size, mode, bits in tests: + for M, N, K in product(*shapes): + with self.subTest( + shape=(M, N, K), + group_size=group_size, + bits=bits, + mode=mode, + ): + x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) + w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) + w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) + w_dq = mx.dequantize( + w_q, scales_w, group_size=group_size, bits=bits, mode=mode + ) + y_q = mx.qqmm( + x, w_q, scales_w, group_size=group_size, bits=bits, mode=mode + ) + x_q, scales_x = mx.quantize( + x, group_size=group_size, bits=bits, mode=mode + ) + x_dq = mx.dequantize( + x_q, scales_x, group_size=group_size, bits=bits, mode=mode + ) + y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + def test_qqmm_vjp(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + dtype = mx.bfloat16 + M = 64 + N = 1024 + K = 512 + tests = ( + (16, "nvfp4", 4), + (32, "mxfp8", 8), + ) + x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) + c = mx.ones(shape=(M, N), dtype=dtype) + + for group_size, mode, bits in tests: + with self.subTest( + shape=(M, N, K), + group_size=group_size, + bits=bits, + mode=mode, + ): + w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) + w_q, scales_w = mx.quantize( + w, group_size=group_size, bits=bits, mode=mode + ) + + def fn(x): + return mx.qqmm( + x, w_q, scales_w, w, group_size=group_size, bits=bits, mode=mode + ) + + _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) + w_tq, scales_wt = mx.quantize( + mx.transpose(w), group_size=group_size, bits=bits, mode=mode + ) + expected_out = mx.qqmm( + c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode + ) + self.assertTrue(mx.allclose(vjp_out[0], expected_out)) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From cc0333e088d0a47a76dc64c0876cee324a66d686 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Nov 2025 00:57:48 +0000 Subject: [PATCH 23/29] added qq linear --- python/mlx/nn/layers/__init__.py | 7 +- python/mlx/nn/layers/quantized.py | 103 ++++++++++++++++++++++++++++++ python/tests/test_nn.py | 12 ++++ 3 files changed, 121 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index ea2d3029d8..c2fba58347 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -87,7 +87,12 @@ MaxPool3d, ) from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding -from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize +from mlx.nn.layers.quantized import ( + QQLinear, + QuantizedEmbedding, + QuantizedLinear, + quantize, +) from mlx.nn.layers.recurrent import GRU, LSTM, RNN from mlx.nn.layers.transformer import ( MultiHeadAttention, diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index c308e884ba..32399896a4 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -268,3 +268,106 @@ def from_linear( ql.bias = linear_layer.bias return ql + + +class QQLinear(Module): + """Quantizes the input and applies an affine transformation using a quantized weight matrix. + + Currently, bias is not implemented for this layer. + It is equivalent to :class:`mlx.nn.QuantizedLinear` but both input and weight are quantized. + Note: QQLinear supported only `nvfp4` and `mxfp8` quantization modes. + For training, parameter `train` should be set to `True`. + + :obj:`QQLinear` also provides a classmethod :meth:`from_linear` to + convert linear layers to :obj:`QQLinear` layers. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``16``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + mode (str): The quantization method to use (see + :func:`mlx.core.quantize`). Default: ``"nvfp4"``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = False, + group_size: int = 16, + bits: int = 4, + mode: str = "nvfp4", + train: bool = False, + ): + super().__init__() + if bias: + raise NotImplementedError("Bias not implemented for QQLinear yet.") + # Quantization config + self.group_size = group_size + self.bits = bits + self.mode = mode + + # Initialize the quantized weight + scale = math.sqrt(1 / input_dims) + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims), + ) + # if vjp: + if train: + self.weight = weight + self.weight_quantized, self.scales = mx.quantize( + weight, group_size, bits, mode=mode + ) + + def _extra_repr(self): + out_dims, in_dims = self.weight_quantized.shape + in_dims *= 32 // self.bits + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" + ) + + def __call__(self, x): + x = mx.qqmm( + x, + self["weight_quantized"], + scales=self["scales"], + w=self.get("weight"), + group_size=self.group_size, + bits=self.bits, + mode=self.mode, + ) + return x + + @classmethod + def from_linear( + cls, + linear_layer: Module, + group_size: int = 16, + bits: int = 4, + mode: str = "nvfp4", + train: bool = False, + ): + """Create a :obj:`QQLinear` layer from a :obj:`Linear` layer.""" + output_dims, input_dims = linear_layer.weight.shape + ql = cls( + input_dims, output_dims, False, group_size, bits, mode=mode, train=train + ) + ql.weight_quantized, ql.scales = mx.quantize( + linear_layer.weight, + group_size, + bits, + mode=mode, + ) + if train: + ql.weight = linear_layer.weight + + if "bias" in linear_layer: + ql.bias = linear_layer.bias + + return ql diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 6ded372278..aea12fa336 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1948,6 +1948,18 @@ def test_attention(self): out = attn(x, x, x) self.assertEqual(out.shape, x.shape) + def test_qqlinear(self): + model = nn.Sequential( + nn.QQLinear(512, 256, bits=4, group_size=16, mode="nvfp4"), + nn.ReLU(), + nn.QQLinear(256, 128, bits=4, group_size=16, mode="nvfp4"), + nn.ReLU(), + nn.QQLinear(128, 256, bits=4, group_size=16, mode="nvfp4"), + ) + x = mx.random.normal(shape=(128, 512)) + out = model(x) + self.assertEqual(out.shape, (128, 256)) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From 34f42fba76ad4d1bbbc6487d2e7da51c728c13bc Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 29 Nov 2025 02:14:08 +0100 Subject: [PATCH 24/29] added tests --- python/tests/test_quantized.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 449e9c4700..327d17ccde 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -7,6 +7,14 @@ import mlx_tests +def ulp_bf16_at(x): + ax = mx.abs(x) + min_normal = mx.array(2.0**-126) + ax = mx.where(ax < min_normal, min_normal, ax) + e = mx.floor(mx.log2(ax)) + return mx.power(2.0, e - 7.0) + + class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) @@ -976,6 +984,14 @@ def gmm(s, x, wq): ds = mx.grad(gmm)(s, x, wq) def test_qqmm(self): + # for mxfp8 mode the results does not match exactly + # for less then 1 percent of elements in the output + # this is not systematic error + # the error can be larger than 1 ULP for very small elements + # and always less than 1 ULP for large elements + # for nvfp4 results match precisely + # therefore I suspect that potential cause of the difference + # is in the implementation of mxfp8 matmul in cuBLASlt key = mx.random.key(0) k1, k2 = mx.random.split(key) dtype = mx.bfloat16 @@ -1013,8 +1029,10 @@ def test_qqmm(self): x_q, scales_x, group_size=group_size, bits=bits, mode=mode ) y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) + ulp = ulp_bf16_at(y_hat) + error = (y_q - y_hat).abs() self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + self.assertTrue(mx.logical_or(error < 1e-3, error <= ulp).all()) def test_qqmm_vjp(self): key = mx.random.key(0) @@ -1054,7 +1072,9 @@ def fn(x): expected_out = mx.qqmm( c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode ) - self.assertTrue(mx.allclose(vjp_out[0], expected_out)) + ulp = ulp_bf16_at(expected_out) + error = (vjp_out[0] - expected_out).abs() + self.assertTrue(mx.logical_or(error < 1e-3, error <= ulp).all()) if __name__ == "__main__": From 9184f9a48e29cbf9993a51a6693b0b50a1fe759d Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 29 Nov 2025 02:24:06 +0100 Subject: [PATCH 25/29] docs correctlion --- python/src/ops.cpp | 19 ++++++++----------- python/tests/test_quantized.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 319d55c4b6..c95535433e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5441,33 +5441,30 @@ void init_ops(nb::module_& m) { "w"_a = nb::none(), // bf16 weights "group_size"_a = nb::none(), "bits"_a = nb::none(), - "mode"_a = "affine", + "mode"_a = "nvfp4", nb::kw_only(), "stream"_a = nb::none(), nb::sig( "def qqmm(x: array, w_q: array, /, scales: array, w: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - Perform the matrix multiplication with the quantized matrix ``w_q`` and x that is - quantized on-the-fly using provided group size, bits and mode which must be the same - as used to quantize ``w_q``. ``w`` must be provided during training for correct - gradient computation, but optional for the inference. + Perform the matrix multiplication with the quantized matrix ``w_q`` and ``x`` that is + quantized on-the-fly using provided group size, bits and mode. Group size, bits and mode + must match those used to quantize ``w_q``. High precision ``w`` must be provided + for gradient computation and should match ``w_q`` before quantization. Args: x (array): Input array w_q (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w_q`` w (array, optional): bf16 or float32 weights used during training for - correct gradient computation. Default: ``None``. - transpose (bool, optional): Defines whether to multiply with the - transposed ``w_q`` or not, namely whether we are performing - ``x @ w_q.T`` or ``x @ w_q``. Default: ``True``. + gradient computation. Must be provided for vjp. Default: ``None``. group_size (int, optional): The size of the group in ``w_q`` that shares a - scale and bias. See supported values and defaults in the + scale. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. bits (int, optional): The number of bits occupied by each element of ``w_q`` in the quantized array. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. - mode (str, optional): The quantization mode. Default: ``"affine"``. + mode (str, optional): The quantization mode. Default: ``"nvfp4"``. Returns: array: The result of the multiplication of quantized ``x`` with ``w_q``. )pbdoc"); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 327d17ccde..225167b87b 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -984,14 +984,14 @@ def gmm(s, x, wq): ds = mx.grad(gmm)(s, x, wq) def test_qqmm(self): - # for mxfp8 mode the results does not match exactly - # for less then 1 percent of elements in the output - # this is not systematic error - # the error can be larger than 1 ULP for very small elements - # and always less than 1 ULP for large elements - # for nvfp4 results match precisely - # therefore I suspect that potential cause of the difference - # is in the implementation of mxfp8 matmul in cuBLASlt + # In mxfp8 mode, the results do not match exactly: + # fewer than 1% of output elements differ. + # This does not appear to be a systematic error. + # The error can exceed 1 ULP for very small values, + # and is always below 1 ULP for larger values. + # For nvfp4, the results match exactly. + # therefore I suspect that the discrepancy comes from + # the mxfp8 matmul implementation in cuBLASLt.. key = mx.random.key(0) k1, k2 = mx.random.split(key) dtype = mx.bfloat16 From 110848fdf3f3abf720d41ba01b9b3c9b84833966 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 29 Nov 2025 19:52:52 +0100 Subject: [PATCH 26/29] small fixes --- mlx/ops.cpp | 2 +- mlx/primitives.cpp | 1 - python/src/ops.cpp | 11 ++++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 213d1220ce..38c88c4b1f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -163,7 +163,6 @@ std::pair extract_qqmm_dims( int group_size, int bits) { // Validate w_q and scales_w - // https://docs.nvidia.com/cuda/cublas/#d-block-scaling-for-fp8-and-fp4-data-types validate_quantized_input( tag, w_q, scales_w, "weight matrix", "scales_w", group_size, bits); @@ -4208,6 +4207,7 @@ array quantized_matmul( } else { inputs = {x, w, scales}; } + if (x.ndim() > 2 && w.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b419c050f8..122e57c37b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -10,7 +10,6 @@ #include #include -#include #include "mlx/backend/common/utils.h" #include "mlx/fft.h" #include "mlx/linalg.h" diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c95535433e..a157647263 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5447,17 +5447,18 @@ void init_ops(nb::module_& m) { nb::sig( "def qqmm(x: array, w_q: array, /, scales: array, w: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - Perform the matrix multiplication with the quantized matrix ``w_q`` and ``x`` that is - quantized on-the-fly using provided group size, bits and mode. Group size, bits and mode - must match those used to quantize ``w_q``. High precision ``w`` must be provided - for gradient computation and should match ``w_q`` before quantization. + Perform matrix multiplication using the quantized weight matrix ``w_q`` and the input ``x``, + which is quantized on the fly using the provided group size, bit width, and mode. + The group size, bit width, and mode must match those used to quantize ``w_q``. + The high-precision weight matrix ``w`` must be provided for gradient computation + and must match ``w_q`` before quantization. Args: x (array): Input array w_q (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w_q`` w (array, optional): bf16 or float32 weights used during training for - gradient computation. Must be provided for vjp. Default: ``None``. + gradient computation. Default: ``None``. group_size (int, optional): The size of the group in ``w_q`` that shares a scale. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. From 6633c4bbeeef9b74a2908e0a0588992213428f90 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 29 Nov 2025 20:08:20 +0100 Subject: [PATCH 27/29] deleted qqlinear for now --- examples/cpp/qqmm.cu | 49 ----------- mlx/backend/cuda/cublas_utils.h | 4 - mlx/backend/cuda/gemms/cublas_gemm.h | 5 -- mlx/backend/cuda/quantized/cublas_qqmm.h | 26 ------ mlx/backend/cuda/quantized/quantized.cpp | 4 +- mlx/ops.cpp | 3 +- python/mlx/nn/layers/__init__.py | 7 +- python/mlx/nn/layers/quantized.py | 103 ----------------------- python/tests/test_nn.py | 12 --- 9 files changed, 4 insertions(+), 209 deletions(-) delete mode 100644 examples/cpp/qqmm.cu diff --git a/examples/cpp/qqmm.cu b/examples/cpp/qqmm.cu deleted file mode 100644 index 7598b9eb13..0000000000 --- a/examples/cpp/qqmm.cu +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/quantized/cublas_qqmm.h" -#include "mlx/mlx.h" -#include "mlx/stream.h" - -namespace mx = mlx::core; - -int main() { - int group_size = 16; - int bits = 4; - int M = 128; - int N = 128; - int K = 256; - std::string quantization_mode = "nvfp4"; - - mx::Device device(mx::Device::gpu, 0); - auto s = mx::default_stream(device); - auto& encoder = mx::cu::get_command_encoder(s); - - mx::array a = mx::random::uniform({M, K}, mx::bfloat16); // (M, K) - mx::array b = mx::random::uniform({N, K}, mx::bfloat16); // (N, K) - - auto scaled_b = mx::quantize(b, group_size, bits, quantization_mode); - mx::array b_quantized = scaled_b[0]; - mx::array b_scale = scaled_b[1]; - - mx::array out = mx::qqmm( - a, - b_quantized, - b_scale, - true, - group_size, - bits, - quantization_mode); - - auto aq = mx::quantize(a, group_size, bits, quantization_mode); - mx::array a_dequantized = - mx::dequantize(aq[0], aq[1], {}, group_size, bits, quantization_mode); - mx::array b_dequantized = - mx::dequantize(b_quantized, b_scale, {}, group_size, bits, quantization_mode); - - mx::array reference_deq = - mx::matmul(a_dequantized, mx::transpose(b_dequantized)); - mx::array isclose = mx::allclose(out, reference_deq, 1e-1f); - - std::cout << isclose << std::endl; - return 0; -} \ No newline at end of file diff --git a/mlx/backend/cuda/cublas_utils.h b/mlx/backend/cuda/cublas_utils.h index ebd422454f..c20e3857dd 100644 --- a/mlx/backend/cuda/cublas_utils.h +++ b/mlx/backend/cuda/cublas_utils.h @@ -29,10 +29,6 @@ class CublasMatmulBase { public: virtual ~CublasMatmulBase(); - cublasLtMatmulDesc_t matmul_desc() const { - return matmul_desc_; - } - void set_bias(cu::CommandEncoder& encoder, const array& bias); protected: diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index aca4f54dc0..1fad45ed24 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -77,11 +77,6 @@ class CublasGemm : public CublasMatmulBase { float alpha, float beta); - // Get the matmul descriptor for setting attributes like bias - cublasLtMatmulDesc_t matmul_desc() const { - return matmul_desc_; - } - private: void run_batched( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index 5dbaec17fa..8ad9748dcf 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -52,19 +52,6 @@ class CublasQQMM : public CublasMatmulBase { const array& b_scale, float alpha = 1.0f); - // void run( - // cu::CommandEncoder& encoder, - // array& out, - // const array& a, - // const array& b, - // const array& c, - // const Shape& batch_shape, - // const Strides& a_batch_strides, - // const Strides& b_batch_strides, - // const Strides& c_batch_strides, - // float alpha, - // float beta); - private: void run_batched( cu::CommandEncoder& encoder, @@ -78,19 +65,6 @@ class CublasQQMM : public CublasMatmulBase { const Strides& b_batch_strides, float alpha); - // void run_batched( - // cu::CommandEncoder& encoder, - // array& out, - // const array& a, - // const array& b, - // const array& c, - // const Shape& batch_shape, - // const Strides& a_batch_strides, - // const Strides& b_batch_strides, - // const Strides& c_batch_strides, - // float alpha, - // float beta); - void execute( cu::CommandEncoder& encoder, void* out, diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 8ad9111882..a4c9acfc9f 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -1,14 +1,14 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/quantized/quantized.h" -#include -#include "mlx/backend/common/matmul.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/quantized/cublas_qqmm.h" #include "mlx/backend/cuda/quantized/qqmm_utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" +#include + namespace mlx::core { namespace { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 38c88c4b1f..bde76ac944 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4207,7 +4207,7 @@ array quantized_matmul( } else { inputs = {x, w, scales}; } - + if (x.ndim() > 2 && w.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } @@ -5979,5 +5979,4 @@ array contiguous( std::make_shared(to_stream(s), allow_col_major), {a}); } - } // namespace mlx::core \ No newline at end of file diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index c2fba58347..ea2d3029d8 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -87,12 +87,7 @@ MaxPool3d, ) from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding -from mlx.nn.layers.quantized import ( - QQLinear, - QuantizedEmbedding, - QuantizedLinear, - quantize, -) +from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize from mlx.nn.layers.recurrent import GRU, LSTM, RNN from mlx.nn.layers.transformer import ( MultiHeadAttention, diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 32399896a4..c308e884ba 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -268,106 +268,3 @@ def from_linear( ql.bias = linear_layer.bias return ql - - -class QQLinear(Module): - """Quantizes the input and applies an affine transformation using a quantized weight matrix. - - Currently, bias is not implemented for this layer. - It is equivalent to :class:`mlx.nn.QuantizedLinear` but both input and weight are quantized. - Note: QQLinear supported only `nvfp4` and `mxfp8` quantization modes. - For training, parameter `train` should be set to `True`. - - :obj:`QQLinear` also provides a classmethod :meth:`from_linear` to - convert linear layers to :obj:`QQLinear` layers. - - Args: - input_dims (int): The dimensionality of the input features. - output_dims (int): The dimensionality of the output features. - group_size (int, optional): The group size to use for the quantized - weight. See :func:`~mlx.core.quantize`. Default: ``16``. - bits (int, optional): The bit width to use for the quantized weight. - See :func:`~mlx.core.quantize`. Default: ``4``. - mode (str): The quantization method to use (see - :func:`mlx.core.quantize`). Default: ``"nvfp4"``. - """ - - def __init__( - self, - input_dims: int, - output_dims: int, - bias: bool = False, - group_size: int = 16, - bits: int = 4, - mode: str = "nvfp4", - train: bool = False, - ): - super().__init__() - if bias: - raise NotImplementedError("Bias not implemented for QQLinear yet.") - # Quantization config - self.group_size = group_size - self.bits = bits - self.mode = mode - - # Initialize the quantized weight - scale = math.sqrt(1 / input_dims) - weight = mx.random.uniform( - low=-scale, - high=scale, - shape=(output_dims, input_dims), - ) - # if vjp: - if train: - self.weight = weight - self.weight_quantized, self.scales = mx.quantize( - weight, group_size, bits, mode=mode - ) - - def _extra_repr(self): - out_dims, in_dims = self.weight_quantized.shape - in_dims *= 32 // self.bits - return ( - f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " - f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" - ) - - def __call__(self, x): - x = mx.qqmm( - x, - self["weight_quantized"], - scales=self["scales"], - w=self.get("weight"), - group_size=self.group_size, - bits=self.bits, - mode=self.mode, - ) - return x - - @classmethod - def from_linear( - cls, - linear_layer: Module, - group_size: int = 16, - bits: int = 4, - mode: str = "nvfp4", - train: bool = False, - ): - """Create a :obj:`QQLinear` layer from a :obj:`Linear` layer.""" - output_dims, input_dims = linear_layer.weight.shape - ql = cls( - input_dims, output_dims, False, group_size, bits, mode=mode, train=train - ) - ql.weight_quantized, ql.scales = mx.quantize( - linear_layer.weight, - group_size, - bits, - mode=mode, - ) - if train: - ql.weight = linear_layer.weight - - if "bias" in linear_layer: - ql.bias = linear_layer.bias - - return ql diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index aea12fa336..6ded372278 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1948,18 +1948,6 @@ def test_attention(self): out = attn(x, x, x) self.assertEqual(out.shape, x.shape) - def test_qqlinear(self): - model = nn.Sequential( - nn.QQLinear(512, 256, bits=4, group_size=16, mode="nvfp4"), - nn.ReLU(), - nn.QQLinear(256, 128, bits=4, group_size=16, mode="nvfp4"), - nn.ReLU(), - nn.QQLinear(128, 256, bits=4, group_size=16, mode="nvfp4"), - ) - x = mx.random.normal(shape=(128, 512)) - out = model(x) - self.assertEqual(out.shape, (128, 256)) - if __name__ == "__main__": mlx_tests.MLXTestRunner() From a71e43675351af39ee1d6fa0d32f32e5db61aaca Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 29 Nov 2025 20:10:13 +0100 Subject: [PATCH 28/29] deleted unused header --- mlx/backend/cuda/matmul.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 12862f0c39..392a13ad82 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -1,7 +1,6 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/matmul.h" -#include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/gemv.h" From b36c6d74b76fabcb030fd8ac6429991e8d28b392 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 29 Nov 2025 23:28:36 +0100 Subject: [PATCH 29/29] delete debuging print --- mlx/primitives.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 140e9a143f..a05143e2e2 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3480,7 +3480,6 @@ std::vector DualQuantizedMatmul::output_shapes( auto out_shape = inputs[0].shape(); int w_outer_dims = inputs[2].shape(-2); out_shape.back() = w_outer_dims; - std::cout << "DualQuantizedMatmul output shape: " << out_shape << std::endl; return {std::move(out_shape)}; }