Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Oct 16, 2025

Stacked PRs:


[mxfp8 moe training] add triton kernel for mxfp8 dequantization

Summary

  • Traces show dequantization kernel in mxfp8 a2a is slow, this PR adds a triton kernel for this which is much faster for large "M" (local_batch_size * seq_len) which is what we care about for MoE training.

Test plan

  • pytest test/prototype/mx_formats/test_kernels.py -k mxfp8_dequant

Benchmarks

input_shape        torch_us    triton_us    torch_gbps    triton_gbps  triton_speedup
---------------  ----------  -----------  ------------  -------------  ----------------
(1, 8192, 7168)      36.864       39.968       4828.44        4453.46  0.922x
(2, 8192, 7168)     287.712       78.88        1237.32        4513.08  3.647x
(4, 8192, 7168)     560.32       150.56        1270.67        4728.9   3.722x
(8, 8192, 7168)    1110.9        297.984       1281.82        4778.67  3.728x

Copy link

pytorch-bot bot commented Oct 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3195

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 82ded0b with merge base b644211 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Oct 16, 2025
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 16, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/78 branch from 61a6f60 to 8ce6f0d Compare October 16, 2025 23:39
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes moe labels Oct 17, 2025
danielvegamyhre added a commit that referenced this pull request Oct 17, 2025
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/78 branch from 8ce6f0d to 357d20f Compare October 17, 2025 00:13
danielvegamyhre added a commit that referenced this pull request Oct 17, 2025
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/78 branch from 357d20f to b0e5061 Compare October 17, 2025 00:33
torch.bfloat16,
)
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size)
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, didn't look at the rest too closely

danielvegamyhre added a commit that referenced this pull request Oct 17, 2025
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/78 branch from b0e5061 to ba81844 Compare October 17, 2025 00:34
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/78 branch from ba81844 to 82ded0b Compare October 17, 2025 00:39
@drisspg
Copy link
Contributor

drisspg commented Oct 17, 2025

Looks like there is some ptx for going to bf16

__CUDA_HOSTDEVICE_FP8_DECL__
__nv_bfloat16_raw __nv_cvt_e8m0_to_bf16raw(const __nv_fp8_storage_t x)
{
    __nv_bfloat16_raw res;

#if (__CUDA_FP8_INTERNAL_CAN_RELY_ON_PTX_FOR_SHORTTYPESCVT__)
    unsigned short in = (unsigned short)x;
    unsigned hr = 0U;
    asm("{cvt.rn.bf16x2.ue8m0x2 %0, %1;}\n"
                : "=r"(hr)
                : "h"(in));

    res.x = (unsigned short)hr;
#else
    res.x = __internal_e8m0_to_bf16(x);
#endif

    return res;
}

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Oct 17, 2025

Looks like there is some ptx for going to bf16

Sweet where is this from? I actually tried looking in TE for PTX examples for this but all I could find was casting fp32 -> e8m0 (for computing scale): https://github.com/NVIDIA/TransformerEngine/blob/dd9433e7ad28c12f27da9770be54c9c584e85fa0/transformer_engine/common/util/ptx.cuh#L134

Will try it out later

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants