Skip to content

[bc-breaking] Generalize QAT configs beyond intx quantization #2608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

andrewor14
Copy link
Contributor

Summary: Current QAT APIs are highly tailored to intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that.

In particular, we are moving references to integer quantization from the main IntXQuantizationAwareTrainingConfig to the inner FakeQuantizeConfigs:

IntXQuantizationAwareTrainingConfig -> QuantizationAwareTrainingConfig
FakeQuantizeConfig -> IntxFakeQuantizeConfig

In the future, we will have other types of FakeQuantizeConfig for float dtypes that users can pass in instead of the existing Intx one.

BC-breaking notes: For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract FakeQuantizeConfigBase instead. For the most part, this abstract class will be hidden from the user.

Before:

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
train(model)
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))

After:

activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
train(model)
quantize_(model, FromQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))

Test Plan:
python test/quantization/test_qat.py

**Summary:** Current QAT APIs are highly tailored to intx
quantization, but we plan to extend QAT to other dtypes such
as fp8 and nvfp4 in the near future. This is the necessary
refactor before that.

In particular, we are moving references to integer quantization
from the main `IntXQuantizationAwareTrainingConfig` to the
inner `FakeQuantizeConfig`s:

```
IntXQuantizationAwareTrainingConfig -> QuantizationAwareTrainingConfig
FakeQuantizeConfig -> IntxFakeQuantizeConfig
```

In the future, we will have other types of FakeQuantizeConfig
for float dtypes that users can pass in instead of the existing
Intx one.

**BC-breaking notes:** For BC, we keep around the old names to
reference the new ones. However, this commit is still BC-breaking
in the sense that a few APIs now accept the abstract
`FakeQuantizeConfigBase` instead. For the most part, this abstract
class will be hidden from the user.

Before:
```
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
train(model)
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```

After:
```
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
train(model)
quantize_(model, FromQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```

**Test Plan:**
python test/quantization/test_qat.py
@andrewor14 andrewor14 requested review from jerryzh168 and drisspg July 25, 2025 23:31
Copy link

pytorch-bot bot commented Jul 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2608

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2125f2f with merge base 12ff479 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2025
@andrewor14 andrewor14 added the topic: bc-breaking Use this tag if this PR breaks backward compatibility label Jul 25, 2025
@jerryzh168
Copy link
Contributor

since we are planning to break BC, I'm wondering if is it possible to merge

quantize_(model, FromQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))

or just to extend Int8DynamicActivationInt4WeightConfig with different steps like awq: #2400

@andrewor14 andrewor14 marked this pull request as draft July 28, 2025 14:36
@andrewor14
Copy link
Contributor Author

Changing the design, closing this for now. Thanks @jerryzh168

@andrewor14 andrewor14 closed this Jul 28, 2025
@andrewor14 andrewor14 deleted the generalize-qat-config branch July 29, 2025 14:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bc-breaking Use this tag if this PR breaks backward compatibility
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants