Skip to content

Commit b2ef921

Browse files
cccclaifacebook-github-bot
authored andcommitted
Wrap partition tags as part of partition result (#269)
Summary: Pull Request resolved: #269 `partition_tags` is an output from the `partition` function and it'll be cleaner to wrap it as part of the outputs instead of the instance attributes. This diff is supposed be a noop Reviewed By: mergennachin, angelayi Differential Revision: D49116131 fbshipit-source-id: 0943c6367e2aa902e28af7f058bd2904f871beac
1 parent 1909b32 commit b2ef921

File tree

13 files changed

+182
-119
lines changed

13 files changed

+182
-119
lines changed

backends/qnnpack/partition/qnnpack_partitioner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from executorch.backends.transforms.addmm_mm_to_linear import (
2020
apply_addmm_mm_to_linear_transform,
2121
)
22-
from executorch.exir.backend.partitioner import DelegationSpec, Partitioner
22+
from executorch.exir.backend.partitioner import (
23+
DelegationSpec,
24+
Partitioner,
25+
PartitionResult,
26+
)
2327
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
2428

2529
logging.basicConfig(level=logging.INFO)
@@ -35,7 +39,6 @@ def __init__(self, delegate_name, patterns):
3539
self.patterns = patterns
3640

3741
self.delegation_spec = DelegationSpec(delegate_name, [])
38-
self.partition_tags: Dict[str, DelegationSpec] = {}
3942

4043
@staticmethod
4144
def check_partitions(partitions: Union[dict, list]) -> None:
@@ -48,9 +51,9 @@ def check_partitions(partitions: Union[dict, list]) -> None:
4851
else:
4952
log.info(f"Found {pl} subgraphs to be partitioned.")
5053

51-
def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
54+
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
5255
raise NotImplementedError("This is not meant to be used directly.")
53-
return graph_module
56+
return PartitionResult(tagged_graph=graph_module, partition_tags={})
5457

5558

5659
class _SingleOpDelegatePartitioner(_BasePartitioner):
@@ -72,7 +75,7 @@ def __init__(
7275
self.transforms = transforms
7376

7477
# override
75-
def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
78+
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
7679
# TODO delete this since we are not allowed to do this
7780
if self.transforms is not None:
7881
for transform in self.transforms: # pyre-ignore
@@ -107,7 +110,7 @@ def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
107110

108111
# Mapping from delegation tag to match set
109112
tag_mapping = {}
110-
113+
partition_tags: Dict[str, DelegationSpec] = {}
111114
for (partition_id, match_set) in enumerate(match_sets):
112115
delegation_tag = f"tag{partition_id}"
113116
for node in match_set:
@@ -124,10 +127,10 @@ def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
124127
)
125128
break
126129
node.meta["delegation_tag"] = delegation_tag
127-
self.partition_tags[delegation_tag] = self.delegation_spec
130+
partition_tags[delegation_tag] = self.delegation_spec
128131
tag_mapping[delegation_tag] = match_set
129132

130-
return graph_module
133+
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
131134

132135

133136
class QnnpackPartitioner(_SingleOpDelegatePartitioner):

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
generate_partitions_from_list_of_nodes,
2727
)
2828

29-
from executorch.exir.backend.partitioner import DelegationSpec, Partitioner
29+
from executorch.exir.backend.partitioner import (
30+
DelegationSpec,
31+
Partitioner,
32+
PartitionResult,
33+
)
3034
from executorch.exir.dialects._ops import ops as exir_ops
3135
from torch.fx.passes.infra.partitioner import Partition
3236
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -313,7 +317,6 @@ def __init__(
313317
self.supported_ops = set(supported_ops or [])
314318

315319
self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
316-
self.partition_tags: Dict[str, DelegationSpec] = {}
317320

318321
@staticmethod
319322
def check_partitions(partitions: Union[dict, list]) -> bool:
@@ -386,27 +389,30 @@ def generate_partitions(self, graph_module: torch.fx.GraphModule) -> List[Any]:
386389
),
387390
)
388391

389-
def tag_nodes(self, partitions: List[Partition]) -> None:
392+
def tag_nodes(self, partitions: List[Partition]) -> Dict[str, DelegationSpec]:
390393
"""
391394
Tag each partition in the list with its delegation tag.
392395
"""
396+
partition_tags: Dict[str, DelegationSpec] = {}
393397
for partition in partitions:
394398
# Add delegation tags
395399
for node in partition.nodes:
396400
delegation_tag = f"tag{partition.id}"
397401
node.meta["delegation_tag"] = delegation_tag
398-
self.partition_tags[delegation_tag] = self.delegation_spec
402+
partition_tags[delegation_tag] = self.delegation_spec
403+
return partition_tags
399404

