From ebaad4d086af956e83300d1ee6245b9ca4eb7727 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 6 Dec 2024 09:37:53 -0800 Subject: [PATCH] Check group size is divisible by 32 (#6941) Summary: Currently when checking the per_channel_group quantization parameters we don't check that the group_size must be a multiple of 32. This constraint was added after we implemented the original checks here. Let's add multiple of 32 here. Reviewed By: malfet, digantdesai Differential Revision: D66131456 --- backends/xnnpack/operators/node_visitor.py | 17 ++++++++++------- backends/xnnpack/test/ops/linear.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index de48748f8f4..018ce1e568f 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -309,27 +309,30 @@ def _check_per_channel_group_params( num_groups = cast(torch.Tensor, quant_params.scale).shape[1] assert ( quant_params.axis == 0 - ), "For per_channel_group quant, axis must be 0, but got {axis}" + ), f"For per_channel_group quant, axis must be 0, but got {quant_params.axis}" assert ( len(dims) == 2 - ), "For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}" + ), f"For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}" assert ( num_groups > 0 and quant_params.group_size > 0 - ), "For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}" + ), f"For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}" output_channels = dims[quant_params.axis] input_channels = dims[quant_params.axis ^ 1] + assert ( + quant_params.group_size % 32 == 0 + ), f"Delegation to XNNPACK requires group_size to be a multiple of 32, but got {quant_params.group_size}" assert ( output_channels == cast(torch.Tensor, quant_params.scale).shape[0] - ), "For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}" + ), f"For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}" assert ( input_channels % num_groups == 0 - ), "For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}" + ), f"For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}" assert ( input_channels % quant_params.group_size == 0 - ), "For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}" + ), f"For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}" assert ( input_channels / quant_params.group_size == num_groups - ), "For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}" + ), f"For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}" # For now group quantization is only supported for 4b weights assert quant_params.is_qc4w, "Only 4b group quantization is supported" diff --git a/backends/xnnpack/test/ops/linear.py b/backends/xnnpack/test/ops/linear.py index cc96bd53da8..348e36bd0cf 100644 --- a/backends/xnnpack/test/ops/linear.py +++ b/backends/xnnpack/test/ops/linear.py @@ -458,6 +458,28 @@ def test_qd8_fp16_per_token_weight_per_channel_group_int4(self): lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=1e-2 ) + @unittest.skipIf( + not torchao_installed, "Per Channel Group Quantization Required TorchAO" + ) + def test_qd8_fp32_per_token_groupwise_unsupported_groupsize(self): + # groupsize must be multiple of 32 + lin_mod = BaseLinear( + in_size=1, + input_channels=60, + output_channels=60, + dtype=torch.float32, + use_bias=True, + ) + inputs = lin_mod.get_inputs() + + with self.assertRaisesRegex( + AssertionError, + "Delegation to XNNPACK requires group_size to be a multiple of 32, but got 30", + ): + self._test_groupwise_dq_linear( + lin_mod, inputs, group_size=30, use_bias=False, atol=1e-2 + ) + def _test_linear( self, make_module,