Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/api_ref_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ Legacy QAT APIs
:toctree: generated/
:nosignatures:

IntXQuantizationAwareTrainingConfig
FromIntXQuantizationAwareTrainingConfig
Int4WeightOnlyQATQuantizer
linear.Int4WeightOnlyQATLinear
Int8DynActInt4WeightQATQuantizer
Expand Down
11 changes: 5 additions & 6 deletions test/prototype/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
Int4WeightOnlyEmbeddingQATQuantizer,
IntxFakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
Expand Down Expand Up @@ -257,7 +256,7 @@ def test_identical_to_IntxWeightOnlyConfig(
],
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
)
def test_identical_to_IntXQuantizationAwareTrainingConfig(
def test_identical_to_QATConfig(
self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype
):
# ASYMMETRIC in QAT is very different that PTQ configs
Expand Down Expand Up @@ -288,12 +287,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
)
quantize_(
model,
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
QATConfig(weight_config=weight_config, step="prepare"),
embedding_filter,
)
prepared_out = model(indices)

quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
quantize_(model, QATConfig(step="convert"), embedding_filter)
quantize_(
model,
IntxWeightOnlyConfig(
Expand Down Expand Up @@ -355,7 +354,7 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer(
prepared_out = model(indices)

# Convert model method 1
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
quantize_(model, QATConfig(step="convert"), embedding_filter)
quantize_(
model,
IntxWeightOnlyConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
Int8DynActInt4WeightQATQuantizer,
IntxFakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationInt4WeightConfig,
Expand Down Expand Up @@ -499,7 +498,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
],
)
def test_identical_to_IntXQuantizationAwareTrainingConfig(
def test_identical_to_QATConfig(
self,
weight_dtype,
group_size,
Expand Down Expand Up @@ -545,7 +544,11 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(

quantize_(
model,
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
),
)
try:
prepared_out = model(activations)
Expand All @@ -555,7 +558,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
return
raise e

quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, QATConfig(step="convert"))
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
Expand Down Expand Up @@ -606,7 +609,7 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer(
prepared_out = model(activations)

# Convert model method 1
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, QATConfig(step="convert"))
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
Expand Down
92 changes: 0 additions & 92 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import copy
import unittest
import warnings
from typing import List, Type

import torch
Expand Down Expand Up @@ -39,8 +38,6 @@
)
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
QATStep,
initialize_fake_quantizers,
Expand Down Expand Up @@ -1718,95 +1715,6 @@ def test_qat_fp8a4w_quantizer(self):
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))

def test_legacy_quantize_api_e2e(self):
"""
Test that the following two APIs are numerically equivalent:

New API:
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))

Old API:
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig())
"""
group_size = 16
torch.manual_seed(self.SEED)
m = M()
baseline_model = copy.deepcopy(m)

# Baseline prepare
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
quantize_(baseline_model, old_qat_config)

# QATConfig prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
quantize_(m, QATConfig(base_config, step="prepare"))

# Compare prepared values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

# Baseline convert
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
quantize_(baseline_model, base_config)

# quantize_ convert
quantize_(m, QATConfig(base_config, step="convert"))

# Compare converted values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

def test_qat_api_deprecation(self):
"""
Test that the appropriate deprecation warning is logged exactly once per class.
"""
from torchao.quantization.qat import (
FakeQuantizeConfig,
FakeQuantizer,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)

# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Map from deprecated API to the args needed to instantiate it
deprecated_apis_to_args = {
IntXQuantizationAwareTrainingConfig: (),
FromIntXQuantizationAwareTrainingConfig: (),
intx_quantization_aware_training: (),
from_intx_quantization_aware_training: (),
FakeQuantizeConfig: (torch.int8, "per_channel"),
FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),),
}

with warnings.catch_warnings(record=True) as _warnings:
# Call each deprecated API twice
for cls, args in deprecated_apis_to_args.items():
cls(*args)
cls(*args)

# Each call should trigger the warning only once
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
for w in _warnings:
self.assertIn(
"is deprecated and will be removed in a future release",
str(w.message),
)

def test_qat_api_convert_no_quantization(self):
"""
Test that `QATConfig(step="convert")` swaps back to nn modules without quantization.
Expand Down
2 changes: 0 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
PlainLayout,
TensorCoreTiledLayout,
UIntXWeightOnlyConfig,
intx_quantization_aware_training,
quantize_,
swap_conv2d_1x1_to_linear,
)
Expand Down Expand Up @@ -119,7 +118,6 @@
"ALL_AUTOQUANT_CLASS_LIST",
# top level API - manual
"quantize_",
"intx_quantization_aware_training",
"swap_conv2d_1x1_to_linear",
"Int4DynamicActivationInt4WeightConfig",
"Int8DynamicActivationInt4WeightConfig",
Expand Down
13 changes: 0 additions & 13 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from .api import (
ComposableQATQuantizer,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
QATStep,
from_intx_quantization_aware_training,
initialize_fake_quantizers,
intx_quantization_aware_training,
)
from .embedding import (
FakeQuantizedEmbedding,
Int4WeightOnlyEmbeddingQATQuantizer,
)
from .fake_quantize_config import (
FakeQuantizeConfig,
FakeQuantizeConfigBase,
Float8FakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
from .fake_quantizer import (
FakeQuantizer,
FakeQuantizerBase,
Float8FakeQuantizer,
IntxFakeQuantizer,
Expand Down Expand Up @@ -50,11 +44,4 @@
"Int4WeightOnlyEmbeddingQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
# for BC
"FakeQuantizer",
"FakeQuantizeConfig",
"from_intx_quantization_aware_training",
"FromIntXQuantizationAwareTrainingConfig",
"intx_quantization_aware_training",
"IntXQuantizationAwareTrainingConfig",
]
Loading
Loading