Skip to content

Make AWQ more general #2400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
config_from_dict,
config_to_dict,
)
from torchao.prototype.awq import (
AWQConfig,
AWQStep,
)
from torchao.quantization.quant_api import (
FbgemmConfig,
Float8DynamicActivationFloat8WeightConfig,
Expand Down Expand Up @@ -79,6 +83,8 @@
"linear2": Int8DynamicActivationInt4WeightConfig(),
}
),
AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING),
AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"),
]

if TORCH_VERSION_AT_LEAST_2_6:
Expand Down
345 changes: 203 additions & 142 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,30 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import os
from copy import deepcopy
import copy
import tempfile
import unittest

import pytest
import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.quantization import quantize_
from torchao.testing.utils import skip_if_rocm
from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
_is_fbgemm_genai_gpu_available,
)

if TORCH_VERSION_AT_LEAST_2_3:
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
self.linear3 = torch.nn.Linear(k, 1, bias=False)
self.linear3 = torch.nn.Linear(k, 64, bias=False)

def example_inputs(
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"
Expand All @@ -44,137 +45,197 @@ def forward(self, x):
return x


devices = ["cpu", "cuda"]
# torch.uintx dtypes are introduced in 2.3
if TORCH_VERSION_AT_LEAST_2_3:
qdtypes = (torch.uint4, torch.uint7)
else:
qdtypes = ()


@pytest.fixture(autouse=True)
def run_before_and_after_tests():
yield
torch._dynamo.reset() # reset cache between tests


@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("qdtype", qdtypes)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
@pytest.mark.skip("Temporarily skipping to unpin nightiles")
def test_awq_loading(device, qdtype):
if qdtype == torch.uint4 and device == "cpu":
pytest.skip("uint4 not supported on cpu")

dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
quant_dtype = qdtype
group_size = 128
n_calibration_examples = 10
n_validation_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
dtype=original_dtype,
device=device,
)
calibration_data = dataset[:n_calibration_examples]

# calibrate
insert_awq_observer_(
m,
n_validation_examples,
sequence_length,
quant_dtype=quant_dtype,
group_size=group_size,
)

for example in calibration_data:
m(example.to(device))

# quantize
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(
m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear
)

model_save_path = "awq_model.pth"
torch.save(m, model_save_path)
loaded_model = torch.load(model_save_path)
os.remove(model_save_path)

if torch.cuda.is_available():
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(),
reason="need to install fbgemm_gpu_genai package",
)
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_6,
reason="torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig",
)
class TestAWQ(TestCase):
def test_awq_config(self):
base_config = Int4WeightOnlyConfig()
AWQConfig(base_config, step=AWQStep.PREPARE)
AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
AWQConfig(base_config, step=AWQStep.CONVERT)

AWQConfig(base_config, step="prepare")
AWQConfig(base_config, step="prepare_for_loading")
AWQConfig(base_config, step="convert")

with self.assertRaisesRegex(ValueError, "is not one of"):
AWQConfig(base_config, step="not_supported")

def test_awq_functionality(self):
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)

# awq quantization
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
dtype=original_dtype,
device=device,
)
ref_out = torch.cat([m(d.squeeze(0)) for d in dataset])

calibration_data = dataset[:n_calibration_examples]

quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
baseline_out = torch.cat([m_baseline(d.squeeze(0)) for d in dataset])

loss_awq = (ref_out - awq_out).pow(2).mean().item()
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq < loss_base

def test_awq_loading(self):
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
dtype=original_dtype,
device=device,
)
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
loaded_model.load_state_dict(state_dict, assign=True)

m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset])

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)

def test_awq_loading_vllm(self):
"""Simulate weight loading in vllm:
* prepare model weight to the same format (awq weight)
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint

There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
"""
device = "cuda"
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
group_size = 128
n_calibration_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
dtype=original_dtype,
device=device,
)
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
quantize_(loaded_model, quant_config)

loaded_model.linear1.weight.copy_(state_dict["linear1.weight"])
loaded_model.linear2.weight.copy_(state_dict["linear2.weight"])
loaded_model.linear3.weight.copy_(state_dict["linear3.weight"])

m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset])

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)


@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_rocm("ROCm enablement in progress")
def test_save_weights_only():
dataset_size = 100
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16
quant_dtype = torch.uint4
device = "cuda"
group_size = 128
n_calibration_examples = 10
n_validation_examples = 10
sequence_length = 5

m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
m2 = deepcopy(m)
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
dtype=original_dtype,
device=device,
)
calibration_data = dataset[:n_calibration_examples]

# calibrate
insert_awq_observer_(
m,
n_validation_examples,
sequence_length,
quant_dtype=quant_dtype,
group_size=group_size,
)

for example in calibration_data:
m(example.to(device))

# quantize
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
quantize_(
m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear
)

model_save_path = "awq_model.pth"
torch.save(m.state_dict(), model_save_path)
m2.load_state_dict(
torch.load(model_save_path), assign=True
) # load weights only.torch.load(model_save_path)
os.remove(model_save_path)

m = torch.compile(m, fullgraph=True)
m2 = torch.compile(m2, fullgraph=True)

awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset])

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset])

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)


if __name__ == "__main__":
run_tests()
Loading
Loading