@@ -268,7 +268,9 @@ def _partition_and_lower_one_graph_module(
268268 """
269269 Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
270270 """
271- for tag , delegation_spec in partition_result .partition_tags .items ():
271+ for idx , (tag , delegation_spec ) in enumerate (
272+ partition_result .partition_tags .items ()
273+ ):
272274 # Create partition with nodes containing this tag. There should only be
273275 # one contained submodule per tag
274276 node_list = _get_node_list_with_same_tag (
@@ -311,6 +313,7 @@ def _partition_and_lower_one_graph_module(
311313 tag ,
312314 call_module_node ,
313315 is_submodule ,
316+ idx == 0 ,
314317 )
315318
316319 lowered_submodule = to_backend (
@@ -452,7 +455,9 @@ def _create_partitions_in_graph_module(
452455 is_submodule : bool ,
453456) -> Dict [str , List [torch .fx .Node ]]:
454457 backend_id_to_submodule_name = {}
455- for tag , delegation_spec in partition_result .partition_tags .items ():
458+ for idx , (tag , delegation_spec ) in enumerate (
459+ partition_result .partition_tags .items ()
460+ ):
456461 # Create partition with nodes containing this tag. There should only be
457462 # one contained submodule per tag
458463 node_list = _get_node_list_with_same_tag (
@@ -492,6 +497,7 @@ def _create_partitions_in_graph_module(
492497 tag ,
493498 call_module_node ,
494499 is_submodule ,
500+ idx == 0 ,
495501 )
496502 call_module_node .meta ["backend_id" ] = delegation_spec .backend_id
497503 call_module_node .meta ["compile_spec" ] = delegation_spec .compile_specs
@@ -720,6 +726,8 @@ def to_backend(
720726 fake_edge_program = copy .deepcopy (edge_program )
721727 partitioner_result = partitioner_instance (fake_edge_program )
722728 tagged_exported_program = partitioner_result .tagged_exported_program
729+ tagged_exported_program .example_inputs = edge_program .example_inputs
730+
723731 method_to_tagged_exported_program [method_name ] = tagged_exported_program
724732
725733 # Check that the partitioner did not modify the original graph
0 commit comments