66
77import copy
88import logging
9- from contextlib import contextmanager
9+ from contextlib import contextmanager , nullcontext
1010from functools import singledispatch
1111from typing import Generator , List
1212
2525
2626from executorch .exir .graph_module import get_control_flow_submodules
2727from executorch .exir .lowered_backend_module import (
28- _get_new_signature ,
28+ _unsafe_adjust_original_program ,
2929 create_exported_program_from_submodule ,
3030 create_submodule_from_nodes ,
3131 LoweredBackendModule ,
3232)
33- from executorch .exir .pass_base import ExportPass
3433from executorch .exir .program ._fake_program import (
3534 get_fake_program ,
3635 update_to_real_program ,
@@ -193,6 +192,7 @@ def _partition_and_lower_one_graph_module(
193192 tagged_graph_module : torch .fx .GraphModule ,
194193 partition_result : PartitionResult ,
195194 owning_program : ExportedProgram ,
195+ is_submodule : bool ,
196196) -> torch .fx .GraphModule :
197197 """
198198 Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
@@ -210,21 +210,40 @@ def _partition_and_lower_one_graph_module(
210210
211211 logging .debug (f"For tag { tag } , found nodes { node_list } " )
212212 # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
213- submodule , call_module_node = create_submodule_from_nodes (
214- tagged_graph_module , node_list , tag
213+
214+ replace_ctx = (
215+ tagged_graph_module ._set_replace_hook (
216+ owning_program .graph_signature .get_replace_hook ()
217+ )
218+ if not is_submodule
219+ else nullcontext ()
215220 )
221+ with replace_ctx :
222+ submodule , call_module_node = create_submodule_from_nodes (
223+ tagged_graph_module , node_list , tag
224+ )
225+
216226 tagged_graph_module_output_node = [
217227 node for node in tagged_graph_module .graph .nodes if node .op == "output"
218- ]
228+ ][ 0 ]
219229 submodule_output_node = [
220230 node for node in submodule .graph .nodes if node .op == "output"
221- ]
222- # Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
223- submodule_output_node [0 ].meta = tagged_graph_module_output_node [0 ].meta
231+ ][0 ]
232+ # Copy the output node meta from the original output node, because
233+ # create_submodule_from_nodes doesn't cover the meta field
234+ submodule_output_node .meta = tagged_graph_module_output_node .meta
224235 logging .debug (f"Partitioned graph module: { tagged_graph_module } " )
225236
226- submodule_program = create_exported_program_from_submodule (
227- submodule , owning_program , tag
237+ (
238+ submodule_program ,
239+ toplevel_input_specs_to_delete ,
240+ toplevel_output_specs_to_delete ,
241+ ) = create_exported_program_from_submodule (
242+ submodule ,
243+ owning_program ,
244+ tag ,
245+ call_module_node ,
246+ is_submodule ,
228247 )
229248
230249 lowered_submodule = to_backend (
@@ -257,64 +276,48 @@ def _partition_and_lower_one_graph_module(
257276 call_delegate_node .meta ["debug_handle" ] = len (
258277 tagged_graph_module .graph .nodes
259278 )
279+ call_delegate_node .meta ["val" ] = submodule_output_node .meta ["val" ]
260280 call_module_node .replace_all_uses_with (call_delegate_node )
261281 tagged_graph_module .graph .erase_node (call_module_node )
262282
263- # Delete all parameters/buffers consumed by the created exported program
264- toplevel_signature = owning_program .graph_signature
265- for node in tagged_graph_module .graph .nodes :
266- # Find placeholders consumed by the delegate
267- if node .op != "placeholder" or len (node .users ) != 0 :
268- continue
269-
270- if node .name in toplevel_signature .inputs_to_buffers :
271- # Delete the consumed buffers
272- buffer_name = toplevel_signature .inputs_to_buffers .get (node .name )
273- if buffer_name in owning_program .state_dict :
274- owning_program .state_dict .pop (buffer_name )
275- else :
276- owning_program .constants .pop (buffer_name )
277- tagged_graph_module .graph .erase_node (node )
278- elif node .name in toplevel_signature .inputs_to_parameters :
279- # Delete the consumed parameters
280- param_name = toplevel_signature .inputs_to_parameters .get (node .name )
281- owning_program .state_dict .pop (param_name )
282- tagged_graph_module .graph .erase_node (node )
283-
284- tagged_graph_module .recompile ()
283+ if is_submodule :
284+ assert len (toplevel_input_specs_to_delete ) == 0
285+ assert len (toplevel_output_specs_to_delete ) == 0
286+ elif (
287+ len (toplevel_input_specs_to_delete ) > 0
288+ or len (toplevel_output_specs_to_delete ) > 0
289+ ):
290+ _unsafe_adjust_original_program (
291+ owning_program ,
292+ call_delegate_node ,
293+ toplevel_input_specs_to_delete ,
294+ toplevel_output_specs_to_delete ,
295+ )
296+
285297 return tagged_graph_module
286298
287299
288300def _partition_and_lower (
289301 tagged_graph_module : torch .fx .GraphModule ,
290302 partition_result : PartitionResult ,
291303 owning_program : ExportedProgram ,
304+ is_submodule : bool = False ,
292305) -> torch .fx .GraphModule :
293306 """
294307 Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
295308 """
296309
297310 partitioned_module = _partition_and_lower_one_graph_module (
298- tagged_graph_module , partition_result , owning_program
311+ tagged_graph_module , partition_result , owning_program , is_submodule
299312 )
300313
301314 # Recursively partition and lower for submodules
302315 for name , submod , _node in get_control_flow_submodules (partitioned_module ):
303316 partitioned_submodule = _partition_and_lower (
304- submod , partition_result , owning_program
317+ submod , partition_result , owning_program , is_submodule = True
305318 )
306319 tagged_graph_module .add_module (name , partitioned_submodule )
307320
308- # Run the export pass over the graph module so that the call delegate
309- # nodes will match Edge dialect
310- # TODO(angelayi): ExportPass will rerun the graph, however all we need
311- # here is to add metadata to the call delegate nodes to preserve Edge
312- # dialect. There's work going on to generate a random tensor from a
313- # fake tensor and possibly it can help to address the issue.
314- res = ExportPass ()(tagged_graph_module )
315- assert res is not None
316- tagged_graph_module = res .graph_module
317-
318321 return tagged_graph_module
319322
320323
@@ -349,6 +352,8 @@ def to_backend(
349352 Returns:
350353 ExportedProgram: The input program, with some portions targeted for delegation.
351354 """
355+ edge_program ._validate ()
356+
352357 # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
353358 # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
354359 try :
@@ -377,26 +382,22 @@ def to_backend(
377382 update_to_real_program (tagged_exported_program , edge_program )
378383
379384 for tag , _ in partitioner_result .partition_tags .items ():
380- _maybe_duplicate_constant_nodes (tagged_exported_program , tag , edge_program )
385+ _maybe_duplicate_constant_nodes (tagged_exported_program , tag )
381386
382387 tagged_graph_module = _partition_and_lower (
383- tagged_exported_program .graph_module , partitioner_result , edge_program
388+ tagged_exported_program .graph_module ,
389+ partitioner_result ,
390+ tagged_exported_program ,
384391 )
385392
386- # TODO(angelayi): Update this signature in a less manual way (maybe through
387- # retracing)
388- new_signature , new_state_dict , new_constants = _get_new_signature (
389- edge_program ,
390- tagged_graph_module ,
391- )
392393 return ExportedProgram (
393394 root = tagged_graph_module ,
394395 graph = tagged_graph_module .graph ,
395- graph_signature = new_signature ,
396- state_dict = new_state_dict ,
397- range_constraints = copy .deepcopy (edge_program .range_constraints ),
398- module_call_graph = copy .deepcopy (edge_program .module_call_graph ),
396+ graph_signature = tagged_exported_program . graph_signature ,
397+ state_dict = tagged_exported_program . state_dict ,
398+ range_constraints = copy .deepcopy (tagged_exported_program .range_constraints ),
399+ module_call_graph = copy .deepcopy (tagged_exported_program .module_call_graph ),
399400 example_inputs = None ,
400- constants = new_constants ,
401- verifiers = [edge_program .verifier ],
401+ constants = tagged_exported_program . constants ,
402+ verifiers = [tagged_exported_program .verifier ],
402403 )
0 commit comments