Skip to content

Commit c4fc948

Browse files
[ExecuTorch] XNNPACK: prefer qc over qb when gs == k for non-int4 (#14248)
Co-authored-by: Digant Desai <[email protected]>
1 parent 0e9d871 commit c4fc948

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def get_per_channel_dtype(
232232
if quant_params.dtype == torch.int32:
233233
return XNNDatatype.xnn_datatype_qcint32
234234
elif quant_params.dtype == torch.int8:
235-
if quant_params.is_per_channel_group:
235+
if quant_params.per_channel_group:
236236
# 4-bit per channel group quantized weights
237237
# No 8-bit support yet
238238
assert (
@@ -282,7 +282,7 @@ def get_quant_params(
282282
buffer_idx = len(xnn_graph.constant_data)
283283
num_scales = scale.numel()
284284

285-
if quant_params.is_per_channel_group:
285+
if quant_params.per_channel_group:
286286
scale = scale.to(torch.bfloat16)
287287

288288
num_bytes = scale.untyped_storage().nbytes()
@@ -300,7 +300,7 @@ def get_quant_params(
300300
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
301301
)
302302

303-
if quant_params.is_per_channel_group:
303+
if quant_params.per_channel_group:
304304
return PerChannelGroupQuant(
305305
scale=[],
306306
channel_dim=quant_params.axis,
@@ -335,7 +335,7 @@ def _check_per_channel_group_params(
335335
) -> None:
336336
# Make sure things are lining up for per_channel_group quantization case
337337
# Has to be done this late because we don't have clean access to the actual tensor
338-
assert quant_params.is_per_channel_group, "Not per_channel_group quantization"
338+
assert quant_params.per_channel_group, "Not per_channel_group quantization"
339339
# linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0
340340
num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
341341
assert (

backends/xnnpack/operators/quant_params.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,39 @@ def __init__(
8989
# Groupwise quantization for weight
9090
self.per_channel_group = False
9191
self.group_size = group_size
92+
93+
tensor = q_input.meta["val"]
94+
9295
if self.group_size > 0:
9396
assert (
9497
self.per_channel is True
9598
), "Only per channel quantization supports groupwise quantization"
9699
assert (
97100
cast(torch.Tensor, scale).ndim == 2
98101
), "Scale must be 2D for per channel groupwise quant"
99-
self.per_channel_group = True
100-
assert group_size > 0, "Group size must be greater than 0"
101-
self.is_per_channel_group = self.per_channel and self.group_size > 0
102+
# Assumed scale shape - [out_channels, in_channels/group_size]
103+
input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size
104+
# 2d weight tensor shape - [out_channels, in_channels]
105+
assert (
106+
tensor.shape[1] == input_channels
107+
), "Invalid input channels for groupwise quant"
108+
# Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only
109+
# int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack.
110+
self.per_channel_group = (
111+
self.group_size <= input_channels
112+
if self.is_qc4w
113+
else self.group_size < input_channels
114+
)
115+
116+
if not self.per_channel_group:
117+
if cast(torch.Tensor, scale).ndim == 2:
118+
# TODO: don't reshape scale for per_channel cases
119+
assert (
120+
cast(torch.Tensor, scale).shape[1] == 1
121+
), "Invalid scale shape for per channel quantization"
122+
scale = cast(torch.Tensor, scale).squeeze(1)
102123

103-
if per_channel and not self.is_per_channel_group:
104-
tensor = q_input.meta["val"]
124+
if per_channel and not self.per_channel_group:
105125
assert (
106126
tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0]
107127
), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}"
@@ -110,6 +130,39 @@ def __init__(
110130
tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0]
111131
), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}"
112132

133+
def __str__(self) -> str:
134+
"""String representation of QuantParams for debugging and logging."""
135+
assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor)
136+
scale_str = (
137+
f"{self.scale}"
138+
if isinstance(self.scale, float)
139+
else f"tensor{tuple(self.scale.shape)}"
140+
)
141+
assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor)
142+
zp_str = (
143+
f"{self.zp}"
144+
if isinstance(self.zp, float)
145+
else f"tensor{tuple(self.zp.shape)}"
146+
)
147+
148+
return (
149+
f"QuantParams("
150+
f"per_channel={self.per_channel}, "
151+
f"per_channel_group={self.per_channel_group}, "
152+
f"scale={scale_str}, "
153+
f"zp={zp_str}, "
154+
f"axis={self.axis}, "
155+
f"dtype={self.dtype}, "
156+
f"qmin={self.qmin}, "
157+
f"qmax={self.qmax}, "
158+
f"is_dynamic={self.is_dynamic}, "
159+
f"is_input={self.is_input}, "
160+
f"is_output={self.is_output}, "
161+
f"group_size={self.group_size}, "
162+
f"is_qc4w={self.is_qc4w}"
163+
f")"
164+
)
165+
113166
def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
114167
# Do nothing if already quantized by the Quantizer
115168
if tensor.dtype == self.dtype:

0 commit comments

Comments
 (0)