-
Notifications
You must be signed in to change notification settings - Fork 1.4k
qqmm #2789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
qqmm #2789
Conversation
| bool transpose_; | ||
| }; | ||
|
|
||
| class DualQuantizedMatmul : public UnaryPrimitive { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit of a nit but I think it makes sense to rename this to QuantizedQuantizedMatmul or QQMatmul to better match the name of the op. Dual is also kind of an overloaded term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree. I think QQMatmul is better, because then the primitive name and the op name are aligned.
| bool is_equivalent(const Primitive& other) const override; | ||
| std::vector<Shape> output_shapes(const std::vector<array>& inputs) override; | ||
| auto state() const { | ||
| return std::make_tuple(group_size_, bits_, mode_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transpose_ should be part of the state here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is a bit unclear and probably should be changed.. transpose is not a member variable, qqmm is always executed in TN layout (transpose = True). I did it this way because, at the moment, quantization always produces a row-major tensor with the last dimension packed, and TN is the only layout supported for mxfp4 and nvfp4 on B200.
|
|
||
| ds = mx.grad(gmm)(s, x, wq) | ||
|
|
||
| def test_qqmm(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests will should only be run for now if mx.cuda.is_available().
And in fact I'm not sure what the behavior is on older hardware and CUDA toolkits. Do you know what the minimum requirements there are?
This PR adds a new operation
mx.qqmm. The current structure is probably neither optimal nor final.General comment
qqmm(quantized weights, bf16 activations).nvfp4, so we need to transpose and quantize again along a different dimension.mxfp8, the recommended recipe is to quantize with 1D blocks and keep two views of the weights (normal and transposed).Therefore,
mx.qqmmtakes bf16 activationsx, quantized weightsw_qand their scales, and optionally bf16 weights plusgroup_size,mode, andbits.In the current implementation, it is the user’s responsibility to ensure that
group_size,bits, andmodematch those used to quantizew_q. This is probably not ideal, and we may want to improve this in the future.Very important details
scalesare repacked on every call for both weights and activations. In the future, we probably want to:fp_quantize.Batched
qqmmis currently not supported; inputs must be 2D. For now it is implemented this way because:CUBLASLT_BATCH_MODE_STRIDEDis not supported for scales.CUBLASLT_BATCH_MODE_POINTER_ARRAYis not supported for arrays with block scaling.We almost certainly want to add batching in the future, but for simplicity
batch_count = 1for now.qqmmis always executed in TN layout (transpose = True).There are several reasons for this, but mainly we always quantize along the reduction dimension, which currently ends up being the last dimension.. I am happy to change this if you think that it is useful to support all layouts for
mxfp8for example. Also, only on B200 only TN layout is supported fornvfp4andmxfp4.Notes
cublas_gemm.cpp: I grouped all common cuBLAS-related functions into a separate helper class incublas_utils.cpp.mxfp8qqmmbehaves slightly differently fromnvfp4: sometimes, for <<1% of the output elements, the result differs from the dequantized reference by exactly 1 ULP in bf16 (seepython/tests/test_quantized.py, line 1027). I do not think this is a bug because:nvfp4the output matches exactly for every tested shape.Therefore, I attribute this to differences in accumulation on tensor cores or other numerical details we do not control.
What this PR lacks [these] because I first want to make sure the rest of the API looks reasonable
addmm-- basicallycis alwaysnullptrnn.QQLinearnn.Linear.to_qqlinear- or similar method to cast tonn.QQLinear(naming is questionable)Examples are in
python/tests/test_quantized.py.Happy to iterate and change anything here!