Skip to content

Commit 35ba5d5

Browse files
committed
[ExecuTorch][XNNPACK] Don't partition per_tensor weights with qd8
This is not supported, so we shouldn't partition it. Add an expectedFailure test to indicate that this is not supported. Differential Revision: [D70343584](https://our.internmc.facebook.com/intern/diff/D70343584/) ghstack-source-id: 268822901 Pull Request resolved: #8787
1 parent 09b592b commit 35ba5d5

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
is_dynamic_qdq,
2222
is_per_channel,
2323
is_per_channel_group,
24+
is_per_tensor,
2425
is_qparam,
2526
is_quant,
2627
)
@@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
6667
return False
6768

6869
is_valid, _ = self.get_deps(node, ep)
69-
if not is_valid:
70-
why(node, "Failed to get valid dependent nodes.")
7170
return is_valid
7271

7372
def get_node_and_deps(
@@ -123,6 +122,7 @@ def get_deps(
123122
precision = self._detect_precision(node)
124123
if precision not in self.supported_precision_types():
125124
# detected precision but it is either disabled or not supported
125+
why(node, f"Unsupported precision type {precision}")
126126
return (False, [])
127127
_, precision = self._overwrite_precision(node)
128128
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
@@ -143,26 +143,34 @@ def _get_weight_deps(
143143
# First find the weight
144144
weight_node = get_input_node(node, self.weight_idx)
145145
if not is_param_node(ep, weight_node):
146-
return (False, []) # weight must be a static param
146+
why(node, "Expected weight to be a static param")
147+
return (False, [])
147148
gemm_deps.append(weight_node)
148149

149150
return (True, gemm_deps)
150151
else:
151152
# Quantized Weight deps
152153
dequant_node = get_input_node(node, self.weight_idx)
153154
if not is_dequant(dequant_node):
155+
why(node, "Expected weight to have a dequantized node")
154156
return False, []
155157
gemm_deps.append(dequant_node)
156158
weight = get_input_node(dequant_node, 0)
157159
if not is_param_node(ep, weight):
160+
why(node, "Expected weight to be a static param")
158161
return False, []
159162
gemm_deps.append(weight)
160163

161164
if is_per_channel(dequant_node) or is_per_channel_group(dequant_node):
162165
if len(dequant_node.all_input_nodes) < 2:
163166
# Expected channel quantized to have scale/zp nodes
167+
why(node, "Expected channel quantized to have scale/zp nodes")
164168
return False, []
165169

170+
if is_per_tensor(dequant_node) and precision == ConfigPrecisionType.DYNAMIC_QUANT:
171+
why(node, "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations")
172+
return False, []
173+
166174
gemm_deps.extend(dequant_node.all_input_nodes[1:3])
167175
return (True, gemm_deps)
168176

@@ -174,7 +182,7 @@ def _get_output_deps(
174182
# Look for fused activations and tail end quant node
175183
node_users = list(node.users.keys())
176184
if len(node_users) != 1:
177-
# Expect quantized node to have a single output (fused act or dequant)
185+
why(node, "Expected quantized node to have a single output")
178186
return False, []
179187

180188
# Check if the quantized pattern has a fused activation
@@ -190,6 +198,7 @@ def _get_output_deps(
190198

191199
if not is_quant(n_output):
192200
# Expected gemm_node --> fused_act (optional) --> dequant
201+
why(node, "Expected output node to have a dequantized node")
193202
return (False, [])
194203
gemm_deps.append(n_output)
195204
elif precision == ConfigPrecisionType.FP32:
@@ -219,7 +228,8 @@ def _get_bias_deps(
219228
bias_node = get_input_node(node, self.bias_idx)
220229
if bias_node:
221230
if not is_param_node(ep, bias_node):
222-
return (False, []) # bias node must be a static param
231+
why(node, "Expected bias to be a static param")
232+
return (False, [])
223233
gemm_deps.append(bias_node)
224234

225235
return (True, gemm_deps)
@@ -233,7 +243,7 @@ def _get_act_deps(
233243
else:
234244
dq_input = get_input_node(node, self.act_idx)
235245
if not is_dequant(dq_input):
236-
# Expected static quant input to be dequant node
246+
why(node, "Expected act input to be dequant node")
237247
return False, []
238248
gemm_deps.append(dq_input)
239249
if precision == ConfigPrecisionType.STATIC_QUANT:
@@ -243,27 +253,28 @@ def _get_act_deps(
243253
# q input node
244254
q_input = get_input_node(dq_input, 0)
245255
if not is_quant(q_input):
256+
why(node, "Expected dequant input to be quant node")
246257
return (False, [])
247258

248259
gemm_deps.append(q_input)
249260
q_input_args = q_input.args
250261
if is_affine_qdq(q_input):
251262
q_input_args = extract_qdq_affine_op_args_for_decomposed_ops(q_input)
252263
if not (is_node(q_input_args[1]) and is_node(q_input_args[2])):
253-
# expected to find getitem node from choose qparam
264+
why(node, "expected to find getitem node from choose qparam")
254265
return (False, [])
255266

256267
getitem1 = q_input_args[1]
257268
getitem2 = q_input_args[2]
258269

259270
if not (is_getitem(getitem1) and is_getitem(getitem2)):
260-
# expected getitem node from choose qparam
271+
why(node, "expected getitem node from choose qparam")
261272
return (False, [])
262273

263274
gemm_deps.extend([getitem1, getitem2])
264275
choose_qparam = get_input_node(getitem1, 0)
265276
if not is_qparam(choose_qparam):
266-
# expected to find choose_qparam node
277+
why(node, "expected to find choose_qparam node")
267278
return (False, [])
268279
gemm_deps.append(choose_qparam)
269280
return (True, gemm_deps)
@@ -471,6 +482,7 @@ def find_partition_args(input_node):
471482
# there can only be a single output node in partition
472483
or len(src_partition.output_nodes) != 1
473484
):
485+
why(node, "invalid source partition")
474486
return (False, [])
475487

476488
# map addmm's args to the source partition linear's inputs and users

backends/xnnpack/test/ops/test_linear.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def get_qnode_checks(quant_node_checks, dialect):
520520
# qtol=bool(quant_config), atol=atol
521521
# )
522522

523-
def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
523+
def _test_qd8_linear(self, dtype: torch.dtype = torch.float, is_per_channel:bool=True):
524524
for uses_bias in (False, True):
525525
module = BaseLinear(
526526
in_size=8,
@@ -535,7 +535,7 @@ def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
535535
module,
536536
inputs,
537537
dynamic_shapes=({1: torch.export.Dim("batch", max=100)},),
538-
is_per_channel=True,
538+
is_per_channel=is_per_channel,
539539
uses_bias=uses_bias,
540540
)
541541

@@ -695,11 +695,29 @@ def test_qs8_linear(self):
695695

696696
# Tests for q[dp]8-f16-qc8w
697697
def test_qd8_f16_per_channel_linear(self):
698-
self._test_qd8_per_channel_linear(dtype=torch.half)
698+
self._test_qd8_linear(dtype=torch.half)
699+
700+
@unittest.expectedFailure
701+
def test_qd8_f16_per_tensor_linear(self):
702+
"""
703+
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
704+
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
705+
"""
706+
self._test_qd8_linear(dtype=torch.half, is_per_channel=False)
707+
708+
699709

700710
# Tests for q[dp]8-f32-qc8w
701711
def test_qd8_f32_per_channel_linear(self):
702-
self._test_qd8_per_channel_linear(dtype=torch.float)
712+
self._test_qd8_linear(dtype=torch.float)
713+
714+
@unittest.expectedFailure
715+
def test_qd8_f32_per_tensor_linear(self):
716+
"""
717+
XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op.
718+
This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights.
719+
"""
720+
self._test_qd8_linear(dtype=torch.half, is_per_channel=False)
703721

704722
# Tests for q[dp]8-f16-qc4w
705723
def test_linear_qd8_f16_per_channel_int4(self):

backends/xnnpack/utils/quant_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def is_per_channel(node: torch.fx.Node) -> bool:
8888

8989
return is_per_channel or is_affine_per_channel_group
9090

91+
def is_per_tensor(node: torch.fx.Node) -> bool:
92+
if not (is_quant(node) or is_dequant(node)):
93+
return False
94+
95+
is_per_tensor = "per_tensor" in node.target.__name__ # pyre-ignore
96+
97+
return is_per_tensor and not (is_per_channel(node))
9198

9299
def is_affine_qdq(node: torch.fx.Node) -> bool:
93100
if not (is_quant(node) or is_dequant(node)):

0 commit comments

Comments
 (0)