Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 120 additions & 132 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(self, x):
"cpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="opaque")],
"xpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="plain_int32")],
}
configs = [(d, c) for d in devices for c in device_to_base_configs[d]]


class TestAWQ(TestCase):
Expand All @@ -95,109 +96,100 @@ def test_awq_config(self):
with self.assertRaisesRegex(ValueError, "is not one of"):
AWQConfig(base_config, step="not_supported")

@parametrize("device", devices)
def test_awq_functionality(self, device):
@parametrize("device,base_config", configs)
def test_awq_functionality(self, device, base_config):
dataset_size = 10
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
sequence_length = 5

assert device in device_to_base_configs, "Unsupported device: {}".format(device)
base_configs = device_to_base_configs[device]

for base_config in base_configs:
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
m_baseline = copy.deepcopy(m)

dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test, we use calibration_data = dataset so that awq is
# guranteed to be better than baseline
# in reality, calibration_data will be a small subset or a different
# dataset
calibration_data = dataset
# concatenatd inputs
input_cat = torch.cat(calibration_data, dim=-2)
ref_out = m(input_cat)

# baseline quantization
quantize_(m_baseline, base_config)

# awq quantization
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

# evaluating on calibration data set to remove any uncertainty
awq_out = m(input_cat)
baseline_out = m_baseline(input_cat)

loss_awq = (ref_out - awq_out).pow(2).mean().item()
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq <= loss_base

@parametrize("device", devices)
def test_awq_loading(self, device):
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
m_baseline = copy.deepcopy(m)

dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test, we use calibration_data = dataset so that awq is
# guranteed to be better than baseline
# in reality, calibration_data will be a small subset or a different
# dataset
calibration_data = dataset
input_cat = torch.cat(calibration_data, dim=-2)
ref_out = m(input_cat)

# baseline quantization
quantize_(m_baseline, base_config)

# awq quantization
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

# evaluating on calibration data set to remove any uncertainty
awq_out = m(input_cat)
baseline_out = m_baseline(input_cat)

loss_awq = (ref_out - awq_out).pow(2).mean().item()
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq <= loss_base

@parametrize("device,base_config", configs)
def test_awq_loading(self, device, base_config):
dataset_size = 10
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
sequence_length = 5

assert device in device_to_base_configs, "Unsupported device: {}".format(device)
base_configs = device_to_base_configs[device]

for base_config in base_configs:
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test purpose, we don't need to get a subset
calibration_data = dataset
# concatenatd inputs
input_cat = torch.cat(calibration_data, dim=-2)
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test purpose, we don't need to get a subset
calibration_data = dataset
# concatenat inputs
input_cat = torch.cat(calibration_data, dim=-2)

# calibrate
# calibrate

quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)
for example in calibration_data:
m(example)

# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)
# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(
l1, l2, l3, device=device, dtype=original_dtype
).eval()
loaded_model.load_state_dict(state_dict, assign=True)
loaded_model = ToyLinearModel(
l1, l2, l3, device=device, dtype=original_dtype
).eval()
loaded_model.load_state_dict(state_dict, assign=True)

m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)
m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = m(input_cat)
awq_save_load_out = loaded_model(input_cat)
awq_out = m(input_cat)
awq_save_load_out = loaded_model(input_cat)

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)

@parametrize("device", devices)
def test_awq_loading_vllm(self, device):
@parametrize("device,base_config", configs)
def test_awq_loading_vllm(self, device, base_config):
"""Simulate weight loading in vllm:
* prepare model weight to the same format (awq weight)
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
Expand All @@ -209,55 +201,51 @@ def test_awq_loading_vllm(self, device):
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
sequence_length = 5

assert device in device_to_base_configs, "Unsupported device: {}".format(device)
base_configs = device_to_base_configs[device]

for base_config in base_configs:
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test purpose, we don't need to get a subset
calibration_data = dataset
# concatenatd inputs
input_cat = torch.cat(calibration_data, dim=-2)

# calibrate
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(
l1, l2, l3, device=device, dtype=original_dtype
).eval()
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
quantize_(loaded_model, quant_config)

loaded_model.linear1.weight.copy_(state_dict["linear1.weight"])
loaded_model.linear2.weight.copy_(state_dict["linear2.weight"])
loaded_model.linear3.weight.copy_(state_dict["linear3.weight"])

m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = m(input_cat)
awq_save_load_out = loaded_model(input_cat)

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test purpose, we don't need to get a subset
calibration_data = dataset
# concatenatd inputs
input_cat = torch.cat(calibration_data, dim=-2)

# calibrate
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(
l1, l2, l3, device=device, dtype=original_dtype
).eval()
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
quantize_(loaded_model, quant_config)

loaded_model.linear1.weight.copy_(state_dict["linear1.weight"])
loaded_model.linear2.weight.copy_(state_dict["linear2.weight"])
loaded_model.linear3.weight.copy_(state_dict["linear3.weight"])

m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = m(input_cat)
awq_save_load_out = loaded_model(input_cat)

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)


instantiate_parametrized_tests(TestAWQ)
Expand Down