400405
# override
401-
def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
406+
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
402407
"""
403408
Run the partitioner on the given graph module, then tag each partition
404409
with its delegation tag (and partition id)
405410
"""
406411
partitions = self.generate_partitions(graph_module)
412+
partition_tags: Dict[str, DelegationSpec] = {}
407413
if self.check_partitions(partitions):
408-
self.tag_nodes(partitions)
409-
return graph_module
414+
partition_tags = self.tag_nodes(partitions)
415+
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
410416

411417

412418
# TODO: Merge XnnpackQuantizedPartitioner and XnnpackFloatingPointPartitioner
@@ -761,10 +767,11 @@ def generate_partitions(
761767
XnnpackOperatorSupport(supported_ops=list(self.get_supported_ops(quant))),
762768
)
763769

764-
def tag_nodes(self, partitions: List[Partition]) -> None:
770+
def tag_nodes(self, partitions: List[Partition]) -> Dict[str, DelegationSpec]:
765771
"""
766772
Tag each partition in the list with its delegation tag.
767773
"""
774+
partition_tags: Dict[str, DelegationSpec] = {}
768775
for partition in partitions:
769776
# Add delegation tags
770777
skip = False
@@ -776,23 +783,25 @@ def tag_nodes(self, partitions: List[Partition]) -> None:
776783
for node in partition.nodes:
777784
delegation_tag = f"tag{partition.id}"
778785
node.meta["delegation_tag"] = delegation_tag
779-
self.partition_tags[delegation_tag] = self.delegation_spec
786+
partition_tags[delegation_tag] = self.delegation_spec
787+
return partition_tags
780788

781789
# override
782790
def _partition(
783791
self, graph_module: torch.fx.GraphModule, quant: Optional[bool]
784-
) -> torch.fx.GraphModule:
792+
) -> PartitionResult:
785793
"""
786794
Run the partitioner on the given graph module, then tag each partition
787795
with its delegation tag (and partition id)
788796
"""
789797
partitions = self.generate_partitions(graph_module, quant)
798+
partition_tags: Dict[str, DelegationSpec] = {}
790799
if self.check_partitions(partitions):
791-
self.tag_nodes(partitions)
792-
return graph_module
800+
partition_tags = self.tag_nodes(partitions)
801+
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
793802

794-
def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
795-
ret = self._partition(graph_module, self.quant)
803+
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
804+
ret: PartitionResult = self._partition(graph_module, self.quant)
796805
return ret
797806

798807

