diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f523cb091c..280c50ad8f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -213,6 +213,9 @@ def forward(self, x): class TestQAT(TestCase): SEED = 123 + GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + ( + ["xpu"] if torch.xpu.is_available() else [] + ) def test_fake_quantize_per_channel_group(self): n_bit = 4 @@ -306,6 +309,7 @@ def _set_ptq_weight( self, ptq_linear: torch.nn.Module, qat_linear: torch.nn.Module, + device="cuda" ): """ Set the weight to the quantized version of the given fp32 weights, @@ -341,13 +345,14 @@ def _set_ptq_weight( ptq_linear.zeros = zp elif isinstance(ptq_linear, WeightOnlyInt4Linear): assert isinstance(qat_linear, Int4WeightOnlyQATLinear) + assert device == torch.device("cuda") or device == torch.device("xpu"), "Device must be either CUDA or XPU" (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( qat_linear.weight, n_bit, group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), + q_weight.to(device), qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight @@ -600,91 +605,93 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "skipping when GPU is not available") def test_qat_4w_primitives(self): - n_bit = 4 - group_size = 32 - inner_k_tiles = 8 - scales_precision = torch.bfloat16 - device = torch.device("cuda") - dtype = torch.bfloat16 - torch.manual_seed(self.SEED) - x = torch.randn(100, 256, dtype=dtype, device=device) - weight = torch.randn(512, 256, dtype=dtype, device=device) - - # PTQ - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - weight, - n_bit, - group_size, - scales_precision, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(device), - inner_k_tiles, - ) - ptq_out = torch.ops.aten._weight_int4pack_mm( - x, q_weight, group_size, scales_and_zeros - ) + for device in self.GPU_DEVICES: + n_bit = 4 + group_size = 32 + inner_k_tiles = 8 + scales_precision = torch.bfloat16 + device = torch.device(device) + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + x = torch.randn(100, 256, dtype=dtype, device=device) + weight = torch.randn(512, 256, dtype=dtype, device=device) + + # PTQ + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + weight, + n_bit, + group_size, + scales_precision, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(device), + inner_k_tiles, + ) + ptq_out = torch.ops.aten._weight_int4pack_mm( + x, q_weight, group_size, scales_and_zeros + ) - # QAT - block_size = (1, group_size) - quant_min = 0 - quant_max = 2**n_bit - 1 - scales, zero_points = get_groupwise_affine_qparams( - weight, - n_bit, - group_size, - scales_precision, - ) - w_fq = _fake_quantize_affine( - weight, - block_size, - scales, - zero_points, - torch.int32, - quant_min, - quant_max, - zero_point_domain=ZeroPointDomain.FLOAT, - ) - qat_out = torch.nn.functional.linear(x, w_fq) + # QAT + block_size = (1, group_size) + quant_min = 0 + quant_max = 2**n_bit - 1 + scales, zero_points = get_groupwise_affine_qparams( + weight, + n_bit, + group_size, + scales_precision, + ) + w_fq = _fake_quantize_affine( + weight, + block_size, + scales, + zero_points, + torch.int32, + quant_min, + quant_max, + zero_point_domain=ZeroPointDomain.FLOAT, + ) + qat_out = torch.nn.functional.linear(x, w_fq) - self._assert_close_4w(qat_out, ptq_out) + self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "skipping when GPU is not available") def test_qat_4w_linear(self): - from torchao.quantization.GPTQ import WeightOnlyInt4Linear - from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear + for device in self.GPU_DEVICES: + from torchao.quantization.GPTQ import WeightOnlyInt4Linear + from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear + + group_size = 128 + device = torch.device(device) + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + qat_linear = Int4WeightOnlyQATLinear( + 256, + 688, + bias=False, + groupsize=group_size, + device=device, + ) + ptq_linear = WeightOnlyInt4Linear( + 256, + 688, + bias=False, + groupsize=group_size, + device=device, + ) - group_size = 128 - device = torch.device("cuda") - dtype = torch.bfloat16 - torch.manual_seed(self.SEED) - qat_linear = Int4WeightOnlyQATLinear( - 256, - 688, - bias=False, - groupsize=group_size, - device=device, - ) - ptq_linear = WeightOnlyInt4Linear( - 256, - 688, - bias=False, - groupsize=group_size, - device=device, - ) + # Force the weights to be the same + self._set_ptq_weight(ptq_linear, qat_linear, device=device) - # Force the weights to be the same - self._set_ptq_weight(ptq_linear, qat_linear) - - # Compare linear values - torch.manual_seed(self.SEED) - x = torch.randn(100, 256, dtype=dtype, device=device) - x2 = copy.deepcopy(x) - qat_out = qat_linear(x) - ptq_out = ptq_linear(x2) - self._assert_close_4w(qat_out, ptq_out) + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256, dtype=dtype, device=device) + x2 = copy.deepcopy(x) + qat_out = qat_linear(x) + ptq_out = ptq_linear(x2) + self._assert_close_4w(qat_out, ptq_out) def test_qat_4w_quantizer_gradients(self): from torchao.quantization.qat import Int4WeightOnlyQATQuantizer @@ -692,50 +699,56 @@ def test_qat_4w_quantizer_gradients(self): quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "skipping when GPU is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer - from torchao.quantization.qat import Int4WeightOnlyQATQuantizer - - group_size = 32 - inner_k_tiles = 8 - device = torch.device("cuda") - dtype = torch.bfloat16 - torch.manual_seed(self.SEED) - m = M().to(device).to(dtype) - m2 = copy.deepcopy(m) - qat_quantizer = Int4WeightOnlyQATQuantizer( - groupsize=group_size, - inner_k_tiles=inner_k_tiles, - ) - ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, - inner_k_tiles=inner_k_tiles, - ) - qat_model = qat_quantizer.prepare(m) - ptq_model = ptq_quantizer.quantize(m2) - - # Compare model values - torch.manual_seed(self.SEED) - x = [i.to(device).to(dtype) for i in m.example_inputs()] - x2 = copy.deepcopy(x) - qat_out = qat_model(*x) - ptq_out = ptq_model(*x2) - self._assert_close_4w(qat_out, ptq_out) - - # Convert QAT model and compare model values - converted_model = qat_quantizer.convert(qat_model) - converted_out = converted_model(*x) - torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) - - # Compare converted state dict - ptq_state_dict = ptq_model.state_dict() - converted_state_dict = converted_model.state_dict() - self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) - for k in ptq_state_dict.keys(): - torch.testing.assert_close( - ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + for device in self.GPU_DEVICES: + if device == "xpu": + self.skipTest("Skipped due to https://github.com/intel/torch-xpu-ops/issues/1770") + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer + + group_size = 32 + inner_k_tiles = 8 + dtype = torch.bfloat16 + device = torch.device(device) + torch.manual_seed(self.SEED) + m = M().to(device).to(dtype) + print(next(m.parameters()).device) + m2 = copy.deepcopy(m) + print(next(m2.parameters()).device) + qat_quantizer = Int4WeightOnlyQATQuantizer( + groupsize=group_size, + inner_k_tiles=inner_k_tiles, + ) + ptq_quantizer = Int4WeightOnlyQuantizer( + groupsize=group_size, + inner_k_tiles=inner_k_tiles, + device=device ) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) + + # Compare model values + torch.manual_seed(self.SEED) + x = [i.to(device).to(dtype) for i in m.example_inputs()] + x2 = copy.deepcopy(x) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + self._assert_close_4w(qat_out, ptq_out) + + # Convert QAT model and compare model values + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) class _MyQATQuantizer(TwoStepQuantizer): """ @@ -1877,6 +1890,7 @@ def _test_quantize_api_against_ptq( target_convert_sqnr: float, dtype: torch.dtype = torch.bfloat16, module_type: str = "linear", + device="cuda" ): """ Test the following: @@ -1888,15 +1902,16 @@ def _test_quantize_api_against_ptq( quantize_(model, base_config) """ + assert device == "cuda" or device == "xpu", "Device must be either CUDA or XPU" torch.manual_seed(self.SEED) if module_type == "linear": - m = M().to(dtype).cuda() - example_inputs = (m.example_inputs()[0].to(dtype).cuda(),) + m = M().to(dtype).to(device) + example_inputs = (m.example_inputs()[0].to(dtype).to(device),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear) elif module_type == "embedding": - m = M3().to(dtype).cuda() - example_inputs = (m.example_inputs()[0].cuda(),) + m = M3().to(dtype).to(device) + example_inputs = (m.example_inputs()[0].to(device),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding) else: raise ValueError(f"Unknown module type {module_type}") @@ -1971,20 +1986,22 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "skipping when GPU is not available") def test_quantize_api_int8_int4(self): """ Test the following: quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) """ - self._test_quantize_api_against_ptq( - Int8DynamicActivationInt4WeightConfig(group_size=32), - target_prepare_sqnr=30, - target_convert_sqnr=float("inf"), - ) + for device in self.GPU_DEVICES: + self._test_quantize_api_against_ptq( + Int8DynamicActivationInt4WeightConfig(group_size=32), + target_prepare_sqnr=30, + target_convert_sqnr=float("inf"), + device=device + ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "skipping when GPU is not available") @parametrize( "weight_dtype, weight_granularity, dtype", [ @@ -2000,16 +2017,18 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): quantize_(model, QATConfig(Int8DynamicActivationIntxWeightConfig(), step="prepare")) quantize_(model, QATConfig(Int8DynamicActivationIntxWeightConfig(), step="convert")) """ - self._test_quantize_api_against_ptq( - Int8DynamicActivationIntxWeightConfig( - weight_dtype=weight_dtype, weight_granularity=weight_granularity - ), - target_prepare_sqnr=float("inf"), - target_convert_sqnr=float("inf"), - dtype=dtype, - ) + for device in self.GPU_DEVICES: + self._test_quantize_api_against_ptq( + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, weight_granularity=weight_granularity + ), + target_prepare_sqnr=float("inf"), + target_convert_sqnr=float("inf"), + dtype=dtype, + device=device + ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(len(GPU_DEVICES) == 0, "skipping when GPU is not available") @parametrize( "weight_dtype, granularity, dtype, module_type", [ @@ -2026,13 +2045,15 @@ def test_quantize_api_intx(self, weight_dtype, granularity, dtype, module_type): quantize_(model, QATConfig(IntxWeightOnlyConfig(), step="prepare")) quantize_(model, QATConfig(IntxWeightOnlyConfig(), step="convert")) """ - self._test_quantize_api_against_ptq( - IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), - target_prepare_sqnr=float("inf"), - target_convert_sqnr=float("inf"), - dtype=dtype, - module_type=module_type, - ) + for device in self.GPU_DEVICES: + self._test_quantize_api_against_ptq( + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + target_prepare_sqnr=float("inf"), + target_convert_sqnr=float("inf"), + dtype=dtype, + module_type=module_type, + device=device + ) def test_infer_fp8_int4_config(self): """