Skip to content

Commit bd793c3

Browse files
authored
Fix pyre errors
Differential Revision: D79828275 Pull Request resolved: #13202
1 parent 97a3aac commit bd793c3

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

exir/backend/test/demos/rpc/executor_backend_partitioner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import final
99

1010
import torch
11+
import torch.fx
12+
1113
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
1214
generate_pattern_op_partitions,
1315
)
@@ -65,8 +67,9 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
6567
partition_tags[delegation_tag] = self.delegation_spec
6668

6769
# Tag the delegate submodules
68-
if node.args[0].op == "get_attr":
69-
node.args[0].meta["delegation_tag"] = delegation_tag
70+
arg0 = node.args[0]
71+
if isinstance(arg0, torch.fx.Node) and arg0.op == "get_attr":
72+
arg0.meta["delegation_tag"] = delegation_tag
7073

7174
return PartitionResult(
7275
tagged_exported_program=edge_exported_program,

exir/backend/test/test_backends_nested.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,11 @@ def _partition_graph_module(
197197
and node.target is torch.ops.higher_order.cond
198198
):
199199
# Tag the arguments that take in the submodules to cond
200-
node.args[1].meta["delegation_tag"] = delegation_tag
201-
node.args[2].meta["delegation_tag"] = delegation_tag
200+
arg1, arg2 = node.args[1], node.args[2]
201+
if isinstance(arg1, torch.fx.Node):
202+
arg1.meta["delegation_tag"] = delegation_tag
203+
if isinstance(arg2, torch.fx.Node):
204+
arg2.meta["delegation_tag"] = delegation_tag
202205
node.meta["delegation_tag"] = delegation_tag
203206
partition_tags[delegation_tag] = self.delegation_spec
204207
return partition_tags

0 commit comments

Comments
 (0)