diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 401ba3de0d1..393b831c2a4 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -66,8 +66,14 @@ def __init__( def forward(self, x): return self.linear(x) - def get_inputs(self): - return (torch.randn(1, self.in_size, self.ic).to(self.op_dtype),) + def get_inputs(self, rank=3): + # rank = 3 as default to inflate the act rank by 1 in batch dim + # This is to make sure we don't specialize on 2D shapes. + inp = torch.randn(self.in_size, self.ic).to(self.op_dtype) + for _ in range(rank - 2): + inp = inp.unsqueeze(0) + assert inp.ndim == rank + return (inp,) class AddMMModule(torch.nn.Module): @@ -165,26 +171,19 @@ def forward(self, x): def get_inputs(self): return (torch.rand(self.in_size, self.input_size, dtype=torch.float),) + class ParallelLinear(torch.nn.Module): def __init__(self, input_size, output_size): super().__init__() - self.linear1_weight = torch.nn.Parameter( - torch.rand(output_size, input_size) - ) + self.linear1_weight = torch.nn.Parameter(torch.rand(output_size, input_size)) self.linear1_bias = torch.nn.Parameter(torch.rand(output_size)) - self.linear2_weight = torch.nn.Parameter( - torch.rand(output_size, input_size) - ) + self.linear2_weight = torch.nn.Parameter(torch.rand(output_size, input_size)) self.linear2_bias = torch.nn.Parameter(torch.rand(output_size)) def forward(self, x, y): - a = torch.nn.functional.linear( - x, self.linear1_weight, self.linear1_bias - ) - b = torch.nn.functional.linear( - y, self.linear2_weight, self.linear2_bias - ) + a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias) + b = torch.nn.functional.linear(y, self.linear2_weight, self.linear2_bias) return a + b @@ -219,7 +218,7 @@ def _test_linear( num_batch_dims=1, quant_type=None, dtype: torch.dtype = torch.float, - atol=1e-03, + atol=1e-03, # TODO(T212995726): Investigate right atol for rand[n] inputs ): """ Helper function to test linear op with different configurations. @@ -312,7 +311,7 @@ def _test_dqlinear( is_per_channel=False, uses_bias=False, qconfig: Optional[QuantizationConfig] = None, - atol=5e-02, + atol=5e-02, # TODO(T212995726): Investigate right atol for rand[n] inputs ): """ Helper function to test dynamic quantized linear op with different configurations. @@ -364,8 +363,8 @@ def _test_groupwise_dq_linear( use_bias: bool = False, group_size: int = 8, num_linears: int = 1, - atol: float = 5e-3, - rtol: float = 5e-3, + atol: float = 5e-3, # TODO(T212995726): Investigate right atol for rand[n] inputs + rtol: float = 5e-3, # TODO(T212995726): Investigate right rtol for rand[n] inputs ): """ Helper function to test groupwise dynamic quantized linear op with different configurations. @@ -420,7 +419,7 @@ def _test_linear_overwrite_precision( uses_bias: bool, quant_type: str, quant_node_checks: List[Dict[str, int]], - atol: float = 1e-03, + atol: float = 1e-03, # TODO(T212995726): Investigate right atol for rand[n] inputs ): """ This test is to test the overwrite precision of linear op. @@ -517,6 +516,96 @@ def get_qnode_checks(quant_node_checks, dialect): # qtol=bool(quant_config), atol=atol # ) + def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float): + for uses_bias in (False, True): + module = BaseLinear( + in_size=8, + input_channels=13, + output_channels=17, + dtype=dtype, + use_bias=uses_bias, + ) + inputs = module.get_inputs() + + self._test_dqlinear( + module, + inputs, + dynamic_shapes=({1: torch.export.Dim("batch", max=100)},), + is_per_channel=True, + uses_bias=uses_bias, + ) + + def _test_qd8_per_channel_4w_linear(self, dtype: torch.dtype = torch.float): + qconfig = self._get_4b_dqconfig() + input_channels = [2, 63] + output_channels = [1, 127] + batches = [ + 2, + ] + use_bias = [False, True] + dtypes = [ + dtype, + ] + + for bs, bias, ipc, opc, dtype in product( + batches, + use_bias, + input_channels, + output_channels, + dtypes, + ): + module = BaseLinear( + in_size=bs, + input_channels=ipc, + output_channels=opc, + dtype=dtype, + use_bias=bias, + ) + inputs = module.get_inputs() + + self._test_dqlinear( + module, + inputs, + dynamic_shapes=({1: torch.export.Dim("batch", max=100)},), + is_per_channel=True, + uses_bias=bias, + qconfig=qconfig, + atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs + ) + + def _test_qd8_per_token_weight_per_channel_group_int4( + self, dtype: torch.dtype = torch.float + ): + M_sizes = [1, 2, 17, 31] + K_sizes = [32, 32, 64, 128] + bl_sizes = [32, 32, 32, 64] + N_sizes = [2, 17, 92, 128] + + for use_bias in [True, False]: + for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes): + lin_mod = BaseLinear( + in_size=M, + input_channels=K, + output_channels=N, + dtype=dtype, + use_bias=use_bias, + ) + + inputs = lin_mod.get_inputs() + # Half requires slightly higher atol, but if you look at error it is not that bad: + # Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375. + # -- Model vs. Reference -- + # Numel: 4, 4 + # Median: -0.05023193359375, -0.0516357421875 + # Mean: 0.2373046875, 0.237060546875 + # Max: 1.0078125, 1.0078125 + # Min: -0.08465576171875, -0.08441162109375 + atol = ( + 1e-2 if dtype == torch.half else 5e-3 + ) # TODO(T212995726): Investigate right atol for rand[n] inputs + self._test_groupwise_dq_linear( + lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol + ) def test_fp16_linear(self): for use_bias in (True, False): @@ -528,7 +617,7 @@ def test_fp16_linear(self): num_batch_dims=num_batch_dims, uses_bias=use_bias, dtype=torch.float16, - atol=5e-2, + atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs ) def test_fp32_linear(self): @@ -600,84 +689,109 @@ def test_qs8_linear(self): quant_type="per_tensor", ) - def test_qd8_per_channel_linear(self): - for uses_bias in (False, True): - inputs = (torch.randn(2, 4),) - module = torch.nn.Linear(4, 5, bias=uses_bias) + # Tests for q[dp]8-f16-qc8w + def test_qd8_f16_per_channel_linear(self): + self._test_qd8_per_channel_linear(dtype=torch.half) - self._test_dqlinear( - module, - inputs, - dynamic_shapes=({0: torch.export.Dim("batch", max=100)},), - is_per_channel=True, - uses_bias=uses_bias, - ) + # Tests for q[dp]8-f32-qc8w + def test_qd8_f32_per_channel_linear(self): + self._test_qd8_per_channel_linear(dtype=torch.float) + # Tests for q[dp]8-f16-qc4w + def test_linear_qd8_f16_per_channel_int4(self): + self._test_qd8_per_channel_4w_linear(dtype=torch.half) - def test_qd8_per_channel_4w_linear(self): - qconfig = self._get_4b_dqconfig() - input_channels = [2, 63] - output_channels = [1, 8, 127] - batches = [2, 2] - use_bias = [False, True] + # Tests for q[dp]8-f32-qc4w + def test_linear_qd8_f32_per_channel_int4(self): + self._test_qd8_per_channel_4w_linear(dtype=torch.float) - for bs, bias, ipc, opc in product( - batches, - use_bias, - input_channels, - output_channels, - ): - inputs = (torch.rand(bs, ipc),) - module = torch.nn.Linear(ipc, opc, bias=bias) + # Tests for q[dp]8-f16-qb4w + @unittest.skipIf( + not torchao_installed, "Per Channel Group Quantization Required TorchAO" + ) + def test_linear_qd8_f16_per_token_weight_per_channel_group_int4(self): + self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.half) - self._test_dqlinear( - module, - inputs, - dynamic_shapes=({0: torch.export.Dim("batch", max=100)},), - is_per_channel=True, - uses_bias=bias, - qconfig=qconfig, + # Tests for q[dp]8-f32-qb4w + @unittest.skipIf( + not torchao_installed, "Per Channel Group Quantization Required TorchAO" + ) + def test_linear_qd8_f32_per_token_weight_per_channel_group_int4(self): + self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.float) + + @unittest.skipIf( + not torchao_installed, "Per Channel Group Quantization Required TorchAO" + ) + def test_linear_qd8_per_token_groupwise_unsupported_groupsize(self): + # groupsize must be multiple of 32 + for dtype in [torch.float, torch.half]: + lin_mod = BaseLinear( + in_size=1, + input_channels=60, + output_channels=60, + dtype=dtype, + use_bias=True, ) + inputs = lin_mod.get_inputs() + + with self.assertRaisesRegex( + AssertionError, + "Delegation to XNNPACK requires group_size to be a multiple of 32, but got 30", + ): + self._test_groupwise_dq_linear( + lin_mod, inputs, group_size=30, use_bias=False, atol=1e-2 + ) def test_qd8_per_channel_linear_parallel(self): in_size = 2 input_size = 4 output_size = 5 + for dtype in torch.float, torch.half: + inputs = ( + torch.rand(in_size, input_size, dtype=dtype), + torch.rand(in_size, input_size, dtype=dtype), + ) + batch_dim = torch.export.Dim("batch", max=100) + dynamic_shapes = ({0: batch_dim}, {0: batch_dim}) - inputs = ( - torch.rand(in_size, input_size, dtype=torch.float), - torch.rand(in_size, input_size, dtype=torch.float), - ) - batch_dim = torch.export.Dim("batch", max=100) - dynamic_shapes = ({0: batch_dim}, {0: batch_dim}) - - self._test_dqlinear( - ParallelLinear(input_size=input_size, output_size=output_size), - inputs, - dynamic_shapes=dynamic_shapes, - linear_count=2, - is_per_channel=True, - uses_bias=True, - ) + self._test_dqlinear( + ParallelLinear(input_size=input_size, output_size=output_size).to( + dtype + ), + inputs, + dynamic_shapes=dynamic_shapes, + linear_count=2, + is_per_channel=True, + uses_bias=True, + ) def test_qd8_per_channel_linear_with_two_batch(self): in_size = 2 - input_size = 4 - output_size = 5 - - linear = torch.nn.Linear(input_size, output_size) - inputs = (torch.randn(2, in_size, input_size, dtype=torch.float),) - batch_dim = torch.export.Dim("batch", max=100) - dynamic_shapes = ({0: batch_dim, 1: batch_dim},) - - self._test_dqlinear( - linear, - inputs, - dynamic_shapes=dynamic_shapes, - linear_count=1, - is_per_channel=True, - uses_bias=True, - ) + input_size = 14 + output_size = 15 + + for dtype in torch.float, torch.half: + for use_bias in (False, True): + linear = BaseLinear( + in_size=in_size, + input_channels=input_size, + output_channels=output_size, + dtype=dtype, + use_bias=use_bias, + ) + # Create inputs with two batch dimensions, i.e. 3D activation + inputs = (torch.randn(in_size, in_size, input_size).to(dtype),) + batch_dim = torch.export.Dim("batch", max=100) + dynamic_shapes = ({0: batch_dim, 1: batch_dim},) + + self._test_dqlinear( + linear, + inputs, + dynamic_shapes=dynamic_shapes, + linear_count=1, + is_per_channel=True, + uses_bias=True, + ) def test_qd8_per_channel_linear_sequential(self): lin_mod = LinearSequential() @@ -691,7 +805,7 @@ def test_qd8_per_channel_linear_sequential(self): linear_count=2, is_per_channel=True, uses_bias=True, - atol=1e-1, + atol=1e-1, # TODO(T212995726): Investigate right atol for rand[n] inputs ) def test_qd8_per_channel_linear_parallel_and_sequential(self): @@ -709,87 +823,10 @@ def test_qd8_per_channel_linear_parallel_and_sequential(self): linear_count=3, is_per_channel=True, uses_bias=True, - atol=1e-1, + atol=1e-1, # TODO(T212995726): Investigate right atol for rand[n] inputs ) - @unittest.skipIf( - not torchao_installed, "Per Channel Group Quantization Required TorchAO" - ) - def test_qd8_fp32_per_token_weight_per_channel_group_int4(self): - M_sizes = [1, 2, 17, 31] - K_sizes = [32, 32, 64, 128] - bl_sizes = [32, 32, 32, 64] - N_sizes = [2, 17, 92, 128] - - for use_bias in [True, False]: - for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes): - lin_mod = BaseLinear( - input_channels=K, - output_channels=N, - dtype=torch.float, - use_bias=use_bias, - ) - - inputs = (torch.randn(1, M, K),) - self._test_groupwise_dq_linear( - lin_mod, inputs, group_size=bl, use_bias=use_bias - ) - - @unittest.skipIf( - not torchao_installed, "Per Channel Group Quantization Required TorchAO" - ) - def test_qd8_fp16_per_token_weight_per_channel_group_int4(self): - M_sizes = [1, 2, 17, 31] - K_sizes = [32, 32, 64, 128] - bl_sizes = [32, 32, 32, 64] - N_sizes = [2, 17, 92, 128] - - for use_bias in [True, False]: - for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes): - lin_mod = BaseLinear( - in_size=M, - input_channels=K, - output_channels=N, - dtype=torch.float16, - use_bias=use_bias, - ) - - inputs = lin_mod.get_inputs() - # This requires slightly higher atol, but if you look at error it is not that bad: - # Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375. - # -- Model vs. Reference -- - # Numel: 4, 4 - # Median: -0.05023193359375, -0.0516357421875 - # Mean: 0.2373046875, 0.237060546875 - # Max: 1.0078125, 1.0078125 - # Min: -0.08465576171875, -0.08441162109375 - self._test_groupwise_dq_linear( - lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=1e-2 - ) - - @unittest.skipIf( - not torchao_installed, "Per Channel Group Quantization Required TorchAO" - ) - def test_qd8_fp32_per_token_groupwise_unsupported_groupsize(self): - # groupsize must be multiple of 32 - lin_mod = BaseLinear( - in_size=1, - input_channels=60, - output_channels=60, - dtype=torch.float32, - use_bias=True, - ) - inputs = lin_mod.get_inputs() - - with self.assertRaisesRegex( - AssertionError, - "Delegation to XNNPACK requires group_size to be a multiple of 32, but got 30", - ): - self._test_groupwise_dq_linear( - lin_mod, inputs, group_size=30, use_bias=False, atol=1e-2 - ) - - def test_qs8_as_fp32(self): + def test_linear_qs8_as_fp32(self): for use_bias in (True, False): self._test_linear_overwrite_precision( lambda in_size, out_size: torch.nn.Linear( @@ -803,7 +840,7 @@ def test_qs8_as_fp32(self): }, ) - def test_qc8_as_fp32(self): + def test_linear_qc8_as_fp32(self): for use_bias in (True, False): self._test_linear_overwrite_precision( lambda in_size, out_size: torch.nn.Linear( @@ -818,7 +855,7 @@ def test_qc8_as_fp32(self): }, ) - def test_qd8_as_fp32(self): + def test_linear_qd8_as_fp32(self): for use_bias in (True, False): self._test_linear_overwrite_precision( lambda in_size, out_size: torch.nn.Linear( @@ -832,4 +869,3 @@ def test_qd8_as_fp32(self): "dequantize_per_channel.default": 1, # 1: weight }, ) - \ No newline at end of file