@@ -1264,3 +1264,63 @@ def forward(self, x: List[torch.Tensor]):
12641264
12651265 gm = to_edge (export (ComposedM (), inputs , strict = True ))
12661266 gm .exported_program ().module ()(* inputs )
1267+
1268+ def test_delegate_info_full_delegate (self ):
1269+ """
1270+ Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
1271+ when using full delegation (to_backend directly).
1272+ """
1273+
1274+ class SinModule (torch .nn .Module ):
1275+ def __init__ (self ):
1276+ super ().__init__ ()
1277+
1278+ def forward (self , x ):
1279+ return torch .sin (x )
1280+
1281+ sin_module = SinModule ()
1282+ model_inputs = (torch .ones (1 ),)
1283+ edgeir_m = to_edge (export (sin_module , model_inputs , strict = True ))
1284+ max_value = model_inputs [0 ].shape [0 ]
1285+ compile_specs = [CompileSpec ("max_value" , bytes ([max_value ]))]
1286+ lowered_sin_module = to_backend (
1287+ "BackendWithCompilerDemo" , edgeir_m .exported_program (), compile_specs
1288+ )
1289+
1290+ # Check that the lowered module has _delegate_info_meta in its meta
1291+ self .assertIn ("_delegate_info_meta" , lowered_sin_module .meta .keys ())
1292+ self .assertEqual (lowered_sin_module .meta ["_delegate_info_meta" ], "test" )
1293+
1294+ def test_delegate_info_partitioner (self ):
1295+ """
1296+ Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
1297+ when using partitioner-based delegation.
1298+ """
1299+
1300+ class SinModule (torch .nn .Module ):
1301+ def __init__ (self ):
1302+ super ().__init__ ()
1303+
1304+ def forward (self , x ):
1305+ return torch .sin (x )
1306+
1307+ sin_module = SinModule ()
1308+ model_inputs = (torch .ones (1 ),)
1309+ max_value = model_inputs [0 ].shape [0 ]
1310+
1311+ partitioner = AllNodePartitioner (
1312+ "BackendWithCompilerDemo" , [CompileSpec ("max_value" , bytes ([max_value ]))]
1313+ )
1314+
1315+ edgeir_m = to_edge (export (sin_module , model_inputs , strict = True ))
1316+ lowered_m = edgeir_m .to_backend (partitioner )
1317+
1318+ # Check that the lowered submodule has _delegate_info_meta in its meta
1319+ lowered_submodules = get_lowered_submodules (
1320+ lowered_m .exported_program ().graph_module
1321+ )
1322+ self .assertEqual (len (lowered_submodules ), 1 )
1323+
1324+ lowered_module = lowered_submodules [0 ][1 ]
1325+ self .assertIn ("_delegate_info_meta" , lowered_module .meta )
1326+ self .assertEqual (lowered_module .meta ["_delegate_info_meta" ], "test" )
0 commit comments