-
Notifications
You must be signed in to change notification settings - Fork 349
[mxfp8 moe training] add triton kernel for mxfp8 dequantization #3195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3195
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 82ded0b with merge base b644211 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
61a6f60
to
8ce6f0d
Compare
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
8ce6f0d
to
357d20f
Compare
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
357d20f
to
b0e5061
Compare
torch.bfloat16, | ||
) | ||
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size) | ||
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, didn't look at the rest too closely
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
b0e5061
to
ba81844
Compare
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
ba81844
to
82ded0b
Compare
Looks like there is some ptx for going to bf16
|
Sweet where is this from? I actually tried looking in TE for PTX examples for this but all I could find was casting fp32 -> e8m0 (for computing scale): https://github.com/NVIDIA/TransformerEngine/blob/dd9433e7ad28c12f27da9770be54c9c584e85fa0/transformer_engine/common/util/ptx.cuh#L134 Will try it out later |
Stacked PRs:
[mxfp8 moe training] add triton kernel for mxfp8 dequantization
Summary
Test plan
Benchmarks