Skip to content

Commit ee8bf40

Browse files
committed
extract exir update from aoti-cuda-export
1 parent a548635 commit ee8bf40

File tree

4 files changed

+19
-2
lines changed

4 files changed

+19
-2
lines changed

exir/backend/backend_api.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,9 @@ def _partition_and_lower_one_graph_module(
268268
"""
269269
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
270270
"""
271-
for tag, delegation_spec in partition_result.partition_tags.items():
271+
for idx, (tag, delegation_spec) in enumerate(
272+
partition_result.partition_tags.items()
273+
):
272274
# Create partition with nodes containing this tag. There should only be
273275
# one contained submodule per tag
274276
node_list = _get_node_list_with_same_tag(
@@ -311,6 +313,7 @@ def _partition_and_lower_one_graph_module(
311313
tag,
312314
call_module_node,
313315
is_submodule,
316+
idx == 0,
314317
)
315318

316319
lowered_submodule = to_backend(
@@ -452,7 +455,9 @@ def _create_partitions_in_graph_module(
452455
is_submodule: bool,
453456
) -> Dict[str, List[torch.fx.Node]]:
454457
backend_id_to_submodule_name = {}
455-
for tag, delegation_spec in partition_result.partition_tags.items():
458+
for idx, (tag, delegation_spec) in enumerate(
459+
partition_result.partition_tags.items()
460+
):
456461
# Create partition with nodes containing this tag. There should only be
457462
# one contained submodule per tag
458463
node_list = _get_node_list_with_same_tag(
@@ -492,6 +497,7 @@ def _create_partitions_in_graph_module(
492497
tag,
493498
call_module_node,
494499
is_submodule,
500+
idx == 0,
495501
)
496502
call_module_node.meta["backend_id"] = delegation_spec.backend_id
497503
call_module_node.meta["compile_spec"] = delegation_spec.compile_specs
@@ -720,6 +726,8 @@ def to_backend(
720726
fake_edge_program = copy.deepcopy(edge_program)
721727
partitioner_result = partitioner_instance(fake_edge_program)
722728
tagged_exported_program = partitioner_result.tagged_exported_program
729+
tagged_exported_program.example_inputs = edge_program.example_inputs
730+
723731
method_to_tagged_exported_program[method_name] = tagged_exported_program
724732

725733
# Check that the partitioner did not modify the original graph

exir/emit/_emit_program.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def emit_program(
176176
)
177177

178178
emitter.run()
179+
179180
plans.append(emitter.plan())
180181

181182
debug_handle_map[name] = emitter.debug_handle_map

exir/lowered_backend_module.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ def create_exported_program_from_submodule(
682682
tag: str,
683683
call_module_node: torch.fx.Node,
684684
is_submodule: bool,
685+
is_first_partition: bool = False,
685686
) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]:
686687
"""
687688
Creates an ExportedProgram from the given submodule using the parameters and buffers
@@ -720,6 +721,11 @@ def create_exported_program_from_submodule(
720721
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
721722
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
722723

724+
# only the example inputs of first parition equals to the example inputs of the owning program
725+
submodule_exmaple_inputs = (
726+
owning_program.example_inputs if is_first_partition else None
727+
)
728+
723729
return (
724730
ExportedProgram(
725731
root=submodule,
@@ -735,6 +741,7 @@ def create_exported_program_from_submodule(
735741
),
736742
)
737743
],
744+
example_inputs=submodule_exmaple_inputs,
738745
constants=subgraph_constants,
739746
verifiers=[owning_program.verifier],
740747
),

exir/program/_program.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,6 +1697,7 @@ def to_executorch( # noqa (FLAKE8) C901
16971697
after it has been transformed to the ExecuTorch backend.
16981698
"""
16991699
config = config if config else ExecutorchBackendConfig()
1700+
17001701
execution_programs: Dict[str, ExportedProgram] = {}
17011702
for name, program in self._edge_programs.items():
17021703
if config.do_quant_fusion_and_const_prop:

0 commit comments

Comments
 (0)