Skip to content

Commit c64b0e1

Browse files
committed
[ExecuTorch][XNNPACK] Rename linear weight partitioning flag for clarity
Differential Revision: [D70372220](https://our.internmc.facebook.com/intern/diff/D70372220/) ghstack-source-id: 268924243 Pull Request resolved: #8844
1 parent 1bdc2f1 commit c64b0e1

File tree

4 files changed

+18
-39
lines changed

4 files changed

+18
-39
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9696
def _overwrite_precision(self, node: torch.fx.Node):
9797
precision = self._detect_precision(node)
9898
if precision not in self.enabled_precision_types:
99-
# detected precision is not enabled, lets try to partition it as fp32
99+
# detected precision is not enabled, try to partition it as fp32
100100
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
101-
# if only fp32 is enabled, then we can still partition fp32 gemms
101+
# when only fp32 is enabled, then we can still partition fp32 gemms
102102
# even with in a quantized graph
103103
if precision in [
104104
ConfigPrecisionType.STATIC_QUANT,
@@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
107107
precision = ConfigPrecisionType.FP32
108108
logging.info(f"Overwriting precision, partitioning {node} as FP32")
109109
return True, precision
110+
110111
return False, precision
111112

112113
def get_deps(
@@ -124,7 +125,6 @@ def get_deps(
124125
# detected precision but it is either disabled or not supported
125126
why(node, f"Unsupported precision type {precision}")
126127
return (False, [])
127-
_, precision = self._overwrite_precision(node)
128128
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
129129
valid_weight, weight_deps = self._get_weight_deps(node, ep, precision)
130130
valid_act, act_deps = self._get_act_deps(node, ep, precision)
@@ -139,6 +139,11 @@ def _get_weight_deps(
139139
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
140140
) -> Tuple[bool, List[torch.fx.Node]]:
141141
gemm_deps = []
142+
if precision == ConfigPrecisionType.FP32 and self.force_non_static_weights_for_f32_linear:
143+
# if force_non_static_weights_for_f32_linear is enabled, then we
144+
# do not partition the weight node
145+
return (True, gemm_deps)
146+
142147
if precision == ConfigPrecisionType.FP32:
143148
# First find the weight
144149
weight_node = get_input_node(node, self.weight_idx)
@@ -220,8 +225,8 @@ def _get_bias_deps(
220225
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
221226
) -> Tuple[bool, List[torch.fx.Node]]:
222227
gemm_deps = []
223-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
224-
# if force force_fp32_dynamic_linear is enabled, then we
228+
if precision == ConfigPrecisionType.FP32 and self.force_non_static_weights_for_f32_linear:
229+
# if force for_fp32_linear_as_matmul is enabled, then we
225230
# do not partition the weight node
226231
return (True, gemm_deps)
227232

@@ -299,11 +304,6 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
299304
def _get_weight_deps(
300305
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
301306
) -> Tuple[bool, List[torch.fx.Node]]:
302-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
303-
# if force fp32_dynamic_linear is enabled, then we
304-
# do not partition the weight node
305-
return (True, [])
306-
307307
# Since we are in Linear, we may assume that the weights are indeed static.
308308
overwritten_linear_precision, new_precision = self._overwrite_precision(node)
309309
if new_precision == ConfigPrecisionType.FP32 and overwritten_linear_precision:
@@ -403,17 +403,6 @@ def __init__(self, **kwargs):
403403
self.src_partitions = None
404404
self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear]
405405

406-
def _get_weight_deps(
407-
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
408-
) -> Tuple[bool, List[torch.fx.Node]]:
409-
# TODO(maxren, T210537195):
410-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
411-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
412-
# do not partition the weight node
413-
return (True, [])
414-
415-
return super()._get_weight_deps(node, ep, precision)
416-
417406
def get_deps(
418407
self,
419408
node: torch.fx.Node,
@@ -495,11 +484,11 @@ def find_partition_args(input_node):
495484
node.args = old_args
496485
node.users = old_users
497486

498-
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
487+
# When using force_non_static_weights_for_f32_linear, we want to get_deps to overwrite the source partition nodes.
499488
# Else we want to be greedy.
500489
ret_deps = (
501490
list(set(deps) & set(src_partition.nodes))
502-
if self.force_fp32_dynamic_linear
491+
if self.force_non_static_weights_for_f32_linear
503492
else list(set(deps) | set(src_partition.nodes))
504493
)
505494

@@ -522,16 +511,6 @@ def __init__(self, **kwargs):
522511
self.weight_idx = 1
523512
self.act_idx = 0
524513

525-
def _get_weight_deps(
526-
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
527-
) -> Tuple[bool, List[torch.fx.Node]]:
528-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
529-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
530-
# do not partition the weight node
531-
return (True, [])
532-
533-
return super()._get_weight_deps(node, ep, precision)
534-
535514
def supported_precision_types(self):
536515
return [
537516
ConfigPrecisionType.FP32,

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, **kwargs):
4141
super().__init__()
4242
self.enabled_precision_types = self.supported_precision_types()
4343
# Flag used in GEMMConfig()
44-
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
44+
self.force_non_static_weights_for_f32_linear = kwargs.get("force_non_static_weights_for_f32_linear", False)
4545

4646
def get_partition(
4747
self, node: torch.fx.Node, ep: ExportedProgram

backends/xnnpack/test/ops/test_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def test_linear_qd8_as_fp32(self):
892892
},
893893
)
894894

895-
def test_linear_fp32_with_force_as_mm(self):
895+
def test_linear_with_force_non_static_weights_for_f32_linear(self):
896896
def check_signature(
897897
signature: ExportGraphSignature,
898898
force_flag: bool,
@@ -925,7 +925,7 @@ def check_signature(
925925
inputs = module.get_inputs()
926926
tester = Tester(module, inputs).export()
927927
partitioner = XnnpackPartitioner(
928-
force_fp32_dynamic_linear=force_flag
928+
force_non_static_weights_for_f32_linear=force_flag
929929
)
930930
if legacy_mode:
931931
tester.to_edge()

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ def test_fp32_lstm(self):
4343
.run_method_and_compare_outputs()
4444
)
4545

46-
def test_fp32_lstm_force_dynamic_linear(self):
46+
def test_lstm_with_force_non_static_weights_for_f32_linear(self):
4747
(
4848
Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
4949
.export()
5050
.to_edge_transform_and_lower(
5151
ToEdgeTransformAndLower(
52-
partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)]
52+
partitioners=[XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)]
5353
)
5454
)
5555
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5656
# Weights are supplied as input to linears
57-
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
57+
# Biases are not owned by delegates when force_non_static_weights_for_f32_linear is set
5858
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
5959
.to_executorch()
6060
.serialize()

0 commit comments

Comments
 (0)