Skip to content

Commit c35ef31

Browse files
committed
remove unnecessary export code
1 parent 31aac73 commit c35ef31

File tree

2 files changed

+17
-79
lines changed

2 files changed

+17
-79
lines changed

backends/aoti/aoti_backend.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -73,31 +73,15 @@ def preprocess(
7373
compile_specs: List[CompileSpec],
7474
) -> PreprocessResult:
7575

76-
print("entering the lowerable parts in AotiBackend.preprocess....")
7776
named_data_store = NamedDataStore()
7877

79-
# print("here", edge_program.example_inputs)
80-
copy_edge_program = copy.deepcopy(edge_program)
78+
# copy_edge_program = copy.deepcopy(edge_program)
8179

8280
# Move the edge_program from CPU to CUDA for aoti compile
83-
cuda_edge_program = move_to_device_pass(copy_edge_program, "cuda")
81+
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
8482

8583
edge_program_module = cuda_edge_program.module()
86-
args, kwargs = copy_edge_program.example_inputs
87-
88-
# # Deep copy args and move tensors to CUDA for aot_compile
89-
# def move_to_cuda(obj):
90-
# if isinstance(obj, torch.Tensor):
91-
# return obj.cuda()
92-
# elif isinstance(obj, (list, tuple)):
93-
# return type(obj)(move_to_cuda(item) for item in obj)
94-
# elif isinstance(obj, dict):
95-
# return {key: move_to_cuda(value) for key, value in obj.items()}
96-
# else:
97-
# return obj
98-
99-
# args = move_to_cuda(copy.deepcopy(args))
100-
# kwargs = move_to_cuda(copy.deepcopy(kwargs))
84+
args, kwargs = cuda_edge_program.example_inputs
10185

10286
output_path = os.path.join(os.getcwd(), "aoti.so")
10387

@@ -122,10 +106,6 @@ def preprocess(
122106
"Please add them to the AOTI backend."
123107
)
124108

125-
assert so_path == output_path, f"Expected {output_path} but got {so_path}"
126-
127-
print("so_path", so_path)
128-
129109
with open(so_path, "rb") as f:
130110
so_data = f.read()
131111

backends/aoti/aoti_partitioner.py

Lines changed: 14 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
# pyre-unsafe
88

9-
import operator
10-
from typing import Callable, cast, Dict, final, List, Optional, Set, Tuple
9+
from typing import Callable, Dict, final, List, Optional, Tuple
1110

1211
import torch
1312
from executorch.backends.aoti.aoti_backend import AotiBackend # usort: skip
@@ -18,65 +17,26 @@
1817
PartitionResult,
1918
)
2019
from executorch.exir.backend.utils import tag_constant_data
21-
from executorch.exir.dialects._ops import ops as exir_ops
2220
from torch.export.exported_program import ExportedProgram
23-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
24-
25-
from torch.fx.passes.operator_support import OperatorSupportBase
26-
27-
28-
class AOTISupportedOperators(OperatorSupportBase):
29-
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
30-
# supported = node.op == "call_function" and (
31-
# node.target == operator.getitem
32-
# or str(node.target._op) not in inductor_fallback_ops
33-
# or str(node.target._op) in supported_fallback_operators
34-
# )
35-
36-
supported = node.op == "call_function"
37-
38-
return supported
39-
40-
def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
41-
if node.target == exir_ops.edge.aten.mean.dim:
42-
keep_dim = node.args[2] if len(node.args) > 2 else False
43-
return cast(bool, keep_dim)
44-
if node.target == exir_ops.edge.aten.var.correction:
45-
keep_dim = node.kwargs.get("keepdim", False)
46-
return cast(bool, keep_dim)
47-
return True
4821

4922

5023
@final
5124
class AotiPartitioner(Partitioner):
5225
def __init__(self, compile_spec: List[CompileSpec]) -> None:
5326
self.delegation_spec = DelegationSpec(AotiBackend.__name__, compile_spec)
54-
print(self.delegation_spec)
5527

5628
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
57-
# Run the CapabilityBasedPartitioner to return the largest possible
58-
# subgraphs containing the nodes with the tags
59-
# logger.info("AotiPartitioner::partition")
60-
print("entering partitioner...")
61-
62-
partition_tags = {}
63-
64-
capability_partitioner = CapabilityBasedPartitioner(
65-
exported_program.graph_module,
66-
AOTISupportedOperators(),
67-
allows_single_node_partition=True,
68-
)
69-
partition_list = capability_partitioner.propose_partitions()
70-
71-
assert len(partition_list) == 1, "Graph break is not supported yet"
72-
73-
print(f"graph breaks into {len(partition_list)} parts")
29+
"""
30+
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
31+
"""
7432

75-
for partition in partition_list:
76-
for node in partition.nodes:
77-
tag = f"tag{partition.id}"
78-
node.meta["delegation_tag"] = tag
79-
partition_tags[tag] = self.delegation_spec
33+
partition_tags: Dict[str, DelegationSpec] = {}
34+
for node in exported_program.graph.nodes:
35+
if node.op != "call_function":
36+
continue
37+
tag = f"tag0"
38+
node.meta["delegation_tag"] = tag
39+
partition_tags[tag] = self.delegation_spec
8040

8141
tag_constant_data(exported_program)
8242

@@ -89,15 +49,13 @@ def ops_to_not_decompose(
8949
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
9050
"""
9151
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
52+
Currently we skip decomposing all ops and let the AOT compiler handle them.
9253
"""
9354
do_not_decompose = set()
94-
op_support = AOTISupportedOperators()
9555

9656
for node in ep.graph.nodes:
97-
if (
98-
node.op == "call_function"
99-
and isinstance(node.target, torch._ops.OpOverload)
100-
and op_support.is_node_supported(None, node)
57+
if node.op == "call_function" and isinstance(
58+
node.target, torch._ops.OpOverload
10159
):
10260
do_not_decompose.add(node.target)
10361
return list(do_not_decompose), None

0 commit comments

Comments
 (0)