From 4a2950047f01c8083118d7ecf2d2b6a71476e2e0 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 27 Feb 2025 12:45:08 -0800 Subject: [PATCH 1/2] [ExecuTorch][XNNPACK] Don't partition per_tensor weights with qd8 This is not supported, so we shouldn't partition it. Add an expectedFailure test to indicate that this is not supported. Differential Revision: [D70343584](https://our.internmc.facebook.com/intern/diff/D70343584/) [ghstack-poisoned] --- .../xnnpack/partition/config/gemm_configs.py | 30 +++++++++++++------ backends/xnnpack/test/ops/test_linear.py | 26 +++++++++++++--- backends/xnnpack/utils/quant_utils.py | 7 +++++ 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index 872ba355c70..45dc3b23d56 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -21,6 +21,7 @@ is_dynamic_qdq, is_per_channel, is_per_channel_group, + is_per_tensor, is_qparam, is_quant, ) @@ -66,8 +67,6 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False is_valid, _ = self.get_deps(node, ep) - if not is_valid: - why(node, "Failed to get valid dependent nodes.") return is_valid def get_node_and_deps( @@ -123,6 +122,7 @@ def get_deps( precision = self._detect_precision(node) if precision not in self.supported_precision_types(): # detected precision but it is either disabled or not supported + why(node, f"Unsupported precision type {precision}") return (False, []) _, precision = self._overwrite_precision(node) valid_bias, bias_deps = self._get_bias_deps(node, ep, precision) @@ -143,7 +143,8 @@ def _get_weight_deps( # First find the weight weight_node = get_input_node(node, self.weight_idx) if not is_param_node(ep, weight_node): - return (False, []) # weight must be a static param + why(node, "Expected weight to be a static param") + return (False, []) gemm_deps.append(weight_node) return (True, gemm_deps) @@ -151,18 +152,25 @@ def _get_weight_deps( # Quantized Weight deps dequant_node = get_input_node(node, self.weight_idx) if not is_dequant(dequant_node): + why(node, "Expected weight to have a dequantized node") return False, [] gemm_deps.append(dequant_node) weight = get_input_node(dequant_node, 0) if not is_param_node(ep, weight): + why(node, "Expected weight to be a static param") return False, [] gemm_deps.append(weight) if is_per_channel(dequant_node) or is_per_channel_group(dequant_node): if len(dequant_node.all_input_nodes) < 2: # Expected channel quantized to have scale/zp nodes + why(node, "Expected channel quantized to have scale/zp nodes") return False, [] + if is_per_tensor(dequant_node) and precision == ConfigPrecisionType.DYNAMIC_QUANT: + why(node, "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations") + return False, [] + gemm_deps.extend(dequant_node.all_input_nodes[1:3]) return (True, gemm_deps) @@ -174,7 +182,7 @@ def _get_output_deps( # Look for fused activations and tail end quant node node_users = list(node.users.keys()) if len(node_users) != 1: - # Expect quantized node to have a single output (fused act or dequant) + why(node, "Expected quantized node to have a single output") return False, [] # Check if the quantized pattern has a fused activation @@ -190,6 +198,7 @@ def _get_output_deps( if not is_quant(n_output): # Expected gemm_node --> fused_act (optional) --> dequant + why(node, "Expected output node to have a dequantized node") return (False, []) gemm_deps.append(n_output) elif precision == ConfigPrecisionType.FP32: @@ -219,7 +228,8 @@ def _get_bias_deps( bias_node = get_input_node(node, self.bias_idx) if bias_node: if not is_param_node(ep, bias_node): - return (False, []) # bias node must be a static param + why(node, "Expected bias to be a static param") + return (False, []) gemm_deps.append(bias_node) return (True, gemm_deps) @@ -233,7 +243,7 @@ def _get_act_deps( else: dq_input = get_input_node(node, self.act_idx) if not is_dequant(dq_input): - # Expected static quant input to be dequant node + why(node, "Expected act input to be dequant node") return False, [] gemm_deps.append(dq_input) if precision == ConfigPrecisionType.STATIC_QUANT: @@ -243,6 +253,7 @@ def _get_act_deps( # q input node q_input = get_input_node(dq_input, 0) if not is_quant(q_input): + why(node, "Expected dequant input to be quant node") return (False, []) gemm_deps.append(q_input) @@ -250,20 +261,20 @@ def _get_act_deps( if is_affine_qdq(q_input): q_input_args = extract_qdq_affine_op_args_for_decomposed_ops(q_input) if not (is_node(q_input_args[1]) and is_node(q_input_args[2])): - # expected to find getitem node from choose qparam + why(node, "expected to find getitem node from choose qparam") return (False, []) getitem1 = q_input_args[1] getitem2 = q_input_args[2] if not (is_getitem(getitem1) and is_getitem(getitem2)): - # expected getitem node from choose qparam + why(node, "expected getitem node from choose qparam") return (False, []) gemm_deps.extend([getitem1, getitem2]) choose_qparam = get_input_node(getitem1, 0) if not is_qparam(choose_qparam): - # expected to find choose_qparam node + why(node, "expected to find choose_qparam node") return (False, []) gemm_deps.append(choose_qparam) return (True, gemm_deps) @@ -471,6 +482,7 @@ def find_partition_args(input_node): # there can only be a single output node in partition or len(src_partition.output_nodes) != 1 ): + why(node, "invalid source partition") return (False, []) # map addmm's args to the source partition linear's inputs and users diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 30bb4f0aba2..f7f5f754eba 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -520,7 +520,7 @@ def get_qnode_checks(quant_node_checks, dialect): # qtol=bool(quant_config), atol=atol # ) - def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float): + def _test_qd8_linear(self, dtype: torch.dtype = torch.float, is_per_channel:bool=True): for uses_bias in (False, True): module = BaseLinear( in_size=8, @@ -535,7 +535,7 @@ def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float): module, inputs, dynamic_shapes=({1: torch.export.Dim("batch", max=100)},), - is_per_channel=True, + is_per_channel=is_per_channel, uses_bias=uses_bias, ) @@ -695,11 +695,29 @@ def test_qs8_linear(self): # Tests for q[dp]8-f16-qc8w def test_qd8_f16_per_channel_linear(self): - self._test_qd8_per_channel_linear(dtype=torch.half) + self._test_qd8_linear(dtype=torch.half) + + @unittest.expectedFailure + def test_qd8_f16_per_tensor_linear(self): + """ + XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op. + This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights. + """ + self._test_qd8_linear(dtype=torch.half, is_per_channel=False) + + # Tests for q[dp]8-f32-qc8w def test_qd8_f32_per_channel_linear(self): - self._test_qd8_per_channel_linear(dtype=torch.float) + self._test_qd8_linear(dtype=torch.float) + + @unittest.expectedFailure + def test_qd8_f32_per_tensor_linear(self): + """ + XNNPACK doesn't support per_tensor quantized weights for dynamic quantized linear op. + This test is to verify that we can't lower per_tensor quantized weights to per_channel quantized weights. + """ + self._test_qd8_linear(dtype=torch.half, is_per_channel=False) # Tests for q[dp]8-f16-qc4w def test_linear_qd8_f16_per_channel_int4(self): diff --git a/backends/xnnpack/utils/quant_utils.py b/backends/xnnpack/utils/quant_utils.py index 7c035757a6f..affb1d56509 100644 --- a/backends/xnnpack/utils/quant_utils.py +++ b/backends/xnnpack/utils/quant_utils.py @@ -88,6 +88,13 @@ def is_per_channel(node: torch.fx.Node) -> bool: return is_per_channel or is_affine_per_channel_group +def is_per_tensor(node: torch.fx.Node) -> bool: + if not (is_quant(node) or is_dequant(node)): + return False + + is_per_tensor = "per_tensor" in node.target.__name__ # pyre-ignore + + return is_per_tensor and not (is_per_channel(node)) def is_affine_qdq(node: torch.fx.Node) -> bool: if not (is_quant(node) or is_dequant(node)): From 399b3a27327229985859a52e45baea959b1ea9cc Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 27 Feb 2025 17:14:05 -0800 Subject: [PATCH 2/2] Update on "[ExecuTorch][XNNPACK] Don't partition per_tensor weights with qd8" This is not supported, so we shouldn't partition it. Add an expectedFailure test to indicate that this is not supported. Differential Revision: [D70343584](https://our.internmc.facebook.com/intern/diff/D70343584/) [ghstack-poisoned] --- backends/xnnpack/partition/config/gemm_configs.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index 45dc3b23d56..e5cc2506e1f 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -161,17 +161,18 @@ def _get_weight_deps( return False, [] gemm_deps.append(weight) + if is_per_tensor(dequant_node) and precision == ConfigPrecisionType.DYNAMIC_QUANT: + why(node, "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations") + return False, [] + if is_per_channel(dequant_node) or is_per_channel_group(dequant_node): if len(dequant_node.all_input_nodes) < 2: # Expected channel quantized to have scale/zp nodes why(node, "Expected channel quantized to have scale/zp nodes") return False, [] - if is_per_tensor(dequant_node) and precision == ConfigPrecisionType.DYNAMIC_QUANT: - why(node, "XNNPACK does not support per tensor quantized weights for dynamic quantization of activations") - return False, [] - gemm_deps.extend(dequant_node.all_input_nodes[1:3]) + return (True, gemm_deps) def _get_output_deps(