-
Notifications
You must be signed in to change notification settings - Fork 607
Add logic for block-scaled tensors with GEMM swizzled scales #2486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
d274220 to
52ce3a4
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
4925b63 to
1de4b5e
Compare
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
This comment was marked as outdated.
This comment was marked as outdated.
| } | ||
| } | ||
|
|
||
| void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a v2 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to break the existing APIs. That said, this PR isn't fully backward-compatible because the GEMM no longer secretly assumes that MXFP8 scales are swizzled.
Signed-off-by: Tim Moon <[email protected]>
Update copyright years. Tweak comments. Fix various complaints from @greptile-apps. Signed-off-by: Tim Moon <[email protected]>
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as outdated.
This comment was marked as outdated.
| @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) | ||
| @pytest.mark.parametrize("eps", [0], ids=["eps_0"]) | ||
| @pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) | ||
| def test_quantization_1D_block_tiling_with_compact_data_and_scales( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we need this test anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP8 block-scaling doesn't require a compact format anymore. Now it's always GEMM-ready.
| as when chaining multiple modules it is hard to validate | ||
| numerical accuracy. | ||
| """ | ||
| """LayerNorm/RMSNorm + Linear + SwiGLU + Linear""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how this change is relevant here to be honest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before these changes, we didn't have a test for the SwiGLU kernel with quantized, GEMM-ready output. We test the unquantized SwiGLU kernel here:
TransformerEngine/tests/pytorch/test_numerics.py
Line 1633 in 69636a0
| def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias): |
We test the quantized SwiGLU kernel with non-GEMM-ready output here:
| def test_activation( |
It's also a good sanity-check for te.Sequential.
| Float8BlockScaleTensorFormat::COMPACT); | ||
| rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT | ||
| : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; | ||
| rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you always choosing the gemm ready version here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it have a similar logic to check the with_gemm_swizzled_scales?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way FP8 block-scaling was implemented, the only advantage of the compact format over the GEMM-ready format was support for all-gathers. However, all-gather support only required a simple change to use the swap-first-dims kernel. Instead of propagating this PR's changes throughout the FP8 block-scaling logic, I found it simpler to just remove the compact format entirely.
|
|
||
| const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; | ||
|
|
||
| // Zero out swizzled scales if padding is needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we already do this though when we create the scaling factor tensors? This seems like a pessimization, as we will do this now every time instead of once - we should instead note the requirement for the scaling factors to be zeroed out before the quantization in the docs of the quantization call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I can see, we just allocate scales with at::empty:
TransformerEngine/transformer_engine/pytorch/csrc/quantizer.cpp
Lines 886 to 897 in 28d08a7
| if (rowwise_usage) { | |
| const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), | |
| rowwise_scale_inv_shape.end()); | |
| rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); | |
| rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); | |
| } | |
| if (columnwise_usage) { | |
| const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), | |
| columnwise_scale_inv_shape.end()); | |
| columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); | |
| columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); | |
| } |
Requiring zeroing out in this case is unintuitive, especially since it's not needed in the unpadded case. I figure that if the model is small enough for the padding to be relevant, then it's probably too small to see the full perf benefit of MXFP8 anyways.
|
|
||
| const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; | ||
|
|
||
| // Zero out swizzled scales if padding is needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as with the gated.
transformer_engine/common/common.h
Outdated
| sizeof(NVTEBasicTensor), // kNVTERowwiseScaleInv | ||
| sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv | ||
| sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax | ||
| sizeof(bool) // kNVTEWithGEMMSwizzledScales |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is implementation defined so we should not rely on this for our API - please use a type with a defined size for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced bool with uint8_t and int with int32_t.
|
|
||
| namespace { | ||
|
|
||
| void reset_tensor_data(transformer_engine::TensorWrapper &tensor, bool rowwise, bool columnwise) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need that? We cache the tensors anyway, so there should not be overhead of just creating a new TensorWrapper.
|
|
||
| } // namespace | ||
|
|
||
| std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could just output a new TensorWrapper here with those new scaling factors and then we would not need that reset thing I think?
| return ((value + multiple - 1) / multiple) * multiple; | ||
| } | ||
|
|
||
| size_t ceildiv(size_t numer, size_t denom) { return (numer + denom - 1) / denom; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we have a divup function already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer if the PyTorch extensions minimize their reliance on internal headers from the core lib (mainly common.h), but I don't have a strong opinion.
Miscellaneous review suggestions from @ptrendx. Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/gemm/config.h, line 17-30 (link)logic: Type mismatch between struct field declarations and
attr_sizesarray. The struct declaresboolfields butattr_sizesspecifiessizeof(uint8_t), andsm_countisintbut array usessizeof(int32_t). This will cause incorrect serialization/memory calculations if theattr_sizesarray is used for byte layout.
67 files reviewed, 6 comments
| if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { | ||
| // FP8 kernel only populates row-wise data, so perform | ||
| // transpose separately if needed | ||
| Tensor transpose_in, transpose_out, dummy; | ||
| transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; | ||
| transpose_in.data.dptr = output->data.dptr; | ||
| transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; | ||
| transpose_in.data.dtype = output->data.dtype; | ||
| transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; | ||
| transpose_out.data.dptr = output->columnwise_data.dptr; | ||
| transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; | ||
| transpose_out.data.dtype = output->data.dtype; | ||
| detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: This transpose logic is duplicated identically in the backward pass (lines 147-160). Consider extracting it into a helper function to reduce code duplication and improve maintainability.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extra layer of abstraction would hurt maintainability more than just duplicating <20 LoC.
| const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; | ||
|
|
||
| const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; | ||
| const bool rowwise_scale_is_within_bounds = SCALE_DIM_X * scales_offset_X_rowwise < cols; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Fixed potential out-of-bounds access by multiplying SCALE_DIM_X * scales_offset_X_rowwise instead of just scales_offset_X_rowwise when checking column bounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the kernel implementation is outside of scope for this PR. The tests already pass, so this is probably already correct, albeit with unclear variable names.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <[email protected]>
|
/te-ci L1 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM from a TE/JAX perspective, thanks! I can't see if TE/JAX CI passed since a new pipeline was recently launched, but looks like you're running JAX tests as part of your launch so all good
|
/te-ci pytorch |
Description
All of the supported block-scaled tensor formats (MXFP8, NVFP4, DSv3 FP8) have two ways of ordering their scaling factors:
The core infrastructure handles this in an ad hoc way, blindly assuming that the "right" scale ordering is used for the different operations. The PyTorch infrastructure only supports MXFP8 and NVFP4 scales are in compact order, although DSv3 FP8 does have awareness of "compact" and "GEMM-ready" formats. This situation makes it hard to implement fused kernels that can bypass the swizzle kernel.
This PR adds a
with_gemm_swizzled_scalesfield in the C++ tensor class so that the core infrastructure can distinguish between the different scale orderings. It also adds this field in the PyTorch quantized tensor classes, and exposes aoptimize_for_gemmoption in the quantizer so that we can create tensors that do not need communication or checkpointing. Finally, it rips out all the DSv3 FP8 infrastructure for the compact format, which is no longer necessary.Progress
Add option to pre-swizzle weightsCloses #2446.
Type of change
Changes
Please list the changes introduced in this PR:
optimize_for_gemmoption in PyTorch quantizerChecklist: