|
12 | 12 |
|
13 | 13 | import torch |
14 | 14 | from executorch.export import ExportRecipe, ExportSession |
| 15 | +from executorch.export.recipe import LoweringRecipe, QuantizationRecipe |
15 | 16 | from executorch.export.stages import PipelineArtifact |
16 | 17 | from executorch.export.types import StageType |
17 | 18 |
|
@@ -434,3 +435,48 @@ def test_save_to_pte_invalid_name(self) -> None: |
434 | 435 |
|
435 | 436 | with self.assertRaises(AssertionError): |
436 | 437 | session.save_to_pte(None) # pyre-ignore |
| 438 | + |
| 439 | + |
| 440 | +class TestExportSessionPipelineBuilding(unittest.TestCase): |
| 441 | + """Test pipeline building and stage configuration.""" |
| 442 | + |
| 443 | + def setUp(self) -> None: |
| 444 | + self.model = SimpleTestModel() |
| 445 | + self.example_inputs = [(torch.randn(2, 10),)] |
| 446 | + |
| 447 | + def test_pipeline_building_with_all_recipes(self) -> None: |
| 448 | + """Test pipeline building with quantization and lowering recipes.""" |
| 449 | + # Create comprehensive recipes |
| 450 | + quant_recipe = QuantizationRecipe( |
| 451 | + ao_base_config=[Mock()], |
| 452 | + quantizers=[Mock()], |
| 453 | + ) |
| 454 | + lowering_recipe = LoweringRecipe( |
| 455 | + partitioners=[Mock()], |
| 456 | + edge_transform_passes=[Mock()], |
| 457 | + edge_compile_config=Mock(), |
| 458 | + ) |
| 459 | + recipe = ExportRecipe( |
| 460 | + name="comprehensive_test", |
| 461 | + quantization_recipe=quant_recipe, |
| 462 | + lowering_recipe=lowering_recipe, |
| 463 | + executorch_backend_config=Mock(), |
| 464 | + ) |
| 465 | + |
| 466 | + session = ExportSession( |
| 467 | + model=self.model, |
| 468 | + example_inputs=self.example_inputs, |
| 469 | + export_recipe=recipe, |
| 470 | + ) |
| 471 | + |
| 472 | + registered_stages = session.get_all_registered_stages() |
| 473 | + |
| 474 | + self.assertEqual(len(registered_stages), 5) |
| 475 | + expected_types = [ |
| 476 | + StageType.SOURCE_TRANSFORM, |
| 477 | + StageType.QUANTIZE, |
| 478 | + StageType.TORCH_EXPORT, |
| 479 | + StageType.TO_EDGE_TRANSFORM_AND_LOWER, |
| 480 | + StageType.TO_EXECUTORCH, |
| 481 | + ] |
| 482 | + self.assertListEqual(list(registered_stages.keys()), expected_types) |
0 commit comments