Skip to content

Commit 32c14b1

Browse files
committed
set example input to only the very first partition
1 parent 3e2f2b7 commit 32c14b1

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

backends/aoti/aoti_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ def preprocess(
7575

7676
named_data_store = NamedDataStore()
7777

78-
# copy_edge_program = copy.deepcopy(edge_program)
79-
8078
# Move the edge_program from CPU to CUDA for aoti compile
8179
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
8280

exir/backend/backend_api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,9 @@ def _partition_and_lower_one_graph_module(
268268
"""
269269
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
270270
"""
271-
for tag, delegation_spec in partition_result.partition_tags.items():
271+
for idx, (tag, delegation_spec) in enumerate(
272+
partition_result.partition_tags.items()
273+
):
272274
# Create partition with nodes containing this tag. There should only be
273275
# one contained submodule per tag
274276
node_list = _get_node_list_with_same_tag(
@@ -311,6 +313,7 @@ def _partition_and_lower_one_graph_module(
311313
tag,
312314
call_module_node,
313315
is_submodule,
316+
idx == 0,
314317
)
315318

316319
lowered_submodule = to_backend(
@@ -452,7 +455,9 @@ def _create_partitions_in_graph_module(
452455
is_submodule: bool,
453456
) -> Dict[str, List[torch.fx.Node]]:
454457
backend_id_to_submodule_name = {}
455-
for tag, delegation_spec in partition_result.partition_tags.items():
458+
for idx, (tag, delegation_spec) in enumerate(
459+
partition_result.partition_tags.items()
460+
):
456461
# Create partition with nodes containing this tag. There should only be
457462
# one contained submodule per tag
458463
node_list = _get_node_list_with_same_tag(
@@ -492,6 +497,7 @@ def _create_partitions_in_graph_module(
492497
tag,
493498
call_module_node,
494499
is_submodule,
500+
idx == 0,
495501
)
496502
call_module_node.meta["backend_id"] = delegation_spec.backend_id
497503
call_module_node.meta["compile_spec"] = delegation_spec.compile_specs

exir/lowered_backend_module.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ def create_exported_program_from_submodule(
682682
tag: str,
683683
call_module_node: torch.fx.Node,
684684
is_submodule: bool,
685+
is_first_partition: bool = False,
685686
) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]:
686687
"""
687688
Creates an ExportedProgram from the given submodule using the parameters and buffers
@@ -720,6 +721,11 @@ def create_exported_program_from_submodule(
720721
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
721722
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
722723

724+
# only the example inputs of first parition equals to the example inputs of the owning program
725+
submodule_exmaple_inputs = (
726+
owning_program.example_inputs if is_first_partition else None
727+
)
728+
723729
return (
724730
ExportedProgram(
725731
root=submodule,
@@ -735,7 +741,7 @@ def create_exported_program_from_submodule(
735741
),
736742
)
737743
],
738-
example_inputs=owning_program.example_inputs,
744+
example_inputs=submodule_exmaple_inputs,
739745
constants=subgraph_constants,
740746
verifiers=[owning_program.verifier],
741747
),

0 commit comments

Comments
 (0)