| 
10 | 10 | 
 
  | 
11 | 11 | import executorch.exir as exir  | 
12 | 12 | import torch  | 
 | 13 | +from executorch.exir import to_edge  | 
13 | 14 | from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend  | 
 | 15 | +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (  | 
 | 16 | +    AllNodePartitioner,  | 
 | 17 | +)  | 
14 | 18 | from executorch.exir.backend.compile_spec_schema import CompileSpec  | 
15 | 19 | from executorch.exir.backend.partitioner import (  | 
16 | 20 |     DelegationSpec,  | 
@@ -1266,3 +1270,178 @@ def forward(self, x: List[torch.Tensor]):  | 
1266 | 1270 | 
 
  | 
1267 | 1271 |         gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()  | 
1268 | 1272 |         gm(*inputs)  | 
 | 1273 | + | 
 | 1274 | +    def test_to_backend_delegation_spec(self):  | 
 | 1275 | +        class SinModule(torch.nn.Module):  | 
 | 1276 | +            def __init__(self):  | 
 | 1277 | +                super().__init__()  | 
 | 1278 | + | 
 | 1279 | +            def forward(self, x):  | 
 | 1280 | +                return [torch.sin(x)]  | 
 | 1281 | + | 
 | 1282 | +        sin_module = SinModule()  | 
 | 1283 | +        model_inputs = (torch.ones(1),)  | 
 | 1284 | +        max_value = model_inputs[0].shape[0]  | 
 | 1285 | + | 
 | 1286 | +        partitioner = AllNodePartitioner(  | 
 | 1287 | +            "BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]  | 
 | 1288 | +        )  | 
 | 1289 | + | 
 | 1290 | +        edgeir_m = to_edge(torch.export.export(sin_module, model_inputs))  | 
 | 1291 | +        edgeir_m = edgeir_m.to_backend(partitioner)  | 
 | 1292 | +        exec_prog = edgeir_m.to_executorch()  | 
 | 1293 | +        graph_module = exec_prog.exported_program().graph_module  | 
 | 1294 | +        # Check that there is not an aten.sin node.  | 
 | 1295 | +        self.assertTrue(  | 
 | 1296 | +            exir_ops.edge.aten.sin  | 
 | 1297 | +            not in {node.target for node in graph_module.graph.nodes}  | 
 | 1298 | +        )  | 
 | 1299 | + | 
 | 1300 | +        # Check that there exists a call_delegate, representing the call to the  | 
 | 1301 | +        # delegated function  | 
 | 1302 | +        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(  | 
 | 1303 | +            graph_module.code  | 
 | 1304 | +        )  | 
 | 1305 | +        lowered_submodules = get_lowered_submodules(graph_module)  | 
 | 1306 | +        self.assertEqual(len(lowered_submodules), 1)  | 
 | 1307 | + | 
 | 1308 | +        for node in graph_module.graph.nodes:  | 
 | 1309 | +            if node.op == "call_function" and node.target == executorch_call_delegate:  | 
 | 1310 | +                # Check that first arg is lowered_module_{unique_id}  | 
 | 1311 | +                self.assertEqual(node.args[0].target, "lowered_module_0")  | 
 | 1312 | + | 
 | 1313 | +        program = exec_prog.executorch_program  | 
 | 1314 | + | 
 | 1315 | +        # Check the program can be printed  | 
 | 1316 | +        print_program(program)  | 
 | 1317 | + | 
 | 1318 | +        # Check the backend delegate  | 
 | 1319 | +        self.check_backend_delegate(  | 
 | 1320 | +            program=program,  | 
 | 1321 | +            delegate=program.execution_plan[0].delegates[0],  | 
 | 1322 | +            expected_id=BackendWithCompilerDemo.__name__,  | 
 | 1323 | +            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",  | 
 | 1324 | +        )  | 
 | 1325 | + | 
 | 1326 | +        # Check the delegate instruction  | 
 | 1327 | +        self.assertTrue(  | 
 | 1328 | +            isinstance(  | 
 | 1329 | +                program.execution_plan[0].chains[0].instructions[0].instr_args,  | 
 | 1330 | +                DelegateCall,  | 
 | 1331 | +            )  | 
 | 1332 | +        )  | 
 | 1333 | +        buff = exec_prog.buffer  | 
 | 1334 | + | 
 | 1335 | +        executorch_module = _load_for_executorch_from_buffer(buff)  | 
 | 1336 | +        model_inputs = torch.ones(1)  | 
 | 1337 | +        model_outputs = executorch_module.forward([model_inputs])  | 
 | 1338 | +        self.assertEqual(  | 
 | 1339 | +            model_inputs,  | 
 | 1340 | +            torch.ones(1),  | 
 | 1341 | +        )  | 
 | 1342 | +        expected_output = 0.8333 * torch.ones(1)  | 
 | 1343 | + | 
 | 1344 | +        self.assertTrue(  | 
 | 1345 | +            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)  | 
 | 1346 | +        )  | 
 | 1347 | + | 
 | 1348 | +    def test_to_backend_multimethod_delegation_spec(self):  | 
 | 1349 | +        class SinModule(torch.nn.Module):  | 
 | 1350 | +            def __init__(self):  | 
 | 1351 | +                super().__init__()  | 
 | 1352 | + | 
 | 1353 | +            def forward(self, x):  | 
 | 1354 | +                return torch.sin(x)  | 
 | 1355 | + | 
 | 1356 | +            def inputs(self):  | 
 | 1357 | +                return (torch.ones(1),)  | 
 | 1358 | + | 
 | 1359 | +        class AddMulModule(torch.nn.Module):  | 
 | 1360 | +            def __init__(self):  | 
 | 1361 | +                super().__init__()  | 
 | 1362 | + | 
 | 1363 | +            def forward(self, a, x, b):  | 
 | 1364 | +                y = torch.mm(a, x)  | 
 | 1365 | +                z = torch.add(y, b)  | 
 | 1366 | +                return z  | 
 | 1367 | + | 
 | 1368 | +            def inputs(self):  | 
 | 1369 | +                return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))  | 
 | 1370 | + | 
 | 1371 | +        sin_module = SinModule()  | 
 | 1372 | +        max_value_sin = sin_module.inputs()[0].shape[0]  | 
 | 1373 | +        sin_partitioner = AllNodePartitioner(  | 
 | 1374 | +            "BackendWithCompilerDemo",  | 
 | 1375 | +            [CompileSpec("max_value", bytes([max_value_sin]))],  | 
 | 1376 | +        )  | 
 | 1377 | + | 
 | 1378 | +        add_mul_module = AddMulModule()  | 
 | 1379 | +        max_value_add_mul = add_mul_module.inputs()[0].shape[0]  | 
 | 1380 | +        add_mul_partitioner = AllNodePartitioner(  | 
 | 1381 | +            "BackendWithCompilerDemo",  | 
 | 1382 | +            [CompileSpec("max_value", bytes([max_value_add_mul]))],  | 
 | 1383 | +        )  | 
 | 1384 | + | 
 | 1385 | +        edgeir_m = to_edge(  | 
 | 1386 | +            {  | 
 | 1387 | +                "sin": torch.export.export(sin_module, sin_module.inputs()),  | 
 | 1388 | +                "add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()),  | 
 | 1389 | +            }  | 
 | 1390 | +        )  | 
 | 1391 | +        edgeir_m = edgeir_m.to_backend(  | 
 | 1392 | +            {  | 
 | 1393 | +                "sin": sin_partitioner,  | 
 | 1394 | +                "add_mul": add_mul_partitioner,  | 
 | 1395 | +            }  | 
 | 1396 | +        )  | 
 | 1397 | +        exec_prog = edgeir_m.to_executorch()  | 
 | 1398 | + | 
 | 1399 | +        for method_name in ["sin", "add_mul"]:  | 
 | 1400 | +            graph_module = exec_prog.exported_program(method_name).graph_module  | 
 | 1401 | +            # Check delegated nodes are gone  | 
 | 1402 | +            self.assertTrue(  | 
 | 1403 | +                exir_ops.edge.aten.sin  | 
 | 1404 | +                not in {node.target for node in graph_module.graph.nodes}  | 
 | 1405 | +            )  | 
 | 1406 | +            self.assertTrue(  | 
 | 1407 | +                exir_ops.edge.aten.add  | 
 | 1408 | +                not in {node.target for node in graph_module.graph.nodes}  | 
 | 1409 | +            )  | 
 | 1410 | +            self.assertTrue(  | 
 | 1411 | +                exir_ops.edge.aten.mm  | 
 | 1412 | +                not in {node.target for node in graph_module.graph.nodes}  | 
 | 1413 | +            )  | 
 | 1414 | +            # Check that there exists a call_delegate, representing the call to the  | 
 | 1415 | +            # delegated function  | 
 | 1416 | +            FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(  | 
 | 1417 | +                graph_module.code  | 
 | 1418 | +            )  | 
 | 1419 | +            lowered_submodules = get_lowered_submodules(graph_module)  | 
 | 1420 | +            self.assertEqual(len(lowered_submodules), 1)  | 
 | 1421 | + | 
 | 1422 | +        program = exec_prog.executorch_program  | 
 | 1423 | + | 
 | 1424 | +        # Check the program can be printed  | 
 | 1425 | +        print_program(program)  | 
 | 1426 | + | 
 | 1427 | +        buff = exec_prog.buffer  | 
 | 1428 | + | 
 | 1429 | +        executorch_module = _load_for_executorch_from_buffer(buff)  | 
 | 1430 | + | 
 | 1431 | +        for method_name, module in {  | 
 | 1432 | +            "sin": sin_module,  | 
 | 1433 | +            "add_mul": add_mul_module,  | 
 | 1434 | +        }.items():  | 
 | 1435 | +            inputs_flattened, _ = tree_flatten(module.inputs())  | 
 | 1436 | +            model_outputs = executorch_module.run_method(  | 
 | 1437 | +                method_name, tuple(inputs_flattened)  | 
 | 1438 | +            )  | 
 | 1439 | + | 
 | 1440 | +            if method_name == "sin":  | 
 | 1441 | +                # backend with compiler demo does a taylor approximation of sin  | 
 | 1442 | +                ref_output = 0.8333 * torch.ones(1)  | 
 | 1443 | +            else:  | 
 | 1444 | +                ref_output = module(*module.inputs())  | 
 | 1445 | +            self.assertTrue(  | 
 | 1446 | +                torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)  | 
 | 1447 | +            )  | 
0 commit comments