Skip to content

Check numerical equivalence / closeness between different kernel preferences #2651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Aug 1, 2025

Stacked PRs:


Check numerical equivalence / closeness between different kernel preferences

Summary:
This PR checks different kernel preferences for Float8Tensor are similar in numerics
(AUTO, TORCH and FBGEMM)

triton implementation and torchao implementation are a bit different right now actually, need to decide if we should fix it or not

  1. difference in quantize op
    main difference seems to be the triton implementation is using:
a_scale = MAX_FP8 / max_abs
then do
a_scale = 1.0 / a_scale
a_fp8 = a * a_scale

while torch is doing:

a_scale = max_abs / MAX_FP8
a_fp8 = a / a_scale

Also the hp_value_lb and hp_value_ub settings are slightly different

triton choose scale and quantize code: https://github.com/pytorch/FBGEMM/blob/a4286c01ef01dad435b2ec8798605127d3032cd8/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L2382-L2392

torchao choose scale and quantize code:

def _choose_scale_float8(

def _quantize_affine_float8(

  1. (potentially) difference in matrix multiplication ops

TORCH and AUTO/FBGEMM are using different quantized mm ops

Added a reverse option to bring sqnr closer:

granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(64.5000, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(68., device='cuda:0', dtype=torch.bfloat16)

Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Aug 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2651

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 1506c0d with merge base d2e791b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 2534529 to c608b78 Compare August 1, 2025 00:53
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/14 branch from e19cb46 to 5ae457c Compare August 1, 2025 00:53
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 1, 2025
@jerryzh168 jerryzh168 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 1, 2025
for i in range(1, len(kp_and_res)):
kp, res = kp_and_res[i]
self.assertTrue(
compute_error(res, kp_and_res[0][1]) > 28,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @vkuzo we don't have equivalence yet due to some differences in implementation, do you think we should match torchao quant primitives (choose_scale_float8 + quantize_float8) and triton ones?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know what the differences are?

IMO we should also choose either TORCH or FBGEMM (but not AUTO) as the reference, and match others to the reference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah see PR summary for differences

I can update and use TORCH as reference

@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 1, 2025 00:56
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from c608b78 to 42a767c Compare August 1, 2025 00:56
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 1, 2025 00:56
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 1, 2025 03:38
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 42a767c to 65a4f84 Compare August 1, 2025 03:38
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 1, 2025 03:38
@jerryzh168 jerryzh168 requested a review from vkuzo August 1, 2025 04:43
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 1, 2025 21:12
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 65a4f84 to ba8efe2 Compare August 1, 2025 21:13
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 1, 2025 21:13
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 2, 2025 01:31
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 2, 2025 01:31
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 4, 2025 17:30
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from ba8efe2 to 1720743 Compare August 4, 2025 17:30
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 4, 2025 17:30
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 4, 2025 18:15
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 1720743 to 36fce5e Compare August 4, 2025 18:15
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 4, 2025 18:15
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 4, 2025 22:14
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 36fce5e to 9824504 Compare August 4, 2025 22:14
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 60eae40 to 2e63f70 Compare August 6, 2025 19:12
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 6, 2025 19:12
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 6, 2025 21:18
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 2e63f70 to 48bae28 Compare August 6, 2025 21:18
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 6, 2025 21:18
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 6, 2025 22:09
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 48bae28 to ab6d944 Compare August 6, 2025 22:10
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 6, 2025 22:10
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 6, 2025 22:16
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from ab6d944 to 57b2086 Compare August 6, 2025 22:16
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 6, 2025 22:16
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 6, 2025 23:27
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 57b2086 to f6bac49 Compare August 6, 2025 23:27
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 6, 2025 23:27
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 7, 2025 02:57
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from f6bac49 to c7f8ff0 Compare August 7, 2025 02:57
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/14 August 7, 2025 02:57
jerryzh168 added a commit that referenced this pull request Aug 7, 2025
…erences

Summary:
This PR checks different kernel preferences for Float8Tensor are similar in numerics
(AUTO, TORCH and FBGEMM)

triton implementation and torchao implementation are a bit different right now actually, need to decide if we should fix it or not

1. difference in quantize op
main difference seems to be the triton implementation is using:
```
a_scale = MAX_FP8 / max_abs
then do
a_scale = 1.0 / a_scale
a_fp8 = a * a_scale
```

while torch is doing:
```
a_scale = max_abs / MAX_FP8
a_fp8 = a / a_scale
```

Also the hp_value_lb and hp_value_ub settings are slightly different

triton choose scale and quantize code: https://github.com/pytorch/FBGEMM/blob/a4286c01ef01dad435b2ec8798605127d3032cd8/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L2382-L2392

torchao choose scale and quantize code:
https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2183
https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2283

2. (potentially) difference in matrix multiplication ops

TORCH and AUTO/FBGEMM are using different quantized mm ops

Added a reverse option to bring sqnr closer:
```
granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(64.5000, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(68., device='cuda:0', dtype=torch.bfloat16)
```
Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2651, branch: jerryzh168/stack/15
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from c7f8ff0 to 847259b Compare August 7, 2025 02:58
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/14 to main August 7, 2025 02:58

# comparing numerics between different kernel preferences, using TORCH as the standard
kp_and_res = list(quantized_outputs.items())
for i in range(1, len(kp_and_res)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explicitly peel the first iteration so its very obvious that we see its the reference and just iterate on the rest of the keys

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK updated to run the ref separately

@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch 3 times, most recently from 9f5fb3f to 873b120 Compare August 7, 2025 04:29
…erences

Summary:
This PR checks different kernel preferences for Float8Tensor are similar in numerics
(AUTO, TORCH and FBGEMM)

triton implementation and torchao implementation are a bit different right now actually, need to decide if we should fix it or not

1. difference in quantize op
main difference seems to be the triton implementation is using:
```
a_scale = MAX_FP8 / max_abs
then do
a_scale = 1.0 / a_scale
a_fp8 = a * a_scale
```

while torch is doing:
```
a_scale = max_abs / MAX_FP8
a_fp8 = a / a_scale
```

Also the hp_value_lb and hp_value_ub settings are slightly different

triton choose scale and quantize code: https://github.com/pytorch/FBGEMM/blob/a4286c01ef01dad435b2ec8798605127d3032cd8/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L2382-L2392

torchao choose scale and quantize code:
https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2183
https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2283

2. (potentially) difference in matrix multiplication ops

TORCH and AUTO/FBGEMM are using different quantized mm ops

Added a reverse option to bring sqnr closer:
```
granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((128,), 256, 128)  kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.AUTO tensor(64.5000, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow()  sizes: ((32, 128), 64, 256)  kp: KernelPreference.FBGEMM tensor(68., device='cuda:0', dtype=torch.bfloat16)
```
Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2651, branch: jerryzh168/stack/15
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/15 branch from 873b120 to 1506c0d Compare August 7, 2025 20:56
@jerryzh168 jerryzh168 merged commit 1114ca0 into main Aug 7, 2025
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants