|
9 | 9 | import logging |
10 | 10 | from abc import ABC, abstractmethod |
11 | 11 | from collections import defaultdict |
12 | | -from typing import Any, Callable, Dict, List, Optional |
| 12 | +from typing import Any, Callable, cast, Dict, List, Optional |
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | from executorch.devtools.backend_debug import get_delegation_info |
16 | | -from executorch.exir import EdgeCompileConfig, ExportedProgram |
| 16 | +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, ExportedProgram |
17 | 17 | from executorch.exir.backend.backend_api import validation_disabled |
| 18 | +from executorch.exir.pass_base import ExportPass, RequireExportedProgram |
| 19 | +from executorch.exir.pass_manager import PassManager |
18 | 20 | from executorch.exir.program import to_edge, to_edge_transform_and_lower |
19 | 21 | from executorch.export.recipe import LoweringRecipe, QuantizationRecipe |
20 | 22 | from executorch.export.types import StageType |
@@ -118,7 +120,7 @@ def __init__( |
118 | 120 | self.strict = strict |
119 | 121 |
|
120 | 122 | @property |
121 | | - def stage_type(self) -> str: |
| 123 | + def stage_type(self) -> StageType: |
122 | 124 | return StageType.TORCH_EXPORT |
123 | 125 |
|
124 | 126 | @property |
@@ -197,7 +199,7 @@ def from_recipe( |
197 | 199 | ) |
198 | 200 |
|
199 | 201 | @property |
200 | | - def stage_type(self) -> str: |
| 202 | + def stage_type(self) -> StageType: |
201 | 203 | return StageType.TO_EDGE_TRANSFORM_AND_LOWER |
202 | 204 |
|
203 | 205 | @property |
@@ -266,7 +268,7 @@ def __init__(self, backend_config: Any) -> None: |
266 | 268 | self._backend_config = backend_config |
267 | 269 |
|
268 | 270 | @property |
269 | | - def stage_type(self) -> str: |
| 271 | + def stage_type(self) -> StageType: |
270 | 272 | return StageType.TO_EXECUTORCH |
271 | 273 |
|
272 | 274 | @property |
@@ -304,7 +306,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None: |
304 | 306 | self._transformed_models: Dict[str, nn.Module] = {} |
305 | 307 |
|
306 | 308 | @property |
307 | | - def stage_type(self) -> str: |
| 309 | + def stage_type(self) -> StageType: |
308 | 310 | return StageType.SOURCE_TRANSFORM |
309 | 311 |
|
310 | 312 | @property |
@@ -358,7 +360,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None: |
358 | 360 | self._quantization_recipe = quantization_recipe |
359 | 361 |
|
360 | 362 | @property |
361 | | - def stage_type(self) -> str: |
| 363 | + def stage_type(self) -> StageType: |
362 | 364 | return StageType.QUANTIZE |
363 | 365 |
|
364 | 366 | @property |
@@ -459,7 +461,7 @@ def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStag |
459 | 461 | ) |
460 | 462 |
|
461 | 463 | @property |
462 | | - def stage_type(self) -> str: |
| 464 | + def stage_type(self) -> StageType: |
463 | 465 | return StageType.TO_EDGE |
464 | 466 |
|
465 | 467 | @property |
@@ -520,7 +522,7 @@ def from_recipe( |
520 | 522 | ) |
521 | 523 |
|
522 | 524 | @property |
523 | | - def stage_type(self) -> str: |
| 525 | + def stage_type(self) -> StageType: |
524 | 526 | return StageType.TO_BACKEND |
525 | 527 |
|
526 | 528 | @property |
@@ -583,3 +585,102 @@ def delegation_info(self) -> Any: |
583 | 585 | Returns the delegation info. |
584 | 586 | """ |
585 | 587 | return self._artifact.get_context("delegation_info") |
| 588 | + |
| 589 | + |
| 590 | +class PostToBackendStage(Stage): |
| 591 | + """ |
| 592 | + Stage: Run passes after all partitioners have done their partitioning. |
| 593 | + """ |
| 594 | + |
| 595 | + def __init__( |
| 596 | + self, |
| 597 | + pass_list_or_manager: ( |
| 598 | + list[PassType | type[ExportPass]] | PassManager | None |
| 599 | + ) = None, |
| 600 | + edge_compile_config: EdgeCompileConfig | None = None, |
| 601 | + ) -> None: |
| 602 | + super().__init__() |
| 603 | + if pass_list_or_manager is None: |
| 604 | + pass_list_or_manager = [] |
| 605 | + |
| 606 | + self._pass_list_or_manager = pass_list_or_manager |
| 607 | + self._edge_compile_config = edge_compile_config |
| 608 | + |
| 609 | + @classmethod |
| 610 | + def from_recipe( |
| 611 | + cls, lowering_recipe: Optional["LoweringRecipe"] |
| 612 | + ) -> "PostToBackendStage": |
| 613 | + if lowering_recipe is None: |
| 614 | + return cls() |
| 615 | + |
| 616 | + return cls( |
| 617 | + pass_list=lowering_recipe.post_to_backend_passes, |
| 618 | + edge_compile_config=lowering_recipe.edge_compile_config, |
| 619 | + ) |
| 620 | + |
| 621 | + @property |
| 622 | + def stage_type(self) -> StageType: |
| 623 | + return StageType.POST_TO_BACKEND |
| 624 | + |
| 625 | + @property |
| 626 | + def valid_predecessor_stages(self) -> List["StageType"]: |
| 627 | + return [StageType.TO_BACKEND, StageType.TO_EDGE_TRANSFORM_AND_LOWER] |
| 628 | + |
| 629 | + @property |
| 630 | + def can_start_pipeline(self) -> bool: |
| 631 | + return False |
| 632 | + |
| 633 | + def run(self, artifact: PipelineArtifact) -> None: |
| 634 | + """ |
| 635 | + Run list of passes using edge_program_manager.transform(). |
| 636 | +
|
| 637 | + Args: |
| 638 | + artifact: PipelineArtifact which's data field is expected to contain an edge_program_manager. |
| 639 | + """ |
| 640 | + |
| 641 | + if self._pass_list_or_manager: |
| 642 | + edge_program_manager = cast(EdgeProgramManager, artifact.data) |
| 643 | + |
| 644 | + if isinstance(self._pass_list_or_manager, PassManager): |
| 645 | + edge_program_manager = edge_program_manager.transform( |
| 646 | + self._pass_list_or_manager, self._edge_compile_config |
| 647 | + ) |
| 648 | + else: |
| 649 | + exported_program = edge_program_manager.exported_program() |
| 650 | + pass_instances: list[PassType] = [] |
| 651 | + for _pass in self._pass_list_or_manager: |
| 652 | + if isinstance(_pass, type): |
| 653 | + if not issubclass(_pass, ExportPass): |
| 654 | + raise RuntimeError( |
| 655 | + f"Pass {_pass} was not subclass of ExportPass." |
| 656 | + ) |
| 657 | + if issubclass(_pass, RequireExportedProgram): |
| 658 | + pass_instance = _pass( |
| 659 | + exported_program=exported_program # type: ignore |
| 660 | + ) |
| 661 | + else: |
| 662 | + pass_instance = _pass() |
| 663 | + pass_instances.append(pass_instance) |
| 664 | + else: |
| 665 | + pass_instances.append(_pass) |
| 666 | + |
| 667 | + edge_program_manager = edge_program_manager.transform( |
| 668 | + pass_instances, self._edge_compile_config |
| 669 | + ) |
| 670 | + # Get delegation info |
| 671 | + delegation_info = get_delegation_info( |
| 672 | + edge_program_manager.exported_program().graph_module |
| 673 | + ) |
| 674 | + |
| 675 | + self._artifact = artifact.copy_with_new_data(edge_program_manager) |
| 676 | + self._artifact.add_context("delegation_info", delegation_info) |
| 677 | + else: |
| 678 | + # If pass_list_or_manager is None or empty list, do nothing. |
| 679 | + self._artifact = artifact |
| 680 | + |
| 681 | + @property |
| 682 | + def delegation_info(self) -> Any: |
| 683 | + """ |
| 684 | + Returns the delegation info. |
| 685 | + """ |
| 686 | + return self._artifact.get_context("delegation_info") |
0 commit comments