From 8812c6597fc8fcf19ae19679ad114d22524ee60a Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Fri, 28 Feb 2025 12:16:19 -0800 Subject: [PATCH 1/2] [ExecuTorch][XNNPACK] Rename linear weight partitioning flag for clarity Differential Revision: [D70372220](https://our.internmc.facebook.com/intern/diff/D70372220/) [ghstack-poisoned] --- .../xnnpack/partition/config/gemm_configs.py | 45 +++++-------------- .../partition/config/xnnpack_config.py | 2 +- backends/xnnpack/test/ops/test_linear.py | 4 +- backends/xnnpack/test/ops/test_lstm.py | 6 +-- 4 files changed, 18 insertions(+), 39 deletions(-) diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index e5cc2506e1f..f6ffa7ba6eb 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType: def _overwrite_precision(self, node: torch.fx.Node): precision = self._detect_precision(node) if precision not in self.enabled_precision_types: - # detected precision is not enabled, lets try to partition it as fp32 + # detected precision is not enabled, try to partition it as fp32 if self.enabled_precision_types == [ConfigPrecisionType.FP32]: - # if only fp32 is enabled, then we can still partition fp32 gemms + # when only fp32 is enabled, then we can still partition fp32 gemms # even with in a quantized graph if precision in [ ConfigPrecisionType.STATIC_QUANT, @@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node): precision = ConfigPrecisionType.FP32 logging.info(f"Overwriting precision, partitioning {node} as FP32") return True, precision + return False, precision def get_deps( @@ -124,7 +125,6 @@ def get_deps( # detected precision but it is either disabled or not supported why(node, f"Unsupported precision type {precision}") return (False, []) - _, precision = self._overwrite_precision(node) valid_bias, bias_deps = self._get_bias_deps(node, ep, precision) valid_weight, weight_deps = self._get_weight_deps(node, ep, precision) valid_act, act_deps = self._get_act_deps(node, ep, precision) @@ -139,6 +139,11 @@ def _get_weight_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: gemm_deps = [] + if precision == ConfigPrecisionType.FP32 and self.force_non_static_weights_for_f32_linear: + # if force_non_static_weights_for_f32_linear is enabled, then we + # do not partition the weight node + return (True, gemm_deps) + if precision == ConfigPrecisionType.FP32: # First find the weight weight_node = get_input_node(node, self.weight_idx) @@ -220,8 +225,8 @@ def _get_bias_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: gemm_deps = [] - if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force force_fp32_dynamic_linear is enabled, then we + if precision == ConfigPrecisionType.FP32 and self.force_non_static_weights_for_f32_linear: + # if force for_fp32_linear_as_matmul is enabled, then we # do not partition the weight node return (True, gemm_deps) @@ -299,11 +304,6 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]: def _get_weight_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: - if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force fp32_dynamic_linear is enabled, then we - # do not partition the weight node - return (True, []) - # Since we are in Linear, we may assume that the weights are indeed static. overwritten_linear_precision, new_precision = self._overwrite_precision(node) if new_precision == ConfigPrecisionType.FP32 and overwritten_linear_precision: @@ -403,17 +403,6 @@ def __init__(self, **kwargs): self.src_partitions = None self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear] - def _get_weight_deps( - self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType - ) -> Tuple[bool, List[torch.fx.Node]]: - # TODO(maxren, T210537195): - if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force fp32_dynamic_linear is on and we detected this as fp32, then we - # do not partition the weight node - return (True, []) - - return super()._get_weight_deps(node, ep, precision) - def get_deps( self, node: torch.fx.Node, @@ -495,11 +484,11 @@ def find_partition_args(input_node): node.args = old_args node.users = old_users - # When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes. + # When using force_non_static_weights_for_f32_linear, we want to get_deps to overwrite the source partition nodes. # Else we want to be greedy. ret_deps = ( list(set(deps) & set(src_partition.nodes)) - if self.force_fp32_dynamic_linear + if self.force_non_static_weights_for_f32_linear else list(set(deps) | set(src_partition.nodes)) ) @@ -522,16 +511,6 @@ def __init__(self, **kwargs): self.weight_idx = 1 self.act_idx = 0 - def _get_weight_deps( - self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType - ) -> Tuple[bool, List[torch.fx.Node]]: - if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force fp32_dynamic_linear is on and we detected this as fp32, then we - # do not partition the weight node - return (True, []) - - return super()._get_weight_deps(node, ep, precision) - def supported_precision_types(self): return [ ConfigPrecisionType.FP32, diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index d261416a76f..7364aa0fb2c 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -41,7 +41,7 @@ def __init__(self, **kwargs): super().__init__() self.enabled_precision_types = self.supported_precision_types() # Flag used in GEMMConfig() - self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False) + self.force_non_static_weights_for_f32_linear = kwargs.get("force_non_static_weights_for_f32_linear", False) def get_partition( self, node: torch.fx.Node, ep: ExportedProgram diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index f7f5f754eba..1f65288c447 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -892,7 +892,7 @@ def test_linear_qd8_as_fp32(self): }, ) - def test_linear_fp32_with_force_as_mm(self): + def test_linear_with_force_non_static_weights_for_f32_linear(self): def check_signature( signature: ExportGraphSignature, force_flag: bool, @@ -925,7 +925,7 @@ def check_signature( inputs = module.get_inputs() tester = Tester(module, inputs).export() partitioner = XnnpackPartitioner( - force_fp32_dynamic_linear=force_flag + force_non_static_weights_for_f32_linear=force_flag ) if legacy_mode: tester.to_edge() diff --git a/backends/xnnpack/test/ops/test_lstm.py b/backends/xnnpack/test/ops/test_lstm.py index be209082b37..5192e8f1350 100644 --- a/backends/xnnpack/test/ops/test_lstm.py +++ b/backends/xnnpack/test/ops/test_lstm.py @@ -43,18 +43,18 @@ def test_fp32_lstm(self): .run_method_and_compare_outputs() ) - def test_fp32_lstm_force_dynamic_linear(self): + def test_lstm_with_force_non_static_weights_for_f32_linear(self): ( Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),)) .export() .to_edge_transform_and_lower( ToEdgeTransformAndLower( - partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)] + partitioners=[XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)] ) ) .check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"]) # Weights are supplied as input to linears - # Biases are not owned by delegates when force_fp32_dynamic_linear is set + # Biases are not owned by delegates when force_non_static_weights_for_f32_linear is set .check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"]) .to_executorch() .serialize() From 6c9249514258b0683980d796d5db1773f0ddddc8 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Fri, 28 Feb 2025 13:04:59 -0800 Subject: [PATCH 2/2] Update on "[ExecuTorch][XNNPACK] Rename linear weight partitioning flag for clarity" Differential Revision: [D70372220](https://our.internmc.facebook.com/intern/diff/D70372220/) [ghstack-poisoned] --- .../xnnpack/partition/config/gemm_configs.py | 47 ++++++++++++++++--- .../partition/config/xnnpack_config.py | 4 +- backends/xnnpack/test/ops/test_lstm.py | 4 +- 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index f6ffa7ba6eb..03d12e3cf24 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -125,6 +125,7 @@ def get_deps( # detected precision but it is either disabled or not supported why(node, f"Unsupported precision type {precision}") return (False, []) + _, precision = self._overwrite_precision(node) valid_bias, bias_deps = self._get_bias_deps(node, ep, precision) valid_weight, weight_deps = self._get_weight_deps(node, ep, precision) valid_act, act_deps = self._get_act_deps(node, ep, precision) @@ -139,11 +140,6 @@ def _get_weight_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: gemm_deps = [] - if precision == ConfigPrecisionType.FP32 and self.force_non_static_weights_for_f32_linear: - # if force_non_static_weights_for_f32_linear is enabled, then we - # do not partition the weight node - return (True, gemm_deps) - if precision == ConfigPrecisionType.FP32: # First find the weight weight_node = get_input_node(node, self.weight_idx) @@ -225,8 +221,11 @@ def _get_bias_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: gemm_deps = [] - if precision == ConfigPrecisionType.FP32 and self.force_non_static_weights_for_f32_linear: - # if force for_fp32_linear_as_matmul is enabled, then we + if ( + precision == ConfigPrecisionType.FP32 + and self.force_non_static_weights_for_f32_linear + ): + # if force_non_static_weights_for_f32_linear is enabled, then we # do not partition the weight node return (True, gemm_deps) @@ -304,6 +303,14 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]: def _get_weight_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: + if ( + precision == ConfigPrecisionType.FP32 + and self.force_non_static_weights_for_f32_linear + ): + # if force_non_static_weights_for_f32_linear is enabled, then we + # do not partition the weight node + return (True, []) + # Since we are in Linear, we may assume that the weights are indeed static. overwritten_linear_precision, new_precision = self._overwrite_precision(node) if new_precision == ConfigPrecisionType.FP32 and overwritten_linear_precision: @@ -403,6 +410,19 @@ def __init__(self, **kwargs): self.src_partitions = None self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear] + def _get_weight_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + if ( + precision == ConfigPrecisionType.FP32 + and self.force_non_static_weights_for_f32_linear + ): + # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we + # do not partition the weight node + return (True, []) + + return super()._get_weight_deps(node, ep, precision) + def get_deps( self, node: torch.fx.Node, @@ -511,6 +531,19 @@ def __init__(self, **kwargs): self.weight_idx = 1 self.act_idx = 0 + def _get_weight_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + if ( + precision == ConfigPrecisionType.FP32 + and self.force_non_static_weights_for_f32_linear + ): + # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we + # do not partition the weight node + return (True, []) + + return super()._get_weight_deps(node, ep, precision) + def supported_precision_types(self): return [ ConfigPrecisionType.FP32, diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 7364aa0fb2c..20018610fce 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -41,7 +41,9 @@ def __init__(self, **kwargs): super().__init__() self.enabled_precision_types = self.supported_precision_types() # Flag used in GEMMConfig() - self.force_non_static_weights_for_f32_linear = kwargs.get("force_non_static_weights_for_f32_linear", False) + self.force_non_static_weights_for_f32_linear = kwargs.get( + "force_non_static_weights_for_f32_linear", False + ) def get_partition( self, node: torch.fx.Node, ep: ExportedProgram diff --git a/backends/xnnpack/test/ops/test_lstm.py b/backends/xnnpack/test/ops/test_lstm.py index 5192e8f1350..6c174b16f33 100644 --- a/backends/xnnpack/test/ops/test_lstm.py +++ b/backends/xnnpack/test/ops/test_lstm.py @@ -49,7 +49,9 @@ def test_lstm_with_force_non_static_weights_for_f32_linear(self): .export() .to_edge_transform_and_lower( ToEdgeTransformAndLower( - partitioners=[XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)] + partitioners=[ + XnnpackPartitioner(force_non_static_weights_for_f32_linear=True) + ] ) ) .check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])