From 1bee2fb1941f2cf5ae0029b6c9657c5c81c630b1 Mon Sep 17 00:00:00 2001 From: youn17 Date: Fri, 10 Oct 2025 15:02:30 +0900 Subject: [PATCH] make awq test parallel execution --- test/prototype/test_awq.py | 252 ++++++++++++++++++------------------- 1 file changed, 120 insertions(+), 132 deletions(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 70bca35f90..ce813d8076 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -79,6 +79,7 @@ def forward(self, x): "cpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="opaque")], "xpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="plain_int32")], } +configs = [(d, c) for d in devices for c in device_to_base_configs[d]] class TestAWQ(TestCase): @@ -95,109 +96,100 @@ def test_awq_config(self): with self.assertRaisesRegex(ValueError, "is not one of"): AWQConfig(base_config, step="not_supported") - @parametrize("device", devices) - def test_awq_functionality(self, device): + @parametrize("device,base_config", configs) + def test_awq_functionality(self, device, base_config): dataset_size = 10 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs sequence_length = 5 - assert device in device_to_base_configs, "Unsupported device: {}".format(device) - base_configs = device_to_base_configs[device] - - for base_config in base_configs: - m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval() - m_baseline = copy.deepcopy(m) - - dataset = m.example_inputs( - dataset_size, - sequence_length=sequence_length, - ) - # for test, we use calibration_data = dataset so that awq is - # guranteed to be better than baseline - # in reality, calibration_data will be a small subset or a different - # dataset - calibration_data = dataset - # concatenatd inputs - input_cat = torch.cat(calibration_data, dim=-2) - ref_out = m(input_cat) - - # baseline quantization - quantize_(m_baseline, base_config) - - # awq quantization - quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) - quantize_(m, quant_config) - - for example in calibration_data: - m(example) - - quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) - quantize_(m, quant_config) - - # evaluating on calibration data set to remove any uncertainty - awq_out = m(input_cat) - baseline_out = m_baseline(input_cat) - - loss_awq = (ref_out - awq_out).pow(2).mean().item() - loss_base = (ref_out - baseline_out).pow(2).mean().item() - assert loss_awq <= loss_base - - @parametrize("device", devices) - def test_awq_loading(self, device): + m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval() + m_baseline = copy.deepcopy(m) + + dataset = m.example_inputs( + dataset_size, + sequence_length=sequence_length, + ) + # for test, we use calibration_data = dataset so that awq is + # guranteed to be better than baseline + # in reality, calibration_data will be a small subset or a different + # dataset + calibration_data = dataset + input_cat = torch.cat(calibration_data, dim=-2) + ref_out = m(input_cat) + + # baseline quantization + quantize_(m_baseline, base_config) + + # awq quantization + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) + quantize_(m, quant_config) + + for example in calibration_data: + m(example) + + quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) + quantize_(m, quant_config) + + # evaluating on calibration data set to remove any uncertainty + awq_out = m(input_cat) + baseline_out = m_baseline(input_cat) + + loss_awq = (ref_out - awq_out).pow(2).mean().item() + loss_base = (ref_out - baseline_out).pow(2).mean().item() + assert loss_awq <= loss_base + + @parametrize("device,base_config", configs) + def test_awq_loading(self, device, base_config): dataset_size = 10 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs sequence_length = 5 - assert device in device_to_base_configs, "Unsupported device: {}".format(device) - base_configs = device_to_base_configs[device] - - for base_config in base_configs: - m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval() - dataset = m.example_inputs( - dataset_size, - sequence_length=sequence_length, - ) - # for test purpose, we don't need to get a subset - calibration_data = dataset - # concatenatd inputs - input_cat = torch.cat(calibration_data, dim=-2) + m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval() + dataset = m.example_inputs( + dataset_size, + sequence_length=sequence_length, + ) + # for test purpose, we don't need to get a subset + calibration_data = dataset + # concatenat inputs + input_cat = torch.cat(calibration_data, dim=-2) - # calibrate + # calibrate - quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) - quantize_(m, quant_config) + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) + quantize_(m, quant_config) - for example in calibration_data: - m(example) + for example in calibration_data: + m(example) - # quantize - quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) - quantize_(m, quant_config) + # quantize + quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) + quantize_(m, quant_config) - with tempfile.NamedTemporaryFile() as f: - torch.save(m.state_dict(), f) - f.seek(0) - state_dict = torch.load(f) + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) - loaded_model = ToyLinearModel( - l1, l2, l3, device=device, dtype=original_dtype - ).eval() - loaded_model.load_state_dict(state_dict, assign=True) + loaded_model = ToyLinearModel( + l1, l2, l3, device=device, dtype=original_dtype + ).eval() + loaded_model.load_state_dict(state_dict, assign=True) - m = torch.compile(m, fullgraph=True) - loaded_model = torch.compile(loaded_model, fullgraph=True) + m = torch.compile(m, fullgraph=True) + loaded_model = torch.compile(loaded_model, fullgraph=True) - awq_out = m(input_cat) - awq_save_load_out = loaded_model(input_cat) + awq_out = m(input_cat) + awq_save_load_out = loaded_model(input_cat) - assert awq_out is not None - assert awq_save_load_out is not None - assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2) + assert awq_out is not None + assert awq_save_load_out is not None + assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2) - @parametrize("device", devices) - def test_awq_loading_vllm(self, device): + @parametrize("device,base_config", configs) + def test_awq_loading_vllm(self, device, base_config): """Simulate weight loading in vllm: * prepare model weight to the same format (awq weight) * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint @@ -209,55 +201,51 @@ def test_awq_loading_vllm(self, device): original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs sequence_length = 5 - assert device in device_to_base_configs, "Unsupported device: {}".format(device) - base_configs = device_to_base_configs[device] - - for base_config in base_configs: - m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval() - dataset = m.example_inputs( - dataset_size, - sequence_length=sequence_length, - ) - # for test purpose, we don't need to get a subset - calibration_data = dataset - # concatenatd inputs - input_cat = torch.cat(calibration_data, dim=-2) - - # calibrate - quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) - quantize_(m, quant_config) - - for example in calibration_data: - m(example) - - # quantize - quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) - quantize_(m, quant_config) - - with tempfile.NamedTemporaryFile() as f: - torch.save(m.state_dict(), f) - f.seek(0) - state_dict = torch.load(f) - - loaded_model = ToyLinearModel( - l1, l2, l3, device=device, dtype=original_dtype - ).eval() - quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING) - quantize_(loaded_model, quant_config) - - loaded_model.linear1.weight.copy_(state_dict["linear1.weight"]) - loaded_model.linear2.weight.copy_(state_dict["linear2.weight"]) - loaded_model.linear3.weight.copy_(state_dict["linear3.weight"]) - - m = torch.compile(m, fullgraph=True) - loaded_model = torch.compile(loaded_model, fullgraph=True) - - awq_out = m(input_cat) - awq_save_load_out = loaded_model(input_cat) - - assert awq_out is not None - assert awq_save_load_out is not None - assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2) + m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval() + dataset = m.example_inputs( + dataset_size, + sequence_length=sequence_length, + ) + # for test purpose, we don't need to get a subset + calibration_data = dataset + # concatenatd inputs + input_cat = torch.cat(calibration_data, dim=-2) + + # calibrate + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) + quantize_(m, quant_config) + + for example in calibration_data: + m(example) + + # quantize + quant_config = AWQConfig(base_config, step=AWQStep.CONVERT) + quantize_(m, quant_config) + + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + + loaded_model = ToyLinearModel( + l1, l2, l3, device=device, dtype=original_dtype + ).eval() + quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING) + quantize_(loaded_model, quant_config) + + loaded_model.linear1.weight.copy_(state_dict["linear1.weight"]) + loaded_model.linear2.weight.copy_(state_dict["linear2.weight"]) + loaded_model.linear3.weight.copy_(state_dict["linear3.weight"]) + + m = torch.compile(m, fullgraph=True) + loaded_model = torch.compile(loaded_model, fullgraph=True) + + awq_out = m(input_cat) + awq_save_load_out = loaded_model(input_cat) + + assert awq_out is not None + assert awq_save_load_out is not None + assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2) instantiate_parametrized_tests(TestAWQ)