Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,27 +309,30 @@ def _check_per_channel_group_params(
num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
assert (
quant_params.axis == 0
), "For per_channel_group quant, axis must be 0, but got {axis}"
), f"For per_channel_group quant, axis must be 0, but got {quant_params.axis}"
assert (
len(dims) == 2
), "For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}"
), f"For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}"
assert (
num_groups > 0 and quant_params.group_size > 0
), "For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}"
), f"For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}"
output_channels = dims[quant_params.axis]
input_channels = dims[quant_params.axis ^ 1]
assert (
quant_params.group_size % 32 == 0
), f"Delegation to XNNPACK requires group_size to be a multiple of 32, but got {quant_params.group_size}"
assert (
output_channels == cast(torch.Tensor, quant_params.scale).shape[0]
), "For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}"
), f"For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}"
assert (
input_channels % num_groups == 0
), "For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}"
), f"For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}"
assert (
input_channels % quant_params.group_size == 0
), "For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}"
), f"For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}"
assert (
input_channels / quant_params.group_size == num_groups
), "For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}"
), f"For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}"

# For now group quantization is only supported for 4b weights
assert quant_params.is_qc4w, "Only 4b group quantization is supported"
Expand Down
22 changes: 22 additions & 0 deletions backends/xnnpack/test/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,28 @@ def test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=1e-2
)

@unittest.skipIf(
not torchao_installed, "Per Channel Group Quantization Required TorchAO"
)
def test_qd8_fp32_per_token_groupwise_unsupported_groupsize(self):
# groupsize must be multiple of 32
lin_mod = BaseLinear(
in_size=1,
input_channels=60,
output_channels=60,
dtype=torch.float32,
use_bias=True,
)
inputs = lin_mod.get_inputs()

with self.assertRaisesRegex(
AssertionError,
"Delegation to XNNPACK requires group_size to be a multiple of 32, but got 30",
):
self._test_groupwise_dq_linear(
lin_mod, inputs, group_size=30, use_bias=False, atol=1e-2
)

def _test_linear(
self,
make_module,
Expand Down