Skip to content

Conversation

@nastya236
Copy link
Contributor

@nastya236 nastya236 commented Nov 18, 2025

This PR adds a new operation mx.qqmm. The current structure is probably neither optimal nor final.

General comment

  1. For inference we want to support: qqmm(quantized weights, bf16 activations).
  2. For training (vjp) we unfortunately still need bf16 weights for two reasons:
    • We currently do not have 2D scaling for nvfp4, so we need to transpose and quantize again along a different dimension.
    • For mxfp8, the recommended recipe is to quantize with 1D blocks and keep two views of the weights (normal and transposed).

Therefore, mx.qqmm takes bf16 activations x, quantized weights w_q and their scales, and optionally bf16 weights plus group_size, mode, and bits.

In the current implementation, it is the user’s responsibility to ensure that group_size, bits, and mode match those used to quantize w_q. This is probably not ideal, and we may want to improve this in the future.

Very important details

  1. scales are repacked on every call for both weights and activations. In the future, we probably want to:

    • Avoid repacking weight scales for inference.
    • Fuse quantization and repacking, and directly pack into swizzled layout in fp_quantize.
  2. Batched qqmm is currently not supported; inputs must be 2D. For now it is implemented this way because:

    • CUBLASLT_BATCH_MODE_STRIDED is not supported for scales.
    • CUBLASLT_BATCH_MODE_POINTER_ARRAY is not supported for arrays with block scaling.

We almost certainly want to add batching in the future, but for simplicity batch_count = 1 for now.

  1. qqmm is always executed in TN layout (transpose = True).
    There are several reasons for this, but mainly we always quantize along the reduction dimension, which currently ends up being the last dimension.. I am happy to change this if you think that it is useful to support all layouts for mxfp8 for example. Also, only on B200 only TN layout is supported for nvfp4 and mxfp4.

Notes

  1. There are some changes to cublas_gemm.cpp: I grouped all common cuBLAS-related functions into a separate helper class in cublas_utils.cpp.
  2. mxfp8 qqmm behaves slightly differently from nvfp4: sometimes, for <<1% of the output elements, the result differs from the dequantized reference by exactly 1 ULP in bf16 (see python/tests/test_quantized.py, line 1027). I do not think this is a bug because:
  • For nvfp4 the output matches exactly for every tested shape.
  • The difference is not structured: there is no clear pattern, and the indices of the affected elements change with the seed.
  • The mismatch is always exactly 1 ULP.

Therefore, I attribute this to differences in accumulation on tensor cores or other numerical details we do not control.

What this PR lacks [these] because I first want to make sure the rest of the API looks reasonable

  1. addmm -- basically c is always nullptr
  2. nn.QQLinear
  3. nn.Linear.to_qqlinear - or similar method to cast to nn.QQLinear (naming is questionable)

Examples are in python/tests/test_quantized.py.
Happy to iterate and change anything here!

@nastya236 nastya236 marked this pull request as draft November 18, 2025 18:43
@nastya236 nastya236 changed the title qqmm [WIP] qqmm Nov 18, 2025
@nastya236 nastya236 marked this pull request as ready for review November 29, 2025 20:20
@nastya236 nastya236 changed the title [WIP] qqmm qqmm Nov 29, 2025
bool transpose_;
};

class DualQuantizedMatmul : public UnaryPrimitive {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of a nit but I think it makes sense to rename this to QuantizedQuantizedMatmul or QQMatmul to better match the name of the op. Dual is also kind of an overloaded term.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. I think QQMatmul is better, because then the primitive name and the op name are aligned.

bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(group_size_, bits_, mode_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transpose_ should be part of the state here.

Copy link
Contributor Author

@nastya236 nastya236 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is a bit unclear and probably should be changed.. transpose is not a member variable, qqmm is always executed in TN layout (transpose = True). I did it this way because, at the moment, quantization always produces a row-major tensor with the last dimension packed, and TN is the only layout supported for mxfp4 and nvfp4 on B200.


ds = mx.grad(gmm)(s, x, wq)

def test_qqmm(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests will should only be run for now if mx.cuda.is_available().

And in fact I'm not sure what the behavior is on older hardware and CUDA toolkits. Do you know what the minimum requirements there are?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants