@@ -1264,3 +1264,63 @@ def forward(self, x: List[torch.Tensor]):
1264
1264
1265
1265
gm = to_edge (export (ComposedM (), inputs , strict = True ))
1266
1266
gm .exported_program ().module ()(* inputs )
1267
+
1268
+ def test_delegate_info_full_delegate (self ):
1269
+ """
1270
+ Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
1271
+ when using full delegation (to_backend directly).
1272
+ """
1273
+
1274
+ class SinModule (torch .nn .Module ):
1275
+ def __init__ (self ):
1276
+ super ().__init__ ()
1277
+
1278
+ def forward (self , x ):
1279
+ return torch .sin (x )
1280
+
1281
+ sin_module = SinModule ()
1282
+ model_inputs = (torch .ones (1 ),)
1283
+ edgeir_m = to_edge (export (sin_module , model_inputs , strict = True ))
1284
+ max_value = model_inputs [0 ].shape [0 ]
1285
+ compile_specs = [CompileSpec ("max_value" , bytes ([max_value ]))]
1286
+ lowered_sin_module = to_backend (
1287
+ "BackendWithCompilerDemo" , edgeir_m .exported_program (), compile_specs
1288
+ )
1289
+
1290
+ # Check that the lowered module has _delegate_info_meta in its meta
1291
+ self .assertIn ("_delegate_info_meta" , lowered_sin_module .meta .keys ())
1292
+ self .assertEqual (lowered_sin_module .meta ["_delegate_info_meta" ], "test" )
1293
+
1294
+ def test_delegate_info_partitioner (self ):
1295
+ """
1296
+ Test that _delegate_info_meta from BackendWithCompilerDemo ends up in the call_delegate node metadata
1297
+ when using partitioner-based delegation.
1298
+ """
1299
+
1300
+ class SinModule (torch .nn .Module ):
1301
+ def __init__ (self ):
1302
+ super ().__init__ ()
1303
+
1304
+ def forward (self , x ):
1305
+ return torch .sin (x )
1306
+
1307
+ sin_module = SinModule ()
1308
+ model_inputs = (torch .ones (1 ),)
1309
+ max_value = model_inputs [0 ].shape [0 ]
1310
+
1311
+ partitioner = AllNodePartitioner (
1312
+ "BackendWithCompilerDemo" , [CompileSpec ("max_value" , bytes ([max_value ]))]
1313
+ )
1314
+
1315
+ edgeir_m = to_edge (export (sin_module , model_inputs , strict = True ))
1316
+ lowered_m = edgeir_m .to_backend (partitioner )
1317
+
1318
+ # Check that the lowered submodule has _delegate_info_meta in its meta
1319
+ lowered_submodules = get_lowered_submodules (
1320
+ lowered_m .exported_program ().graph_module
1321
+ )
1322
+ self .assertEqual (len (lowered_submodules ), 1 )
1323
+
1324
+ lowered_module = lowered_submodules [0 ][1 ]
1325
+ self .assertIn ("_delegate_info_meta" , lowered_module .meta )
1326
+ self .assertEqual (lowered_module .meta ["_delegate_info_meta" ], "test" )
0 commit comments