Skip to content

Commit 6c92495

Browse files
committed
Update on "[ExecuTorch][XNNPACK] Rename linear weight partitioning flag for clarity"
Differential Revision: [D70372220](https://our.internmc.facebook.com/intern/diff/D70372220/) [ghstack-poisoned]
1 parent 8812c65 commit 6c92495

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def get_deps(
125125
# detected precision but it is either disabled or not supported
126126
why(node, f"Unsupported precision type {precision}")
127127
return (False, [])
128+
_, precision = self._overwrite_precision(node)
128129
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
129130
valid_weight, weight_deps = self._get_weight_deps(node, ep, precision)
130131
valid_act, act_deps = self._get_act_deps(node, ep, precision)
@@ -139,11 +140,6 @@ def _get_weight_deps(
139140
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
140141
) -> Tuple[bool, List[torch.fx.Node]]:
141142
gemm_deps = []
142-
if precision == ConfigPrecisionType.FP32 and self.force_non_static_weights_for_f32_linear:
143-
# if force_non_static_weights_for_f32_linear is enabled, then we
144-
# do not partition the weight node
145-
return (True, gemm_deps)
146-
147143
if precision == ConfigPrecisionType.FP32:
148144
# First find the weight
149145
weight_node = get_input_node(node, self.weight_idx)
@@ -225,8 +221,11 @@ def _get_bias_deps(
225221
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
226222
) -> Tuple[bool, List[torch.fx.Node]]:
227223
gemm_deps = []
228-
if precision == ConfigPrecisionType.FP32 and self.force_non_static_weights_for_f32_linear:
229-
# if force for_fp32_linear_as_matmul is enabled, then we
224+
if (
225+
precision == ConfigPrecisionType.FP32
226+
and self.force_non_static_weights_for_f32_linear
227+
):
228+
# if force_non_static_weights_for_f32_linear is enabled, then we
230229
# do not partition the weight node
231230
return (True, gemm_deps)
232231

@@ -304,6 +303,14 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
304303
def _get_weight_deps(
305304
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
306305
) -> Tuple[bool, List[torch.fx.Node]]:
306+
if (
307+
precision == ConfigPrecisionType.FP32
308+
and self.force_non_static_weights_for_f32_linear
309+
):
310+
# if force_non_static_weights_for_f32_linear is enabled, then we
311+
# do not partition the weight node
312+
return (True, [])
313+
307314
# Since we are in Linear, we may assume that the weights are indeed static.
308315
overwritten_linear_precision, new_precision = self._overwrite_precision(node)
309316
if new_precision == ConfigPrecisionType.FP32 and overwritten_linear_precision:
@@ -403,6 +410,19 @@ def __init__(self, **kwargs):
403410
self.src_partitions = None
404411
self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear]
405412

413+
def _get_weight_deps(
414+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
415+
) -> Tuple[bool, List[torch.fx.Node]]:
416+
if (
417+
precision == ConfigPrecisionType.FP32
418+
and self.force_non_static_weights_for_f32_linear
419+
):
420+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
421+
# do not partition the weight node
422+
return (True, [])
423+
424+
return super()._get_weight_deps(node, ep, precision)
425+
406426
def get_deps(
407427
self,
408428
node: torch.fx.Node,
@@ -511,6 +531,19 @@ def __init__(self, **kwargs):
511531
self.weight_idx = 1
512532
self.act_idx = 0
513533

534+
def _get_weight_deps(
535+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
536+
) -> Tuple[bool, List[torch.fx.Node]]:
537+
if (
538+
precision == ConfigPrecisionType.FP32
539+
and self.force_non_static_weights_for_f32_linear
540+
):
541+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
542+
# do not partition the weight node
543+
return (True, [])
544+
545+
return super()._get_weight_deps(node, ep, precision)
546+
514547
def supported_precision_types(self):
515548
return [
516549
ConfigPrecisionType.FP32,

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self, **kwargs):
4141
super().__init__()
4242
self.enabled_precision_types = self.supported_precision_types()
4343
# Flag used in GEMMConfig()
44-
self.force_non_static_weights_for_f32_linear = kwargs.get("force_non_static_weights_for_f32_linear", False)
44+
self.force_non_static_weights_for_f32_linear = kwargs.get(
45+
"force_non_static_weights_for_f32_linear", False
46+
)
4547

4648
def get_partition(
4749
self, node: torch.fx.Node, ep: ExportedProgram

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def test_lstm_with_force_non_static_weights_for_f32_linear(self):
4949
.export()
5050
.to_edge_transform_and_lower(
5151
ToEdgeTransformAndLower(
52-
partitioners=[XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)]
52+
partitioners=[
53+
XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)
54+
]
5355
)
5456
)
5557
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])

0 commit comments

Comments
 (0)