Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Dec 6, 2025

Description

All of the supported block-scaled tensor formats (MXFP8, NVFP4, DSv3 FP8) have two ways of ordering their scaling factors:

  • "Compact" ordering for quantization, dequantization, and communication
  • "Swizzled" ordering for GEMM

The core infrastructure handles this in an ad hoc way, blindly assuming that the "right" scale ordering is used for the different operations. The PyTorch infrastructure only supports MXFP8 and NVFP4 scales are in compact order, although DSv3 FP8 does have awareness of "compact" and "GEMM-ready" formats. This situation makes it hard to implement fused kernels that can bypass the swizzle kernel.

This PR adds a with_gemm_swizzled_scales field in the C++ tensor class so that the core infrastructure can distinguish between the different scale orderings. It also adds this field in the PyTorch quantized tensor classes, and exposes a optimize_for_gemm option in the quantizer so that we can create tensors that do not need communication or checkpointing. Finally, it rips out all the DSv3 FP8 infrastructure for the compact format, which is no longer necessary.

Progress

  • MXFP8
  • DSv3 FP8
  • NVFP4
  • Add option to pre-swizzle weights
  • Pre-swizzle activations
  • Fused MXFP8 quantize + swizzle

Closes #2446.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Support GEMM swizzled scales in C++ tensor class
  • Support GEMM swizzled scales in PyTorch quantized tensor classes
  • Support optimize_for_gemm option in PyTorch quantizer
  • Expose PyTorch function to swizzle scales
  • Support MXFP8 quantization with pre-swizzled scales
  • Enable fused quantize+swizzle kernels in linear module and related
  • Remove DSv3 FP8 compact data format. It was used to avoid all-gather interleaving, which we can now fix with the swap-first-dims kernel.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 force-pushed the tmoon/pre-swizzled-scales branch from d274220 to 52ce3a4 Compare December 6, 2025 02:53
@timmoon10 timmoon10 added enhancement New feature or request refactor labels Dec 6, 2025
@timmoon10 timmoon10 force-pushed the tmoon/pre-swizzled-scales branch from 4925b63 to 1de4b5e Compare December 10, 2025 07:19
@timmoon10

This comment was marked as outdated.

@timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

}
}

void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a v2 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to break the existing APIs. That said, this PR isn't fully backward-compatible because the GEMM no longer secretly assumes that MXFP8 scales are swizzled.

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.

Signed-off-by: Tim Moon <[email protected]>
@greptile-apps

This comment was marked as resolved.

@timmoon10

This comment was marked as outdated.

@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
def test_quantization_1D_block_tiling_with_compact_data_and_scales(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need this test anymore?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 block-scaling doesn't require a compact format anymore. Now it's always GEMM-ready.

as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
"""LayerNorm/RMSNorm + Linear + SwiGLU + Linear"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how this change is relevant here to be honest.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before these changes, we didn't have a test for the SwiGLU kernel with quantized, GEMM-ready output. We test the unquantized SwiGLU kernel here:

def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):

We test the quantized SwiGLU kernel with non-GEMM-ready output here:

It's also a good sanity-check for te.Sequential.

Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you always choosing the gemm ready version here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it have a similar logic to check the with_gemm_swizzled_scales?

Copy link
Collaborator Author

@timmoon10 timmoon10 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way FP8 block-scaling was implemented, the only advantage of the compact format over the GEMM-ready format was support for all-gathers. However, all-gather support only required a simple change to use the swap-first-dims kernel. Instead of propagating this PR's changes throughout the FP8 block-scaling logic, I found it simpler to just remove the compact format entirely.


const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;

// Zero out swizzled scales if padding is needed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we already do this though when we create the scaling factor tensors? This seems like a pessimization, as we will do this now every time instead of once - we should instead note the requirement for the scaling factors to be zeroed out before the quantization in the docs of the quantization call.

Copy link
Collaborator Author

@timmoon10 timmoon10 Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can see, we just allocate scales with at::empty:

if (rowwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
}
if (columnwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(),
columnwise_scale_inv_shape.end());
columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
}

Requiring zeroing out in this case is unintuitive, especially since it's not needed in the unpadded case. I figure that if the model is small enough for the padding to be relevant, then it's probably too small to see the full perf benefit of MXFP8 anyways.


const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;

// Zero out swizzled scales if padding is needed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as with the gated.

sizeof(NVTEBasicTensor), // kNVTERowwiseScaleInv
sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv
sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax
sizeof(bool) // kNVTEWithGEMMSwizzledScales
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is implementation defined so we should not rely on this for our API - please use a type with a defined size for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced bool with uint8_t and int with int32_t.


namespace {

void reset_tensor_data(transformer_engine::TensorWrapper &tensor, bool rowwise, bool columnwise) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need that? We cache the tensors anyway, so there should not be overhead of just creating a new TensorWrapper.


} // namespace

std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just output a new TensorWrapper here with those new scaling factors and then we would not need that reset thing I think?

return ((value + multiple - 1) / multiple) * multiple;
}

size_t ceildiv(size_t numer, size_t denom) { return (numer + denom - 1) / denom; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have a divup function already?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer if the PyTorch extensions minimize their reliance on internal headers from the core lib (mainly common.h), but I don't have a strong opinion.

greptile-apps[bot]

This comment was marked as outdated.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/gemm/config.h, line 17-30 (link)

    logic: Type mismatch between struct field declarations and attr_sizes array. The struct declares bool fields but attr_sizes specifies sizeof(uint8_t), and sm_count is int but array uses sizeof(int32_t). This will cause incorrect serialization/memory calculations if the attr_sizes array is used for byte layout.

67 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +57 to +70
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This transpose logic is duplicated identically in the backward pass (lines 147-160). Consider extracting it into a helper function to reduce code duplication and improve maintainability.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Collaborator Author

@timmoon10 timmoon10 Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra layer of abstraction would hurt maintainability more than just duplicating <20 LoC.

const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;

const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
const bool rowwise_scale_is_within_bounds = SCALE_DIM_X * scales_offset_X_rowwise < cols;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Fixed potential out-of-bounds access by multiplying SCALE_DIM_X * scales_offset_X_rowwise instead of just scales_offset_X_rowwise when checking column bounds.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the kernel implementation is outside of scope for this PR. The tests already pass, so this is probably already correct, albeit with unclear variable names.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

/te-ci L1

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from a TE/JAX perspective, thanks! I can't see if TE/JAX CI passed since a new pipeline was recently launched, but looks like you're running JAX tests as part of your launch so all good

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.12.0 enhancement New feature or request MoE performance Performance issues refactor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support MXFP8/NVFP4 tensors with pre-swizzled scales

4 participants