Skip to content

Commit 1bdc2f1

Browse files
committed
[ExecuTorch][XNNPACK] Don't partition per_tensor weights with qd8
Pull Request resolved: #8787 This is not supported, so we shouldn't partition it. Add an expectedFailure test to indicate that this is not supported. ghstack-source-id: 268894461 @exported-using-ghexport Differential Revision: [D70343584](https://our.internmc.facebook.com/intern/diff/D70343584/)
1 parent 09b592b commit 1bdc2f1

File tree

3 files changed

+51
-13
lines changed

3 files changed

+51
-13
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
is_dynamic_qdq,
2222
is_per_channel,
2323
is_per_channel_group,
24+
is_per_tensor,
2425
is_qparam,
2526
is_quant,
2627
)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
6667
return False
6768

6869
is_valid, _ = self.get_deps(node, ep)
69-
if not is_valid:
70-
why(node, "Failed to get valid dependent nodes.")
7170
return is_valid
7271

7372
def get_node_and_deps(
@@ -123,6 +122,7 @@ def get_deps(
123122
precision = self._detect_precision(node)
124123
if precision not in self.supported_precision_types():
125124
# detected precision but it is either disabled or not supported
125+
why(node, f"Unsupported precision type {precision}")
126126
return (False, [])
127127
_, precision = self._overwrite_precision(node)
128128
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
@@ -143,27 +143,36 @@ def _get_weight_deps(
143143
# First find the weight
144144
weight_node = get_input_node(node, self.weight_idx)
145145
if not is_param_node(ep, weight_node):
146-
return (False, []) # weight must be a static param
146+
why(node, "Expected weight to be a static param")
147+
return (False, [])
147148
gemm_deps.append(weight_node)
148149

149150
return (True, gemm_deps)
150151
else:
151152
# Quantized Weight deps
152153
dequant_node = get_input_node(node, self.weight_idx)
153154
if not is_dequant(dequant_node):
155+
why(node, "Expected weight to have a dequantized node")
154156
return False, []
155157
gemm_deps.append(dequant_node)
156158
weight = get_input_node(dequant_node, 0)
157159
if not is_param_node(ep, weight):
160+
why(node, "Expected weight to be a static param")
158161
return False, []
159162
gemm_deps.append(weight)
160163

164+
if is_per_tensor(dequant_node) and precision == ConfigPrecisionType.DYNAMIC_QUANT:
165+
why(node, "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations")
166+
return False, []
167+
161168
if is_per_channel(dequant_node) or is_per_channel_group(dequant_node):
162169
if len(dequant_node.all_input_nodes) < 2:
163170
# Expected channel quantized to have scale/zp nodes
171+
why(node, "Expected channel quantized to have scale/zp nodes")
164172
return False, []
165173

166174
gemm_deps.extend(dequant_node.all_input_nodes[1:3])
175+
167176
return (True, gemm_deps)
168177

169178
def _get_output_deps(
@@ -174,7 +183,7 @@ def _get_output_deps(
174183
# Look for fused activations and tail end quant node
175184
node_users = list(node.users.keys())
176185
if len(node_users) != 1:
177-
# Expect quantized node to have a single output (fused act or dequant)
186+
why(node, "Expected quantized node to have a single output")
178187
return False, []
179188

180189
# Check if the quantized pattern has a fused activation
@@ -190,6 +199,7 @@ def _get_output_deps(
190199

191200
if not is_quant(n_output):
192201
# Expected gemm_node --> fused_act (optional) --> dequant
202+
why(node, "Expected output node to have a dequantized node")
193203
return (False, [])
194204
gemm_deps.append(n_output)
195205
elif precision == ConfigPrecisionType.FP32:
@@ -219,7 +229,8 @@ def _get_bias_deps(
219229
bias_node = get_input_node(node, self.bias_idx)
220230
if bias_node:
221231
if not is_param_node(ep, bias_node):
222-
return (False, []) # bias node must be a static param
232+
why(node, "Expected bias to be a static param")
233+
return (False, [])
223234
gemm_deps.append(bias_node)
224235

225236
return (True, gemm_deps)
@@ -233,7 +244,7 @@ def _get_act_deps(
233244
else:
234245
dq_input = get_input_node(node, self.act_idx)
235246
if not is_dequant(dq_input):
236-
# Expected static quant input to be dequant node
247+
why(node, "Expected act input to be dequant node")
237248
return False, []
238249
gemm_deps.append(dq_input)
239250
if precision == ConfigPrecisionType.STATIC_QUANT:
@@ -243,27 +254,28 @@ def _get_act_deps(
243254
# q input node
244255
q_input = get_input_node(dq_input, 0)
245256
if not is_quant(q_input):
257+
why(node, "Expected dequant input to be quant node")
246258
return (False, [])
247259

248260
gemm_deps.append(q_input)
249261
q_input_args = q_input.args
250262
if is_affine_qdq(q_input):
251263
q_input_args = extract_qdq_affine_op_args_for_decomposed_ops(q_input)
252264
if not (is_node(q_input_args[1]) and is_node(q_input_args[2])):
253-
# expected to find getitem node from choose qparam
265+
why(node, "expected to find getitem node from choose qparam")
254266
return (False, [])
255267

256268
getitem1 = q_input_args[1]
257269
getitem2 = q_input_args[2]
258270

259271
if not (is_getitem(getitem1) and is_getitem(getitem2)):
260-
# expected getitem node from choose qparam
272+
why(node, "expected getitem node from choose qparam")
261273
return (False, [])
262274

263275
gemm_deps.extend([getitem1, getitem2])
264276
choose_qparam = get_input_node(getitem1, 0)
265277
if not is_qparam(choose_qparam):
266-
# expected to find choose_qparam node
278+
why(node, "expected to find choose_qparam node")
267279
return (False, [])
268280
gemm_deps.append(choose_qparam)
269281
return (True, gemm_deps)
@@ -471,6 +483,7 @@ def find_partition_args(input_node):
471483
# there can only be a single output node in partition
472484
or len(src_partition.output_nodes) != 1
473485
):
486+
why(node, "invalid source partition")
474487
return (False, [])
475488

476489
# map addmm's args to the source partition linear's inputs and users

backends/xnnpack/test/ops/test_linear.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def get_qnode_checks(quant_node_checks, dialect):
520520
# qtol=bool(quant_config), atol=atol
521521
# )
522522

523-
def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
523+
def _test_qd8_linear(self, dtype: torch.dtype = torch.float, is_per_channel:bool=True):
524524
for uses_bias in (False, True):
525525
module = BaseLinear(
526526
in_size=8,
@@ -535,7 +535,7 @@ def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
535535
module,
536536
inputs,
537537
dynamic_shapes=({1: torch.export.Dim("batch", max=100)},),
538-
is_per_channel=True,
538+
is_per_channel=is_per_channel,
539539
uses_bias=uses_bias,
540540
)
541541

@@ -695,11 +695,29 @@ def test_qs8_linear(self):
695695

696696
# Tests for q[dp]8-f16-qc8w
697697
def test_qd8_f16_per_channel_linear(self):
698-
self._test_qd8_per_channel_linear(dtype=torch.half)
698+
self._test_qd8_linear(dtype=torch.half)
699+
700+
@unittest.expectedFailure
701+
def test_qd8_f16_per_tensor_linear(self):
702+
"""
703+
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
704+
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
705+
"""
706+
self._test_qd8_linear(dtype=torch.half, is_per_channel=False)
707+
708+
699709

700710
# Tests for q[dp]8-f32-qc8w
701711
def test_qd8_f32_per_channel_linear(self):
702-
self._test_qd8_per_channel_linear(dtype=torch.float)
712+
self._test_qd8_linear(dtype=torch.float)
713+
714+
@unittest.expectedFailure
715+
def test_qd8_f32_per_tensor_linear(self):
716+
"""
717+
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
718+
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
719+
"""
720+
self._test_qd8_linear(dtype=torch.half, is_per_channel=False)
703721

704722
# Tests for q[dp]8-f16-qc4w
705723
def test_linear_qd8_f16_per_channel_int4(self):

backends/xnnpack/utils/quant_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def is_per_channel(node: torch.fx.Node) -> bool:
8888

8989
return is_per_channel or is_affine_per_channel_group
9090

91+
def is_per_tensor(node: torch.fx.Node) -> bool:
92+
if not (is_quant(node) or is_dequant(node)):
93+
return False
94+
95+
is_per_tensor = "per_tensor" in node.target.__name__ # pyre-ignore
96+
97+
return is_per_tensor and not (is_per_channel(node))
9198

9299
def is_affine_qdq(node: torch.fx.Node) -> bool:
93100
if not (is_quant(node) or is_dequant(node)):

0 commit comments

Comments
 (0)