@@ -238,7 +238,6 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
238238 call_delegate_node .meta ["val" ] = submodule_output_node .meta ["val" ]
239239 call_submodule_node .replace_all_uses_with (call_delegate_node )
240240 owning_graph_module .graph .erase_node (call_submodule_node )
241-
242241 if is_submodule :
243242 assert len (toplevel_input_specs_to_delete ) == 0
244243 assert len (toplevel_output_specs_to_delete ) == 0
@@ -574,26 +573,29 @@ def lower_all_submodules_to_backend(
574573 # The created exported program for the submodules are in the call_module node's meta data
575574 # We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs
576575 method_to_partitioned_program = {
577- method_name : [node .meta ["submodule_program" ] for node in call_submodule_nodes ]
576+ method_name : [
577+ copy .deepcopy (node .meta ["submodule_program" ])
578+ for node in call_submodule_nodes
579+ ]
578580 for method_name , call_submodule_nodes in method_to_submodules_nodes .items ()
579581 }
580582 method_to_compile_specs = {
581583 method_name : [node .meta ["compile_spec" ] for node in call_submodule_nodes ]
582584 for method_name , call_submodule_nodes in method_to_submodules_nodes .items ()
583585 }
584- backend_found = False
585- for cls in BackendDetails .__subclasses__ ():
586- if backend_id == cls .__name__ :
587- method_to_preprocess_result : dict [str , List [PreprocessResult ]] = (
588- cls .preprocess_multimethod (
589- method_to_partitioned_program , method_to_compile_specs
590- )
591- )
592- backend_found = True
593586
594- if not backend_found :
587+ backend_name_to_subclass = {
588+ subclass .__name__ : subclass for subclass in BackendDetails .__subclasses__ ()
589+ }
590+ if backend_id not in backend_name_to_subclass :
595591 raise NotImplementedError (f"Backend { backend_id } was not found." )
596592
593+ method_to_preprocess_result : dict [str , List [PreprocessResult ]] = (
594+ backend_name_to_subclass [backend_id ].preprocess_multimethod (
595+ method_to_partitioned_program , method_to_compile_specs
596+ )
597+ )
598+
597599 for method_name in method_to_preprocess_result .keys ():
598600 owning_program = method_to_tagged_edge_program [method_name ]
599601 list_of_preprocess_results = method_to_preprocess_result [method_name ]
@@ -612,6 +614,9 @@ def lower_all_submodules_to_backend(
612614 compile_specs = compile_spec ,
613615 named_data_store_output = preprocess_result .data_store_output ,
614616 )
617+ lowered_module .meta = {
618+ "debug_handle_map" : preprocess_result .debug_handle_map ,
619+ }
615620 is_submodule = call_submodule_node .meta ["is_submodule" ]
616621 toplevel_input_specs_to_delete = call_submodule_node .meta [
617622 "toplevel_input_specs_to_delete"
@@ -633,6 +638,20 @@ def lower_all_submodules_to_backend(
633638 )
634639
635640
641+ def remove_used_metadata (graph : torch .fx .Graph ) -> None :
642+ """
643+ Remove the used metadata from the graph.
644+ """
645+ for node in graph .nodes :
646+ node .meta .pop ("delegation_tag" , None )
647+ node .meta .pop ("backend_id" , None )
648+ node .meta .pop ("submodule_program" , None )
649+ node .meta .pop ("toplevel_input_specs_to_delete" , None )
650+ node .meta .pop ("toplevel_output_specs_to_delete" , None )
651+ node .meta .pop ("is_submodule" , None )
652+ node .meta .pop ("submodule_output_node" , None )
653+
654+
636655@dataclass
637656class MethodProgramsPartitionerSpec :
638657 """
@@ -748,6 +767,7 @@ def to_backend(
748767 if method_name in method_to_tagged_exported_program :
749768 tagged_exported_program = method_to_tagged_exported_program [method_name ]
750769 tagged_exported_program ._validate ()
770+ remove_used_metadata (tagged_exported_program .graph_module .graph )
751771 partitioned_and_lowered_exported_programs [method_name ] = ExportedProgram (
752772 root = tagged_exported_program .graph_module ,
753773 graph = tagged_exported_program .graph_module .graph ,
0 commit comments