@@ -805,7 +814,7 @@ def __init__(
805814
super().__init__(supported_modules, supported_ops)
806815

807816
# override
808-
def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
817+
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
809818
"""
810819
Run the partitioner on the given graph module, then tag each partition with its delegegation tag (and partition id)
811820
@@ -819,7 +828,8 @@ def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
819828
)
820829
for match in self.get_module_partitions(graph_module)
821830
]
831+
partition_tags: Dict[str, DelegationSpec] = {}
822832

823833
if self.check_partitions(partitions):
824-
self.tag_nodes(partitions)
825-
return graph_module
834+
partition_tags = self.tag_nodes(partitions)
835+
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)

docs/website/docs/tutorials/backend_delegate.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def to_backend(
205205
```
206206

207207
This function takes in a `Partitioner` which adds a tag to all the nodes that
208-
are meant to be lowered. It will also contain a `partition_tags` mapping tags to
208+
are meant to be lowered. The `Partitioner.partition` function will return both the tagged graph module and the `partition_tags` mapping tags to
209209
backend names and module compile specs. The tagged nodes will then be
210210
partitioned and lowered to their mapped backends using Flow 1's process.
211211
Available helper partitioner are documented [here](./passes.md#partitioner). These
@@ -371,24 +371,25 @@ class Backend_1_2_Partitioner(Partitioner):
371371
def __init__(self) -> None:
372372
self.delegation_spec_1 = DelegationSpec("Backend1", [])
373373
self.delegation_spec_2 = DelegationSpec("Backend2", [])
374-
self.partition_tags = {}
375374

376375
def partition(
377376
self, edge_graph_module: torch.fx.GraphModule
378-
) -> torch.fx.GraphModule:
379-
377+
) -> PartitionResult:
378+
partition_tags: Dict[str, DelegationSpec] = {}
380379
# Tag all nodes in the first partiton to backend 1
381380
node_to_backend_1 = ... # some logic to select the nodes from the graph
382381
delegation_tag = f"backend2_tag{partitioner_1.id}"
383382
node.meta["delegation_tag"] = delegation_tag
384-
self.partition_tags[delegation_tag] = self.delegation_spec_1
383+
partition_tags[delegation_tag] = self.delegation_spec_1
385384

386385
# Tag all nodes in the first partiton to backend 2
387386
node_to_backend_2 = ... # some logic to select the nodes from the graph
388387
delegation_tag = f"backend2_tag{partitioner_2.id}"
389388
node.meta["delegation_tag"] = delegation_tag
390-
self.partition_tags[delegation_tag] = self.delegation_spec_2
391-
return edge_graph_module
389+
partition_tags[delegation_tag] = self.delegation_spec_2
390+
return PartitionResult(
391+
tagged_graph=edge_graph_module, partition_tags=partition_tags
392+
)
392393
```
393394

394395
6. Is there an easy way to write partitioner?

exir/backend/backend_api.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
1717
from executorch.exir.backend.compile_spec_schema import CompileSpec
1818

19-
from executorch.exir.backend.partitioner import Partitioner, TPartitioner
19+
from executorch.exir.backend.partitioner import (
20+
Partitioner,
21+
PartitionResult,
22+
TPartitioner,
23+
)
2024
from executorch.exir.backend.utils import is_identical_graph
2125

2226
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
@@ -142,10 +146,10 @@ def validation_disabled() -> Generator[None, None, None]:
142146

143147
def _partition_and_lower(
144148
tagged_graph_module: torch.fx.GraphModule,
145-
partitioner_instance: Partitioner,
149+
partition_result: PartitionResult,
146150
owning_program: ExportedProgram,
147151
) -> torch.fx.GraphModule:
148-
for tag, delegation_spec in partitioner_instance.partition_tags.items():
152+
for tag, delegation_spec in partition_result.partition_tags.items():
149153
# Create partition with nodes containing this tag. There should only be
150154
# one contained submodule per tag
151155
node_list = []
@@ -231,7 +235,7 @@ def _partition_and_lower(
231235
# Recursively partition and lower for submodules
232236
for name, submod, _node in get_control_flow_submodules(tagged_graph_module):
233237
partitioned_submodule = _partition_and_lower(
234-
submod, partitioner_instance, owning_program
238+
submod, partition_result, owning_program
235239
)
236240
tagged_graph_module.add_module(name, partitioned_submodule)
237241

@@ -281,7 +285,8 @@ def to_backend(
281285
copied_graph_module = copy.deepcopy(edge_graph_module)
282286
# Call the partitioner on the given graph module
283287
partitioner_instance: Partitioner = partitioner()
284-
tagged_graph_module = partitioner_instance(copied_graph_module)
288+
partitioner_result = partitioner_instance(copied_graph_module)
289+
tagged_graph_module = partitioner_result.tagged_graph
285290

286291
# Check that the partitioner did not modify the original graph
287292
if _ENABLE_VALIDATION:
@@ -293,12 +298,11 @@ def to_backend(
293298
logging.warning("Disabled validating the partitioner.")
294299

295300
assert (
296-
hasattr(partitioner_instance, "partition_tags")
297-
and partitioner_instance.partition_tags is not None
301+
partitioner_result.partition_tags is not None
298302
), f"Partitioner {partitioner} needs a `partition_tags` field containing a mapping of tags to delegate spec"
299303

300304
tagged_graph_module = _partition_and_lower(
301-
tagged_graph_module, partitioner_instance, edge_program
305+
tagged_graph_module, partitioner_result, edge_program
302306
)
303307

304308
# TODO(angelayi): Update this signature in a less manual way (maybe through

exir/backend/partitioner.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8+
from dataclasses import dataclass
89
from typing import Dict, List, NamedTuple, TypeVar
910

1011
import torch.fx as fx
@@ -18,14 +19,27 @@ class DelegationSpec(NamedTuple):
1819
compile_specs: List[CompileSpec]
1920

2021

22+
@dataclass
23+
class PartitionResult:
24+
"""
25+
tagged_graph: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata
26+
partition_tags: A dictionary that will be used to keep track of the tags and it's corresponding DelegationSpec. The tag is defined by users and used
27+
in the node.meta.
28+
"""
29+
30+
tagged_graph: fx.GraphModule
31+
partition_tags: Dict[str, DelegationSpec]
32+
33+
2134
class Partitioner(ABC):
2235
"""
23-
Defines a callable interface for partitioning an exported Module (i.e. a program) for
36+
Defines a callable interface for partitioning an exported module (i.e. a program) for
2437
backend delegation.
25-
A partitioner implementation would receive an exported Module, determine what portions of
38+
A partitioner implementation would receive an exported module, determine what portions of
2639
the it can be delegated to certain backend (though a partitioner can target multiple
27-
backends as well), and return the same input Module with specific nodes in
28-
the input graph tagged for delegation.
40+
backends as well), and return the PartitionResult including:
41+
- the same input module with specific nodes in the input graph tagged for delegation
42+
- the "partition_tags" to indicate how the tag is mapped to Delegation Spec.
2943
3044
The nodes that intend to be delegated must be tagged (by setting
3145
node.meta["delegation_tag"]) and this tag must be provided in the
@@ -40,14 +54,12 @@ class Partitioner(ABC):
4054
edge_graph_module: A module in Edge dialect to be partitioned for backend delegation.
4155
"""
4256

43-
partition_tags: Dict[str, DelegationSpec]
44-
45-
def __call__(self, edge_graph_module: fx.GraphModule) -> fx.GraphModule:
57+
def __call__(self, edge_graph_module: fx.GraphModule) -> PartitionResult:
4658
return self.partition(edge_graph_module)
4759

4860
@enforcedmethod
4961
@abstractmethod
50-
def partition(self, edge_graph_module: fx.GraphModule) -> fx.GraphModule:
62+
def partition(self, edge_graph_module: fx.GraphModule) -> PartitionResult:
5163
"""
5264
Returns the input exported program with newly created sub-Modules encapsulating
5365
specific portions of the input "tagged" for delegation.
@@ -66,8 +78,7 @@ def partition(self, edge_graph_module: fx.GraphModule) -> fx.GraphModule:
6678
edge_graph_module: A module in Edge dialect to be partitioned for backend delegation.
6779
6880
Returns:
69-
GraphModule: Returns the input exported program with nodes that
70-
intend to be delegated containing a "delegate_spec" metadata
81+
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
7182
"""
7283
pass
7384

exir/backend/test/demos/rpc/executor_backend_partitioner.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
1212
generate_pattern_op_partitions,
1313
)
14-
from executorch.exir.backend.partitioner import DelegationSpec, Partitioner
14+
from executorch.exir.backend.partitioner import (
15+
DelegationSpec,
16+
Partitioner,
17+
PartitionResult,
18+
)
1519
from executorch.exir.backend.test.backend_with_compiler_demo import (
1620
BackendWithCompilerDemo,
1721
)
@@ -45,23 +49,23 @@ def __init__(self) -> None:
4549
self.op_support = any_chain(AnyOperatorSupport(), AnyDelegateSupport())
4650
self.delegation_spec = DelegationSpec("ExecutorBackend", [])
4751

48-
self.partition_tags = {}
49-
50-
def partition(
51-
self, edge_graph_module: torch.fx.GraphModule
52-
) -> torch.fx.GraphModule:
52+
def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
53+
partition_tags = {}
5354
partition_list = generate_pattern_op_partitions(
5455
edge_graph_module, op_support=self.op_support
5556
)
5657
for partition in partition_list:
5758
for node in partition.nodes:
5859
delegation_tag = f"tag{partition.id}"
5960
node.meta["delegation_tag"] = delegation_tag
60-
self.partition_tags[delegation_tag] = self.delegation_spec
61+
partition_tags[delegation_tag] = self.delegation_spec
6162

6263
# Tag the delegate submodules
6364
# pyre-ignore Undefined attribute [16]: Item `None` of `typing.Union[None, typing.Dict[str, typing.Any], typing.List[typing.Any], bool, complex, float, int, range, slice, str, torch._C.device, torch._C.dtype, torch._C.layout, torch._C.memory_format, torch._tensor.Tensor, torch.fx.node.Node, typing.Tuple[typing.Any, ...]]` has no attribute `op`.Pyre
6465
if node.args[0].op == "get_attr":
6566
# pyre-ignore Undefined attribute [16]: Item `None` of `typing.Union[None, typing.Dict[str, typing.Any], typing.List[typing.Any], bool, complex, float, int, range, slice, str, torch._C.device, torch._C.dtype, torch._C.layout, torch._C.memory_format, torch._tensor.Tensor, torch.fx.node.Node, typing.Tuple[typing.Any, ...]]` has no attribute `op`.Pyre
6667
node.args[0].meta["delegation_tag"] = delegation_tag
67-
return edge_graph_module
68+
69+
return PartitionResult(
70+
tagged_graph=edge_graph_module, partition_tags=partition_tags
71+
)

0 commit comments

Comments
 (0)