Skip to content

Commit d7a47cd

Browse files
committed
[ExecuTorch][XNNPACK] Rename linear weight partitioning flag for clarity
Pull Request resolved: #8844 Differential Revision: [D70372220](https://our.internmc.facebook.com/intern/diff/D70372220/) ghstack-source-id: 269069584
1 parent 1bdc2f1 commit d7a47cd

File tree

4 files changed

+35
-19
lines changed

4 files changed

+35
-19
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9696
def _overwrite_precision(self, node: torch.fx.Node):
9797
precision = self._detect_precision(node)
9898
if precision not in self.enabled_precision_types:
99-
# detected precision is not enabled, lets try to partition it as fp32
99+
# detected precision is not enabled, try to partition it as fp32
100100
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
101-
# if only fp32 is enabled, then we can still partition fp32 gemms
101+
# when only fp32 is enabled, then we can still partition fp32 gemms
102102
# even with in a quantized graph
103103
if precision in [
104104
ConfigPrecisionType.STATIC_QUANT,
@@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
107107
precision = ConfigPrecisionType.FP32
108108
logging.info(f"Overwriting precision, partitioning {node} as FP32")
109109
return True, precision
110+
110111
return False, precision
111112

112113
def get_deps(
@@ -220,8 +221,11 @@ def _get_bias_deps(
220221
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
221222
) -> Tuple[bool, List[torch.fx.Node]]:
222223
gemm_deps = []
223-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
224-
# if force force_fp32_dynamic_linear is enabled, then we
224+
if (
225+
precision == ConfigPrecisionType.FP32
226+
and self.force_non_static_weights_for_f32_linear
227+
):
228+
# if force_non_static_weights_for_f32_linear is enabled, then we
225229
# do not partition the weight node
226230
return (True, gemm_deps)
227231

@@ -299,8 +303,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
299303
def _get_weight_deps(
300304
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
301305
) -> Tuple[bool, List[torch.fx.Node]]:
302-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
303-
# if force fp32_dynamic_linear is enabled, then we
306+
if (
307+
precision == ConfigPrecisionType.FP32
308+
and self.force_non_static_weights_for_f32_linear
309+
):
310+
# if force_non_static_weights_for_f32_linear is enabled, then we
304311
# do not partition the weight node
305312
return (True, [])
306313

@@ -406,9 +413,11 @@ def __init__(self, **kwargs):
406413
def _get_weight_deps(
407414
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
408415
) -> Tuple[bool, List[torch.fx.Node]]:
409-
# TODO(maxren, T210537195):
410-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
411-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
416+
if (
417+
precision == ConfigPrecisionType.FP32
418+
and self.force_non_static_weights_for_f32_linear
419+
):
420+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
412421
# do not partition the weight node
413422
return (True, [])
414423

@@ -495,11 +504,11 @@ def find_partition_args(input_node):
495504
node.args = old_args
496505
node.users = old_users
497506

498-
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
507+
# When using force_non_static_weights_for_f32_linear, we want to get_deps to overwrite the source partition nodes.
499508
# Else we want to be greedy.
500509
ret_deps = (
501510
list(set(deps) & set(src_partition.nodes))
502-
if self.force_fp32_dynamic_linear
511+
if self.force_non_static_weights_for_f32_linear
503512
else list(set(deps) | set(src_partition.nodes))
504513
)
505514

@@ -525,8 +534,11 @@ def __init__(self, **kwargs):
525534
def _get_weight_deps(
526535
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
527536
) -> Tuple[bool, List[torch.fx.Node]]:
528-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
529-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
537+
if (
538+
precision == ConfigPrecisionType.FP32
539+
and self.force_non_static_weights_for_f32_linear
540+
):
541+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
530542
# do not partition the weight node
531543
return (True, [])
532544

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self, **kwargs):
4141
super().__init__()
4242
self.enabled_precision_types = self.supported_precision_types()
4343
# Flag used in GEMMConfig()
44-
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
44+
self.force_non_static_weights_for_f32_linear = kwargs.get(
45+
"force_non_static_weights_for_f32_linear", False
46+
)
4547

4648
def get_partition(
4749
self, node: torch.fx.Node, ep: ExportedProgram

backends/xnnpack/test/ops/test_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def test_linear_qd8_as_fp32(self):
892892
},
893893
)
894894

895-
def test_linear_fp32_with_force_as_mm(self):
895+
def test_linear_with_force_non_static_weights_for_f32_linear(self):
896896
def check_signature(
897897
signature: ExportGraphSignature,
898898
force_flag: bool,
@@ -925,7 +925,7 @@ def check_signature(
925925
inputs = module.get_inputs()
926926
tester = Tester(module, inputs).export()
927927
partitioner = XnnpackPartitioner(
928-
force_fp32_dynamic_linear=force_flag
928+
force_non_static_weights_for_f32_linear=force_flag
929929
)
930930
if legacy_mode:
931931
tester.to_edge()

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,20 @@ def test_fp32_lstm(self):
4343
.run_method_and_compare_outputs()
4444
)
4545

46-
def test_fp32_lstm_force_dynamic_linear(self):
46+
def test_lstm_with_force_non_static_weights_for_f32_linear(self):
4747
(
4848
Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
4949
.export()
5050
.to_edge_transform_and_lower(
5151
ToEdgeTransformAndLower(
52-
partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)]
52+
partitioners=[
53+
XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)
54+
]
5355
)
5456
)
5557
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5658
# Weights are supplied as input to linears
57-
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
59+
# Biases are not owned by delegates when force_non_static_weights_for_f32_linear is set
5860
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
5961
.to_executorch()
6062
.serialize()

0 commit comments

Comments
 (0)