Skip to content

Commit e6f6047

Browse files
committed
Update on "[ExecuTorch] XNNPACK: prefer qc over qb when gs == k for non-int4"
* Prefer chanelwise over groupwise when possible for perf and for int8 which doesn't have groupwise support * Fix bug / improve behavior for affine q/dq with gs == k for per_channel * refactor is_per_channel_group state variable * add QuantParams.__str__() TODO - improve affine quant primitives - T237476295 Differential Revision: [D82060758](https://our.internmc.facebook.com/intern/diff/D82060758/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D82060758/)! [ghstack-poisoned]
1 parent 50994a5 commit e6f6047

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

backends/xnnpack/operators/quant_params.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,16 @@ def __init__(
8989
# Groupwise quantization for weight
9090
self.per_channel_group = False
9191
self.group_size = group_size
92-
92+
9393
tensor = q_input.meta["val"]
94-
94+
9595
if self.group_size > 0:
9696
assert (
9797
self.per_channel is True
9898
), "Only per channel quantization supports groupwise quantization"
9999
assert (
100-
self.axis == 0, "Only axis 0 is supported for per channel groupwise quant"
100+
self.axis == 0,
101+
"Only axis 0 is supported for per channel groupwise quant",
101102
)
102103
assert (
103104
cast(torch.Tensor, scale).ndim == 2
@@ -106,16 +107,23 @@ def __init__(
106107
input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size
107108
# 2d weight tensor shape - [out_channels, in_channels]
108109
assert (
109-
tensor.shape[1] == input_channels, "Invalid input channels for groupwise quant"
110-
)
110+
tensor.shape[1] == input_channels,
111+
"Invalid input channels for groupwise quant",
112+
)
111113
# Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only
112114
# int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack.
113-
self.per_channel_group = self.group_size <= input_channels if self.is_qc4w else self.group_size < input_channels
115+
self.per_channel_group = (
116+
self.group_size <= input_channels
117+
if self.is_qc4w
118+
else self.group_size < input_channels
119+
)
114120

115121
if not self.per_channel_group:
116122
if cast(torch.Tensor, scale).ndim == 2:
117123
# TODO: don't reshape scale for per_channel cases
118-
assert (cast(torch.Tensor, scale).shape[1] == 1), "Invalid scale shape for per channel quantization"
124+
assert (
125+
cast(torch.Tensor, scale).shape[1] == 1
126+
), "Invalid scale shape for per channel quantization"
119127
scale = cast(torch.Tensor, scale).squeeze(1)
120128

121129
if per_channel and not self.per_channel_group:
@@ -130,10 +138,18 @@ def __init__(
130138
def __str__(self) -> str:
131139
"""String representation of QuantParams for debugging and logging."""
132140
assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor)
133-
scale_str = f"{self.scale}" if isinstance(self.scale, float) else f"tensor{tuple(self.scale.shape)}"
141+
scale_str = (
142+
f"{self.scale}"
143+
if isinstance(self.scale, float)
144+
else f"tensor{tuple(self.scale.shape)}"
145+
)
134146
assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor)
135-
zp_str = f"{self.zp}" if isinstance(self.zp, float) else f"tensor{tuple(self.zp.shape)}"
136-
147+
zp_str = (
148+
f"{self.zp}"
149+
if isinstance(self.zp, float)
150+
else f"tensor{tuple(self.zp.shape)}"
151+
)
152+
137153
return (
138154
f"QuantParams("
139155
f"per_channel={self.per_channel}, "

0 commit comments

Comments
 (0)