@@ -1098,6 +1098,33 @@ def _gen_edge_manager_for_partitioners(
10981098 return edge_manager
10991099
11001100
1101+ def collect_named_data_store_from_exported_program (
1102+ exported_program : ExportedProgram ,
1103+ named_data_store : NamedDataStore ,
1104+ ) -> None :
1105+ """
1106+ Collects all the named data store outputs found within the exported program
1107+ and adds them to named_data_store.
1108+ """
1109+
1110+ # collected all the named data into the named data store for deduplication
1111+ def collect_named_data_store_outputs (
1112+ graph_module : torch .fx .GraphModule ,
1113+ ) -> None :
1114+ for node in graph_module .graph .nodes :
1115+ if node .target == executorch_call_delegate :
1116+ lbm = getattr (graph_module , node .args [0 ].target )
1117+ assert is_lowered_module (lbm )
1118+ data_store_output = lbm .named_data_store_output
1119+ if data_store_output is not None :
1120+ named_data_store .merge_named_data_store (data_store_output )
1121+
1122+ for _ , submod , _ in get_control_flow_submodules (graph_module ):
1123+ collect_named_data_store_outputs (submod )
1124+
1125+ collect_named_data_store_outputs (exported_program .graph_module )
1126+
1127+
11011128@et_logger ("to_edge_transform_and_lower" )
11021129def to_edge_transform_and_lower (
11031130 programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
@@ -1307,7 +1334,6 @@ def __init__(
13071334 constant_methods : Optional [Dict [str , Any ]] = None ,
13081335 compile_config : Optional [EdgeCompileConfig ] = None ,
13091336 ops_set_to_not_decompose : Optional [List [torch ._ops .OpOverload ]] = None ,
1310- named_data_store : Optional [NamedDataStore ] = None ,
13111337 ):
13121338 """
13131339 Should not be called directly by users. User should use :func:'to_edge' instead.
@@ -1331,7 +1357,11 @@ def __init__(
13311357 self ._edge_programs : Dict [str , ExportedProgram ] = edge_programs
13321358 self ._config_methods = constant_methods
13331359
1334- self ._named_data_store = named_data_store or NamedDataStore ()
1360+ self ._named_data_store = NamedDataStore ()
1361+ for _ , program in self ._edge_programs .items ():
1362+ collect_named_data_store_from_exported_program (
1363+ program , self ._named_data_store
1364+ )
13351365
13361366 @property
13371367 def methods (self ) -> Set [str ]:
@@ -1441,30 +1471,11 @@ def to_backend(
14411471 for name , program in self ._edge_programs .items ():
14421472 new_edge_programs [name ] = to_backend (program , partitioner )
14431473
1444- # collected all the named data into the named data store for deduplication
1445- def collect_named_data_store_outputs (
1446- graph_module : torch .fx .GraphModule ,
1447- ) -> None :
1448- for node in graph_module .graph .nodes :
1449- if node .target == executorch_call_delegate :
1450- lbm = getattr (graph_module , node .args [0 ].name )
1451- assert is_lowered_module (lbm )
1452- data_store_output = lbm .named_data_store_output
1453- if data_store_output is not None :
1454- self ._named_data_store .merge_named_data_store (data_store_output )
1455-
1456- for _ , submod , _ in get_control_flow_submodules (graph_module ):
1457- collect_named_data_store_outputs (submod )
1458-
1459- for _ , program in new_edge_programs .items ():
1460- collect_named_data_store_outputs (program .graph_module )
1461-
14621474 config = EdgeCompileConfig (_check_ir_validity = False )
14631475 return EdgeProgramManager (
14641476 new_edge_programs ,
14651477 copy .deepcopy (self ._config_methods ),
14661478 config ,
1467- named_data_store = self ._named_data_store ,
14681479 )
14691480
14701481 @et_logger ("to_executorch" )
0 commit comments