diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 75a8e62337..d3964a28b1 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -1145,4 +1145,9 @@ void fast::ConvertFP8::eval_cpu( }); } +void DualQuantizedMatmul::eval_cpu( + const std::vector& inputs, + array& out) { + throw std::runtime_error("DualQuantizedMatmul not implemented on CPU."); +} } // namespace mlx::core diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7986c09d88..5b2f504108 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -18,6 +18,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/cublas_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp @@ -64,6 +65,11 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) # fp4 is not available on < 12.8 if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0) target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/) +else() + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu) + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp) endif() if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) diff --git a/mlx/backend/cuda/cublas_utils.cpp b/mlx/backend/cuda/cublas_utils.cpp new file mode 100644 index 0000000000..9a2717fdeb --- /dev/null +++ b/mlx/backend/cuda/cublas_utils.cpp @@ -0,0 +1,238 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cublas_utils.h" +#include "mlx/backend/cuda/cuda.h" +#include "mlx/utils.h" + +namespace mlx::core { +namespace cublas_utils { + +namespace { + +struct CublasPreference { + CublasPreference(cu::Device& device) { + // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB + // for Hopper+: + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace + uint64_t MiB = 1024 * 1024; + uint64_t workspace_size = + device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; + + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( + pref_, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(uint64_t))); + } + + ~CublasPreference() { + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); + } + + cublasLtMatmulPreference_t pref_{nullptr}; +}; + +} // namespace + +cublasLtMatmulPreference_t get_preference(cu::Device& device) { + static CublasPreference pref(device); + return pref.pref_; +} + +void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size) { + if (workspace_size == 0) { + return nullptr; + } + + // Ensure workspace is 256-byte aligned + int nbytes = cuda::ceil_div(workspace_size, 256) * 256; + array workspace( + cu::malloc_async(nbytes, encoder), + {static_cast(workspace_size)}, + int8); + encoder.add_temporary(workspace); + return gpu_ptr(workspace); +} + +cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + cublasLtMatrixLayout_t desc; + if (transposed) { + std::swap(rows, cols); + } + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); + if (batch_count > 1) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_count, + sizeof(int32_t))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, + sizeof(int64_t))); + } + return desc; +} + +} // namespace cublas_utils + +CublasMatmulBase::~CublasMatmulBase() { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); +} + +void CublasMatmulBase::init_base( + cu::Device& device, + cudaDataType_t scale_type, + cublasComputeType_t compute_type, + cudaDataType_t data_type, + cudaDataType_t output_type, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride) { + M_ = a_rows; + N_ = b_cols; + scale_type_ = scale_type; + handle_ = device.lt_handle(); + pref_ = cublas_utils::get_preference(device); + heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; + + CHECK_CUBLAS_ERROR( + cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type)); + + int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(int32_t))); + + // In cublasLt matrices use column-major layout, while it is possible to use + // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias + // epilogue does not work with the option. So instead we swap A and B to make + // cublasLt return the row-major result, which works because: + // - the data of a matrix in row-major layout is identical to its transpose in + // column-major layout + // - C^T = (A @ B)^T = B^T @ A^T + cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &a_op, + sizeof(cublasOperation_t))); + cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &b_op, + sizeof(cublasOperation_t))); + + a_desc_ = cublas_utils::create_matrix_layout( + data_type, + b_cols, + b_rows, + b_transposed, + ldb, + batch_count, + b_batch_stride); + b_desc_ = cublas_utils::create_matrix_layout( + data_type, + a_cols, + a_rows, + a_transposed, + lda, + batch_count, + a_batch_stride); + out_desc_ = cublas_utils::create_matrix_layout( + output_type, b_cols, a_rows, false, b_cols, batch_count, b_cols * a_rows); +} + +void CublasMatmulBase::execute_matmul( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + const void* alpha_ptr, + const void* beta_ptr) { + if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { + int ret = 0; + CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( + handle_, + matmul_desc_, + a_desc_, + b_desc_, + c ? c_desc_ : out_desc_, + out_desc_, + pref_, + 1, + &heuristic_, + &ret)); + if (ret == 0) { + throw std::runtime_error("Can not find algorithm for matmul."); + } + } + + void* workspace_ptr = + cublas_utils::allocate_workspace(encoder, heuristic_.workspaceSize); + + // Execute matmul + auto capture = encoder.capture_context(); + CHECK_CUBLAS_ERROR(cublasLtMatmul( + handle_, + matmul_desc_, + alpha_ptr, + b, // a and b are swapped for row-major layout + a_desc_, + a, + b_desc_, + beta_ptr, + c ? c : out, + c ? c_desc_ : out_desc_, + out, + out_desc_, + &heuristic_.algo, + workspace_ptr, + heuristic_.workspaceSize, + encoder.stream())); +} + +void CublasMatmulBase::set_bias( + cu::CommandEncoder& encoder, + const array& bias) { + encoder.set_input_array(bias); + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue))); + auto* bias_ptr = gpu_ptr(bias); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias_ptr, + sizeof(bias_ptr))); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/cublas_utils.h b/mlx/backend/cuda/cublas_utils.h new file mode 100644 index 0000000000..c20e3857dd --- /dev/null +++ b/mlx/backend/cuda/cublas_utils.h @@ -0,0 +1,78 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { +namespace cublas_utils { + +// Get the shared cublas preference for a device +cublasLtMatmulPreference_t get_preference(cu::Device& device); + +void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size); + +cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + +} // namespace cublas_utils + +class CublasMatmulBase { + public: + virtual ~CublasMatmulBase(); + + void set_bias(cu::CommandEncoder& encoder, const array& bias); + + protected: + CublasMatmulBase() = default; + + // Common member variables shared by all matmul types + uint64_t M_; + uint64_t N_; + cudaDataType_t scale_type_; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtHandle_t handle_{nullptr}; + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulHeuristicResult_t heuristic_; + + void init_base( + cu::Device& device, + cudaDataType_t scale_type, + cublasComputeType_t compute_type, + cudaDataType_t data_type, + cudaDataType_t output_type, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + void execute_matmul( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + const void* alpha_ptr, + const void* beta_ptr); +}; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 0d4e25f5ae..05d1fecbc2 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/dtype_utils.h" #include "mlx/utils.h" @@ -11,35 +12,6 @@ namespace mlx::core { namespace { -struct CublasPreference { - CublasPreference(cu::Device& device) { - // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB - // for Hopper+: - // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace - uint64_t MiB = 1024 * 1024; - uint64_t workspace_size = - device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; - - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( - pref_, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(uint64_t))); - } - - ~CublasPreference() { - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); - } - - cublasLtMatmulPreference_t pref_{nullptr}; -}; - -cublasLtMatmulPreference_t cublas_preference(cu::Device& device) { - static CublasPreference pref(device); - return pref.pref_; -} - cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { case float16: @@ -78,34 +50,6 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) { } } -cublasLtMatrixLayout_t create_matrix_layout( - cudaDataType_t type, - uint64_t rows, - uint64_t cols, - bool transposed, - int64_t ld, - int32_t batch_count, - int64_t batch_stride) { - cublasLtMatrixLayout_t desc; - if (transposed) { - std::swap(rows, cols); - } - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); - if (batch_count > 1) { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch_count, - sizeof(int32_t))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &batch_stride, - sizeof(int64_t))); - } - return desc; -} - } // namespace CublasGemm::CublasGemm( @@ -121,54 +65,28 @@ CublasGemm::CublasGemm( int64_t ldb, int32_t batch_count, int64_t a_batch_stride, - int64_t b_batch_stride) - : handle_(device.lt_handle()), - pref_(cublas_preference(device)), - M_(a_rows), - N_(b_cols) { - heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - + int64_t b_batch_stride) { scale_type_ = dtype_to_cublas_type(dtype); if (dtype == bfloat16 || dtype == float16) { scale_type_ = CUDA_R_32F; } - - CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( - &matmul_desc_, dtype_to_compute_type(dtype), scale_type_)); - int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, - sizeof(int32_t))); - - // In cublasLt matrices use column-major layout, while it is possible to use - // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias - // epilogue does not work with the option. So instead we swap A and B to make - // cublasLt return the row-major result, which works because: - // - the data of a matrix in row-major layout is identical to its transpose in - // column-major layout - // - C^T = (A @ B)^T = B^T @ A^T - cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &a_op, - sizeof(cublasOperation_t))); - cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, - &b_op, - sizeof(cublasOperation_t))); - - auto type = dtype_to_cublas_type(dtype); - a_desc_ = create_matrix_layout( - type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride); - b_desc_ = create_matrix_layout( - type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride); - out_desc_ = create_matrix_layout( - type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols); + init_base( + device, + scale_type_, + dtype_to_compute_type(dtype), + dtype_to_cublas_type(dtype), + dtype_to_cublas_type(dtype), + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride); } CublasGemm::CublasGemm( @@ -202,18 +120,10 @@ CublasGemm::CublasGemm( a_batch_stride, b_batch_stride) { auto type = dtype_to_cublas_type(dtype); - c_desc_ = create_matrix_layout( + c_desc_ = cublas_utils::create_matrix_layout( type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride); } -CublasGemm::~CublasGemm() { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); -} - void CublasGemm::set_out( Dtype dtype, bool transposed, @@ -223,7 +133,7 @@ void CublasGemm::set_out( int32_t batch_count, int64_t batch_stride) { CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); - out_desc_ = create_matrix_layout( + out_desc_ = cublas_utils::create_matrix_layout( dtype_to_cublas_type(dtype), cols, rows, @@ -233,22 +143,6 @@ void CublasGemm::set_out( batch_stride); } -void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) { - encoder.set_input_array(bias); - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epilogue, - sizeof(epilogue))); - auto* bias_ptr = gpu_ptr(bias); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_ptr, - sizeof(bias_ptr))); -} - void CublasGemm::run( cu::CommandEncoder& encoder, array& out, @@ -337,24 +231,6 @@ void CublasGemm::execute( const void* c, float alpha /* = 1 */, float beta /* = 0 */) { - if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { - int ret = 0; - CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( - handle_, - matmul_desc_, - a_desc_, - b_desc_, - c ? c_desc_ : out_desc_, - out_desc_, - pref_, - 1, - &heuristic_, - &ret)); - if (ret == 0) { - throw std::runtime_error("Can not find algorithm for matmul."); - } - } - const void* alpha_ptr = α const void* beta_ptr = β complex64_t alpha_c, beta_c; @@ -365,36 +241,7 @@ void CublasGemm::execute( beta_ptr = &beta_c; } - void* workspace_ptr = nullptr; - if (heuristic_.workspaceSize > 0) { - // Ensure workspace is 256-byte aligned - int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; - array workspace( - cu::malloc_async(nbytes, encoder), - {static_cast(heuristic_.workspaceSize)}, - int8); - encoder.add_temporary(workspace); - workspace_ptr = gpu_ptr(workspace); - } - - auto capture = encoder.capture_context(); - CHECK_CUBLAS_ERROR(cublasLtMatmul( - handle_, - matmul_desc_, - alpha_ptr, - b, // a and b are swapped - a_desc_, - a, - b_desc_, - beta_ptr, - c ? c : out, - c ? c_desc_ : out_desc_, - out, - out_desc_, - &heuristic_.algo, - workspace_ptr, - heuristic_.workspaceSize, - encoder.stream())); + execute_matmul(encoder, out, a, b, c, alpha_ptr, beta_ptr); } } // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index d6d2189b95..1fad45ed24 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -2,13 +2,14 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/cublas_utils.h" #include "mlx/backend/cuda/device.h" #include namespace mlx::core { -class CublasGemm { +class CublasGemm : public CublasMatmulBase { public: CublasGemm( cu::Device& device, @@ -42,8 +43,6 @@ class CublasGemm { int64_t b_batch_stride, int64_t c_batch_stride); - ~CublasGemm(); - // The output's descriptor is inferred from inputs by default, use this method // for unusual output. void set_out( @@ -55,8 +54,6 @@ class CublasGemm { int32_t batch_count, int64_t batch_stride); - void set_bias(cu::CommandEncoder& encoder, const array& bias); - void run( cu::CommandEncoder& encoder, array& out, @@ -112,18 +109,6 @@ class CublasGemm { const void* c, float alpha = 1, float beta = 0); - - uint64_t M_; - uint64_t N_; - cudaDataType_t scale_type_; - cublasLtMatmulPreference_t pref_{nullptr}; - cublasLtHandle_t handle_{nullptr}; - cublasLtMatmulDesc_t matmul_desc_{nullptr}; - cublasLtMatrixLayout_t a_desc_{nullptr}; - cublasLtMatrixLayout_t b_desc_{nullptr}; - cublasLtMatrixLayout_t c_desc_{nullptr}; - cublasLtMatrixLayout_t out_desc_{nullptr}; - cublasLtMatmulHeuristicResult_t heuristic_; }; } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp new file mode 100644 index 0000000000..1ef6d7bc67 --- /dev/null +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -0,0 +1,200 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/quantized/cublas_qqmm.h" + +#include +#include "mlx/backend/cuda/cublas_utils.h" + +#include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +// Currently cublas supports only mxfp8 and nvfp4 +// quantization modes for block scaled quantization +cudaDataType_t qmode_to_cublas_scale_dtype(std::string mode) { + if (mode == "mxfp8") { + return CUDA_R_8F_UE8M0; + } else if (mode == "nvfp4") { + return CUDA_R_8F_UE4M3; + } else { + throw std::runtime_error( + fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); + } +} + +cudaDataType_t qmode_to_cublas_dtype(std::string mode) { + if (mode == "mxfp8") { + return CUDA_R_8F_E4M3; + } else if (mode == "nvfp4") { + return CUDA_R_4F_E2M1; + } else { + throw std::runtime_error( + fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); + } +} + +cublasLtMatmulMatrixScale_t qmode_to_cublas_scale_mode(std::string mode) { + if (mode == "mxfp8") { + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + } else if (mode == "nvfp4") { + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + } else { + throw std::runtime_error( + fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); + } +} + +} // namespace + +CublasQQMM::CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + std::string qmode) { + cudaDataType_t scale_type = CUDA_R_32F; + cudaDataType_t output_type = CUDA_R_16BF; // always output in bf16 + cublasComputeType_t gemm_compute_type = + CUBLAS_COMPUTE_32F; // always for narrow precision + cudaDataType_t data_type = qmode_to_cublas_dtype(qmode); + quantization_mode_ = std::string(qmode); + + init_base( + device, + scale_type, + gemm_compute_type, + data_type, + output_type, + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride); + + a_scale_mode_ = qmode_to_cublas_scale_mode(qmode); + b_scale_mode_ = qmode_to_cublas_scale_mode(qmode); + + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_B_SCALE_MODE, + &a_scale_mode_, + sizeof(a_scale_mode_))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_A_SCALE_MODE, + &b_scale_mode_, + sizeof(b_scale_mode_))); +} + +CublasQQMM::CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride, + std::string qmode) + : CublasQQMM( + device, + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride, + qmode) { + auto type = CUDA_R_16BF; // always c in bf16 + c_desc_ = cublas_utils::create_matrix_layout( + type, + b_transposed ? b_rows : b_cols, + a_transposed ? a_cols : a_rows, + false, + ldc, + batch_count, + c_batch_stride); +} + +void CublasQQMM::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + float alpha) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(a_scale); + encoder.set_input_array(b_scale); + encoder.set_output_array(out); + + execute( + encoder, + gpu_ptr(out), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(a_scale), + gpu_ptr(b_scale), + nullptr, + alpha); +} + +void CublasQQMM::execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + float alpha /* = 1 */, + float beta /* = 0 */) { + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &b_scale, + sizeof(b_scale))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &a_scale, + sizeof(a_scale))); + + const void* alpha_ptr = α + const void* beta_ptr = β + + execute_matmul(encoder, out, a, b, c, alpha_ptr, beta_ptr); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h new file mode 100644 index 0000000000..8ad9748dcf --- /dev/null +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -0,0 +1,86 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/cublas_utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +class CublasQQMM : public CublasMatmulBase { + public: + CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + std::string quantization_mode); + + CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride, + std::string quantization_mode); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + float alpha = 1.0f); + + private: + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + float alpha = 1, + float beta = 0); + + std::string quantization_mode_; + cublasLtMatmulMatrixScale_t a_scale_mode_; + cublasLtMatmulMatrixScale_t b_scale_mode_; + cublasLtMatmulMatrixScale_t c_scale_mode_; + cublasLtMatmulMatrixScale_t out_scale_mode_; +}; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu new file mode 100644 index 0000000000..ff19057b08 --- /dev/null +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -0,0 +1,169 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/quantized/qqmm_utils.h" + +#include + +namespace mlx::core { + +namespace cg = cooperative_groups; + +// To pass scales to tensor cores, they need to be repacked into a tiled layout +// https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout +// Tiled layout for scale factors is very well described in CUTLASS +// documentation: +// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts +// Conceptually, it should be like this: +// q_w = mx.zeros(shape=(M, N)) <-- zeros just for an example +// s.shape = (M, N // 16) -- packed in row contigous order, group_size = 16 +// cbg_cnt = N // 16 // 4 +// rb_cnt = M // 128 +// tmp = x.reshape(rb_cnt, 4, 32, cbg_cnt, 4) +// repacked_scales = tmp.transpose(0, 3, 2, 1, 4) +// example: indecis of intial tile 128 x 4 of scales (packed in row major tensor +// (M, K // 16), where M = 128, K = 64): array([[0, 1, 2, 3], +// [4, 5, 6, 7], +// [8, 9, 10, 11], +// ..., +// [500, 501, 502, 503], +// [504, 505, 506, 507], +// [508, 509, 510, 511]] +// packed scales within tile 128 x 4: +// array([[[[[0, 1, 2, 3], <-- s_0,0..s_0,3 scales +// [128, 129, 130, 131], <-- s_32,0..s_32,3 scales +// [256, 257, 258, 259], <-- s_64,0..s_64,3 scales +// [384, 385, 386, 387]], <-- s_96,0..s_96,3 scales +// [[4, 5, 6, 7], <-- s_1,0..s_1,3 scales +// [132, 133, 134, 135], ... +// [260, 261, 262, 263], +// [388, 389, 390, 391]], +// [[124, 125, 126, 127], +// [252, 253, 254, 255], +// [380, 381, 382, 383], +// [508, 509, 510, 511]]]]], +__device__ size_t +scale_tiled_offset(size_t scale_index, size_t num_rows, size_t num_scale_cols) { + // Compute the tiled layout offset for scale factors used in tensor cores + // This function maps from a linear scale index to the tiled layout expected + // by tensor cores (and cublaslt). + // + // Input: linear scale index (e.g., for a matrix M x K with group_size, + // scale_index ranges from 0 to (M * K/group_size - 1)) + // + // The tiled layout organizes scales into tiles of 128 rows x 4 columns, + // where each tile is subdivided into 4 sub-blocks of 32 rows x 4 columns. + size_t row = scale_index / num_scale_cols; + size_t col = scale_index % num_scale_cols; + + constexpr size_t rows_per_tile = 128; + constexpr size_t rows_per_sub_block = 32; + constexpr size_t cols_per_sub_block = 4; + constexpr size_t sub_blocks_per_tile = 4; // Vertically stacked + + // Decompose row position + size_t tile_row = row / rows_per_tile; // Which tile row + size_t row_in_tile = row % rows_per_tile; // Row within tile + size_t sub_block_row = + row_in_tile / rows_per_sub_block; // Sub-block within tile + size_t row_in_sub_block = + row_in_tile % rows_per_sub_block; // Row in sub-block + + // Decompose column position + size_t col_tile = col / cols_per_sub_block; // Which column tile + size_t col_in_sub_block = col % cols_per_sub_block; // Column within sub-block + + // Compute tile index and offset within tile + size_t num_col_tiles = cuda::ceil_div(num_scale_cols, cols_per_sub_block); + size_t tile_idx = tile_row * num_col_tiles + col_tile; + + size_t offset_in_tile = + (row_in_sub_block * sub_blocks_per_tile * cols_per_sub_block) + + (sub_block_row * cols_per_sub_block) + col_in_sub_block; + + constexpr size_t tile_size = rows_per_tile * cols_per_sub_block; + return tile_idx * tile_size + offset_in_tile; +} + +namespace cu { + +__global__ void repack_scales( + const uint8_t* scales_linear, + uint8_t* scales_tiled, + size_t input_rows, + size_t input_cols, + size_t output_rows, + size_t output_cols) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim_x = + cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; + + size_t output_index = tidx + grid_dim_x * size_t(tidy); + size_t output_size = output_rows * output_cols; + + if (output_index >= output_size) { + return; + } + + size_t tiled_offset = + scale_tiled_offset(output_index, output_rows, output_cols); + + size_t row = output_index / output_cols; + size_t col = output_index % output_cols; + + // Probably this can be done better with 2 separated paths for valid and + // padding + if (row < input_rows && col < input_cols) { + size_t input_index = row * input_cols + col; + scales_tiled[tiled_offset] = scales_linear[input_index]; + } else { + // Zero-fill padding region + scales_tiled[tiled_offset] = 0; + } +} + +} // namespace cu + +void repack_scales( + const array& scales, + array& scales_tiled, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(scales); + enc.set_output_array(scales_tiled); + + // Note: scales_tiled is padded to full tiles so if num_rows or num_cols + // are not multiples of tile sizes, the extra space is filled with zeros + + size_t input_rows = scales.shape(-2); + size_t input_cols = scales.shape(-1); + + size_t output_rows = scales_tiled.shape(-2); + size_t output_cols = scales_tiled.shape(-1); + size_t output_size = output_rows * output_cols; + + bool large = output_size > UINT_MAX; + auto [num_blocks, block_dims] = get_launch_args( + output_size, scales_tiled.shape(), scales_tiled.strides(), large); + + enc.add_kernel_node( + cu::repack_scales, + num_blocks, + block_dims, + 0, + gpu_ptr(scales), + gpu_ptr(scales_tiled), + input_rows, + input_cols, + output_rows, + output_cols); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.h b/mlx/backend/cuda/quantized/qqmm_utils.h new file mode 100644 index 0000000000..126cc298b2 --- /dev/null +++ b/mlx/backend/cuda/quantized/qqmm_utils.h @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +// Compute padded dimensions for tiled layout +// Tiles are 128 rows × 4 columns, must allocate full tiles +inline std::pair get_padded_scale_dims(int num_rows, int num_cols) { + constexpr int rows_per_tile = 128; + constexpr int cols_per_tile = 4; + + int padded_rows = + ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile; + int padded_cols = + ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile; + + return {padded_rows, padded_cols}; +} + +void repack_scales( + const array& scales, + array& scales_tiled, + cu::CommandEncoder& enc, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index a185523edc..a4c9acfc9f 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -2,6 +2,8 @@ #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/quantized/cublas_qqmm.h" +#include "mlx/backend/cuda/quantized/qqmm_utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" @@ -44,6 +46,29 @@ inline array ensure_row_contiguous_matrix( return x_copy; } +array pad_and_repack_scales( + const array& scale, + cu::CommandEncoder& encoder, + const Stream& s) { + // Compute padded dimensions for full tiles (128 rows × 4 cols) + auto [pad_outer, pad_inner] = + get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); + // cuBLAS requirements for scale factor layout: + // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) + // 2. Out-of-bounds values must be filled with zeros + // 3. Starting addresses must be 16-byte aligned + // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + // Note: cu::malloc_async already provides 256-byte alignment + array scale_tiled( + cu::malloc_async(pad_outer * pad_inner, encoder), + Shape{pad_outer, pad_inner}, + scale.dtype()); + repack_scales(scale, scale_tiled, encoder, s); + + encoder.add_temporary(scale_tiled); + return scale_tiled; +} + } // namespace void fast::Quantize::eval_gpu( @@ -84,4 +109,110 @@ void fast::Quantize::eval_gpu( } } +namespace { +void qqmm_impl( + cu::CommandEncoder& encoder, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + QuantizationMode mode, + float alpha = 1.0f) { + // Invoke CublasQQMM + std::string qmode = quantization_mode_to_string(mode); + + // Currently only supports non-batched QQMM operations + // that covers all use cases for training, we will just collapse (batch, + // seq_len) into (tokens) + CublasQQMM qqmm( + encoder.device(), + a_transposed, + M, + K, + lda, + b_transposed, + K, + N, + ldb, + 1, // batch_count + 0, // a_batch_stride + 0, // b_batch_stride + qmode); + + qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha); +} +} // namespace + +void DualQuantizedMatmul::eval_gpu( + const std::vector& inputs, + array& out) { + nvtx3::scoped_range r("DualQuantizedMatmul::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + assert(inputs.size() == 4 || inputs.size() == 3); + auto& x = inputs[0]; // activations bf16 + auto& w_q = inputs[1]; // quantized weights + auto& scale_w_pre = inputs[2]; // weight scales + + auto quantize_activation = + [&](const array& input, cu::CommandEncoder& encoder, const Stream& s) { + auto x = ensure_row_contiguous(input, encoder, s); + auto xq_shape = x.shape(); + xq_shape.back() = x.shape(-1) * bits_ / 32; + auto sshape = x.shape(); + sshape.back() = x.shape(-1) / group_size_; + array x_q( + cu::malloc_async(x.size() * bits_ / 8, encoder), xq_shape, uint32); + array scales_x( + cu::malloc_async(x.size() / group_size_ * sizeof(uint8), encoder), + sshape, + uint8); + fp_quantize(x, x_q, scales_x, group_size_, bits_, encoder, s); + encoder.add_temporary(scales_x); + encoder.add_temporary(x_q); + return std::make_pair(x_q, scales_x); + }; + + auto [x_q, scale_x_pre] = quantize_activation(inputs[0], encoder, s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + int M = x_q.shape(-2); + int N = w_q.shape(-2); // always transposed + int K_packed = x_q.shape(-1); + int K = K_packed * (32 / bits_); + + // Repack scales from linear to tiled layout for tensor cores + array scale_x = pad_and_repack_scales(scale_x_pre, encoder, s); + array scale_w = pad_and_repack_scales(scale_w_pre, encoder, s); + + bool x_transposed = false; + bool w_transposed = true; // always transposed + int64_t lda = K; + int64_t ldb = K; + + qqmm_impl( + encoder, + M, + N, + K, + x_transposed, + lda, + w_transposed, + ldb, + out, + x_q, + w_q, + scale_x, + scale_w, + mode_); +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index c2636b614d..8b6a6cb0d2 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -97,6 +97,7 @@ NO_CPU(Partition) NO_CPU(Power) NO_CPU_MULTI(QRF) NO_CPU(QuantizedMatmul) +NO_CPU(DualQuantizedMatmul) NO_CPU(RandomBits) NO_CPU(Real) NO_CPU(Reduce) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 40695283c7..99cc90d580 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -124,6 +124,7 @@ NO_GPU(Partition) NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) +NO_GPU(DualQuantizedMatmul) NO_GPU(RandomBits) NO_GPU(Real) NO_GPU(Reduce) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index fbe6799373..636be40cad 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -70,22 +70,24 @@ array indices_or_default( return reshape(arange(total, uint32, s), std::move(shape), s); } -std::pair extract_quantized_matmul_dims( +void validate_quantized_input( std::string_view tag, - const array& x, - const array& w, + const array& matrix, const array& scales, - const std::optional& biases, - bool transpose, + const std::string& matrix_name, + const std::string& scales_name, int group_size, - int bits) { - if (w.dtype() != uint32) { + int bits, + const std::optional& biases = std::nullopt) { + // If matrix is quantized + if (matrix.dtype() != uint32) { std::ostringstream msg; - msg << "[" << tag << "] The weight matrix should be uint32 " - << "but received " << w.dtype(); + msg << "[" << tag << "] The " << matrix_name << " should be uint32 " + << "but received " << matrix.dtype(); throw std::invalid_argument(msg.str()); } + // If biases and scales have same shape if biases provided if (biases && scales.shape() != biases->shape()) { std::ostringstream msg; msg << "[" << tag << "] Scales and biases should have the same shape. " @@ -94,24 +96,43 @@ std::pair extract_quantized_matmul_dims( throw std::invalid_argument(msg.str()); } + // Batch shapes match if (!std::equal( - w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) { + matrix.shape().begin(), + matrix.shape().end() - 2, + scales.shape().begin())) { std::ostringstream msg; - msg << "[" << tag - << "] Weight and scales should have the same batch shape. " - << "Received weight with shape " << w.shape() << ", scales with " - << scales.shape() << "."; + msg << "[" << tag << "] " << matrix_name << " and " << scales_name + << " should have the same batch shape. " + << "Received " << matrix_name << " with shape " << matrix.shape() + << ", " << scales_name << " with " << scales.shape() << "."; throw std::invalid_argument(msg.str()); } - if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { + // Shape compatibility based on bits and group_size + if (matrix.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { std::ostringstream msg; - msg << "[" << tag << "] The shapes of the weight and scales are " - << "incompatible based on bits and group_size. w.shape() == " - << w.shape() << " and scales.shape() == " << scales.shape() - << " with group_size=" << group_size << " and bits=" << bits; + msg << "[" << tag << "] The shapes of the " << matrix_name << " and " + << scales_name << " are " + << "incompatible based on bits and group_size. " << matrix_name + << ".shape() == " << matrix.shape() << " and " << scales_name + << ".shape() == " << scales.shape() << " with group_size=" << group_size + << " and bits=" << bits; throw std::invalid_argument(msg.str()); } +} + +std::pair extract_quantized_matmul_dims( + std::string_view tag, + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + bool transpose, + int group_size, + int bits) { + validate_quantized_input( + tag, w, scales, "weight", "scales", group_size, bits, biases); int x_inner_dims = x.shape(-1); @@ -133,6 +154,46 @@ std::pair extract_quantized_matmul_dims( return {w_inner_dims, w_outer_dims}; } +std::pair extract_qqmm_dims( + std::string_view tag, + const array& x, + const array& w_q, + const array& scales_w, + const std::optional& w, + int group_size, + int bits) { + // Validate w_q and scales_w + validate_quantized_input( + tag, w_q, scales_w, "weight matrix", "scales_w", group_size, bits); + + if (w && + (w->shape(-1) != w_q.shape(-1) * 32 / bits || + w->shape(-2) != w_q.shape(-2))) { + std::ostringstream msg; + msg << "[" << tag << "] The shape of the weight matrix and its " + << "quantized version are incompatible. Received weight matrix " + << "with shape " << w->shape() << " and quantized weight matrix " + << "with shape " << w_q.shape() << " with bits=" << bits; + throw std::invalid_argument(msg.str()); + } + int x_inner_dims = x.shape(-1) / (32 / bits); // K + + int w_inner_dims = w_q.shape(-1); + int w_outer_dims = w_q.shape(-2); + + if (w_inner_dims != x_inner_dims) { + std::ostringstream msg; + msg << "[" << tag << "] Inner dimension of second input with " + << "shape (" << w_inner_dims << ", " << w_outer_dims << ")" + << " does not match the packed inner dimension of the first" + << "input (...," << x_inner_dims << ") computed with bits=" << bits; + + throw std::invalid_argument(msg.str()); + } + + return {w_inner_dims, w_outer_dims}; +} + } // namespace array arange( @@ -4246,6 +4307,73 @@ array quantized_matmul( std::move(inputs)); } +array qqmm( + array x, + array w_q, + array scales_w, + std::optional w /* = std::nullopt */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, + const std::string& mode /* = "nvfp4" */, + StreamOrDevice s /* = {} */) { + // currently only simetric quantization is supported for qqmm + auto qmode = string_to_quantization_mode(mode, "qqmm"); + // cuBLAS block scaled matmul only supports nvfp4 and mxfp8 + if (qmode != QuantizationMode::Nvfp4 && qmode != QuantizationMode::Mxfp8) { + std::ostringstream msg; + msg << "[qqmm] only 'nvfp4' and 'mxfp8' quantization modes are supported but '" + << mode << "' was provided."; + throw std::invalid_argument(msg.str()); + } + // for fp4 block scaling the only supported layout is TN + // https://docs.nvidia.com/cutlass/4.2.1/media/docs/cpp/blackwell_functionality.html + // because w_q should always be quantized along the reduction dimension + // and we quantize so that the last dim is packed, we assume that the last dim + // always the reduction dim so the firat argument in cubals column major is + // (the second argument in qqmm) always transposed + if (qmode == QuantizationMode::Affine) { + std::ostringstream msg; + msg << "[qqmm] Affine quantization is not supported for qqmm."; + throw std::invalid_argument(msg.str()); + } + if (x.ndim() > 2 || w_q.ndim() > 2) { + std::ostringstream msg; + msg << "[qqmm] Only 2D inputs are supported but " + << "x.ndim() == " << x.ndim() << " and " + << "w_q.ndim() == " << w_q.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); + auto [w_inner_dims, w_outer_dims] = + extract_qqmm_dims("qqmm", x, w_q, scales_w, w, group_size, bits); + + // we don't backprope through qunatized w and scales + std::vector inputs = { + x, + stop_gradient(w_q), + stop_gradient(scales_w), + }; + // if bf16 w is provided, add it to inputs for vjps + if (w.has_value()) { + inputs.push_back(*w); + } + if (x.ndim() > 2 && w_q.ndim() > 2) { + inputs = broadcast_arrays(inputs, {-2, -1}, s); + } + + auto out_shape = inputs[0].shape(); + out_shape.back() = w_outer_dims; + auto dtype = bfloat16; + // out dtype can be only bf16 for now + return array( + std::move(out_shape), + dtype, + std::make_shared( + to_stream(s), group_size, bits, qmode), + std::move(inputs)); +} + array pack_and_quantize( array& packed_w, const array& scales, @@ -5936,5 +6064,4 @@ array contiguous( std::make_shared(to_stream(s), allow_col_major), {a}); } - -} // namespace mlx::core +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/ops.h b/mlx/ops.h index ff77baf3e8..95d0718268 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1409,6 +1409,16 @@ std::vector quantize( const std::string& mode = "affine", StreamOrDevice s = {}); +array qqmm( + array x, // input activations + array w_q, // quantized weights + array w_scales, + std::optional w = std::nullopt, // optional bf16 weights for vjp + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "nvfp4", + StreamOrDevice s = {}); + /** Dequantize a matrix produced by quantize() */ array dequantize( const array& w, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 5d96301bd7..a05143e2e2 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3468,6 +3468,84 @@ std::vector QuantizedMatmul::output_shapes( return {std::move(out_shape)}; } +bool DualQuantizedMatmul::is_equivalent(const Primitive& other) const { + const DualQuantizedMatmul& qm_other = + static_cast(other); + return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && + mode_ == qm_other.mode_; +} + +std::vector DualQuantizedMatmul::output_shapes( + const std::vector& inputs) { + auto out_shape = inputs[0].shape(); + int w_outer_dims = inputs[2].shape(-2); + out_shape.back() = w_outer_dims; + return {std::move(out_shape)}; +} + +std::vector DualQuantizedMatmul::vjp( + const std::vector& primals, // bf16 x, quantized weights + const std::vector& cotangents, // bf16 grads + const std::vector& argnums, + const std::vector&) { + std::vector vjps; + auto& cotan = cotangents[0]; + std::vector reorder(cotan.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + std::iter_swap(reorder.end() - 1, reorder.end() - 2); + auto& s = stream(); + // primal[1] -- quantized weights (row wise) + // primal[2] -- scales_w + // primal[0] -- bf16 activations (M, K) + // primal[3] -- bf16 weights (N, K) + // cotan -- bf16 activation grads (M, N) + auto qmode = quantization_mode_to_string(mode_); + for (auto arg : argnums) { + if (arg == 0) { // gradient wrt to x + // We transpose weights -> quantize along N -> qqmm (cotan quantized in + // eval_gpu) + auto wtq = quantize( + transpose(primals[3], {1, 0}, s), // we assume that weights are 2D + group_size_, + bits_, + qmode, + s); // (K, N_packed), scales + vjps.push_back(qqmm( + cotan, // M X N + wtq[0], // K X N_packed + wtq[1], // scales + std::nullopt, + group_size_, + bits_, + qmode, + s)); + } else if (arg == 3) { // gradient wrt to weights + // it is a bit complicated -- we need to quantize along M but cotan is + // (M,N) so we transpose cotan -> quantize along M -> qqmm + auto xt = transpose(primals[0], reorder, s); // (K, M) + auto xtq = quantize(xt, group_size_, bits_, qmode, + s); // (N, M_packed) + vjps.push_back(qqmm( + transpose(cotan, reorder, s), // (N, M) + xtq[0], // (N, M_packed) + xtq[1], // scales + std::nullopt, + group_size_, + bits_, + qmode, + s)); // (K, M), (N, M_packed) + } + } + return vjps; +} + +std::vector DualQuantizedMatmul::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + throw std::runtime_error("DualQuantizedMatmul::jvp NYI"); +} + std::pair, std::vector> GatherQMM::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 87aad33be5..a84b858d30 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1642,6 +1642,37 @@ class QuantizedMatmul : public UnaryPrimitive { bool transpose_; }; +class DualQuantizedMatmul : public UnaryPrimitive { + public: + explicit DualQuantizedMatmul( + Stream stream, + int group_size, + int bits, + QuantizationMode mode) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + // DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(DualQuantizedMatmul) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple(group_size_, bits_, mode_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool transpose_; +}; + class GatherQMM : public UnaryPrimitive { public: explicit GatherQMM( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 7c00ad8a1d..704529272a 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5433,4 +5433,41 @@ void init_ops(nb::module_& m) { Returns: array or Sequence[array]: The outputs which depend on dependencies. )pbdoc"); + m.def( + "qqmm", + &mx::qqmm, + nb::arg(), // x + nb::arg(), // w_q + "scales"_a, // scales w + "w"_a = nb::none(), // bf16 weights + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "nvfp4", + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def qqmm(x: array, w_q: array, /, scales: array, w: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Perform matrix multiplication using the quantized weight matrix ``w_q`` and the input ``x``, + which is quantized on the fly using the provided group size, bit width, and mode. + The group size, bit width, and mode must match those used to quantize ``w_q``. + The high-precision weight matrix ``w`` must be provided for gradient computation + and must match ``w_q`` before quantization. + + Args: + x (array): Input array + w_q (array): Quantized matrix packed in unsigned integers + scales (array): The scales to use per ``group_size`` elements of ``w_q`` + w (array, optional): bf16 or float32 weights used during training for + gradient computation. Default: ``None``. + group_size (int, optional): The size of the group in ``w_q`` that shares a + scale. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): The number of bits occupied by each element of + ``w_q`` in the quantized array. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + mode (str, optional): The quantization mode. Default: ``"nvfp4"``. + Returns: + array: The result of the multiplication of quantized ``x`` with ``w_q``. + )pbdoc"); } diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index ce63544b47..cf22b3b976 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -7,6 +7,14 @@ import mlx_tests +def ulp_bf16_at(x): + ax = mx.abs(x) + min_normal = mx.array(2.0**-126) + ax = mx.where(ax < min_normal, min_normal, ax) + e = mx.floor(mx.log2(ax)) + return mx.power(2.0, e - 7.0) + + class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) @@ -1015,6 +1023,99 @@ def gmm(s, x, wq): ds = mx.grad(gmm)(s, x, wq) + def test_qqmm(self): + # In mxfp8 mode, the results do not match exactly: + # fewer than 1% of output elements differ. + # This does not appear to be a systematic error. + # The error can exceed 1 ULP for very small values, + # and is always below 1 ULP for larger values. + # For nvfp4, the results match exactly. + # therefore I suspect that the discrepancy comes from + # the mxfp8 matmul implementation in cuBLASLt.. + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + dtype = mx.bfloat16 + + tests = ( + (16, "nvfp4", 4), + (32, "mxfp8", 8), + ) + shapes = ( + [64, 65, 33, 128, 256, 1024, 1024 * 8], # M + [64, 128, 256, 1024, 1024 * 8], # N + [64, 128, 256, 1024, 1024 * 8], # K + ) + for group_size, mode, bits in tests: + for M, N, K in product(*shapes): + with self.subTest( + shape=(M, N, K), + group_size=group_size, + bits=bits, + mode=mode, + ): + x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) + w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) + w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) + w_dq = mx.dequantize( + w_q, scales_w, group_size=group_size, bits=bits, mode=mode + ) + y_q = mx.qqmm( + x, w_q, scales_w, group_size=group_size, bits=bits, mode=mode + ) + x_q, scales_x = mx.quantize( + x, group_size=group_size, bits=bits, mode=mode + ) + x_dq = mx.dequantize( + x_q, scales_x, group_size=group_size, bits=bits, mode=mode + ) + y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) + ulp = ulp_bf16_at(y_hat) + error = (y_q - y_hat).abs() + self.assertEqual(y_q.shape, y_hat.shape) + self.assertTrue(mx.logical_or(error < 1e-3, error <= ulp).all()) + + def test_qqmm_vjp(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + dtype = mx.bfloat16 + M = 64 + N = 1024 + K = 512 + tests = ( + (16, "nvfp4", 4), + (32, "mxfp8", 8), + ) + x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) + c = mx.ones(shape=(M, N), dtype=dtype) + + for group_size, mode, bits in tests: + with self.subTest( + shape=(M, N, K), + group_size=group_size, + bits=bits, + mode=mode, + ): + w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) + w_q, scales_w = mx.quantize( + w, group_size=group_size, bits=bits, mode=mode + ) + + def fn(x): + return mx.qqmm( + x, w_q, scales_w, w, group_size=group_size, bits=bits, mode=mode + ) + + _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) + w_tq, scales_wt = mx.quantize( + mx.transpose(w), group_size=group_size, bits=bits, mode=mode + ) + expected_out = mx.qqmm( + c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode + ) + ulp = ulp_bf16_at(expected_out) + error = (vjp_out[0] - expected_out).abs() + self.assertTrue(mx.logical_or(error < 1e-3, error <= ulp).all()) + if __name__ == "__main__": mlx_tests.MLXTestRunner()