diff --git a/backends/xnnpack/recipes/xnnpack_recipe_provider.py b/backends/xnnpack/recipes/xnnpack_recipe_provider.py index 9d00c3c9c98..8fba58c12c3 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_provider.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_provider.py @@ -27,6 +27,7 @@ from executorch.export import ( BackendRecipeProvider, ExportRecipe, + LoweringRecipe, QuantizationRecipe, RecipeType, ) @@ -88,12 +89,19 @@ def create_recipe( ) return None + def _get_xnnpack_lowering_recipe( + self, precision_type: Optional[ConfigPrecisionType] = None + ) -> LoweringRecipe: + return LoweringRecipe( + partitioners=[XnnpackPartitioner(precision_type=precision_type)], + edge_compile_config=get_xnnpack_edge_compile_config(), + ) + def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe: return ExportRecipe( name=recipe_type.value, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner()], ) def _build_quantized_recipe( @@ -120,9 +128,8 @@ def _build_quantized_recipe( return ExportRecipe( name=recipe_type.value, quantization_recipe=quant_recipe, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(precision_type), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner(config_precision=precision_type)], ) def _build_int8da_intx_weight_recipe( @@ -150,9 +157,8 @@ def _build_int8da_intx_weight_recipe( return ExportRecipe( name=recipe_type.value, quantization_recipe=quant_recipe, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner()], ) def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: diff --git a/export/TARGETS b/export/TARGETS index defb508b33a..816a3a1a289 100644 --- a/export/TARGETS +++ b/export/TARGETS @@ -15,7 +15,6 @@ runtime.python_library( "//caffe2:torch", "//executorch/exir/backend:backend_api", "//executorch/exir:pass_manager", - "//executorch/devtools/backend_debug:delegation_info", "//executorch/extension/export_util:export_util", ] ) @@ -31,11 +30,35 @@ runtime.python_library( ], deps = [ ":recipe", + ":stages", + ":types", "//executorch/runtime:runtime", ":recipe_registry" ] ) + +runtime.python_library( + name = "stages", + srcs = [ + "stages.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + ":recipe", + ":types", + "//executorch/devtools/backend_debug:delegation_info", + "//executorch/exir/backend:backend_api", + "//executorch/exir:pass_manager", + "//caffe2:torch", + "//executorch/devtools/backend_debug:delegation_info", + ] +) + + runtime.python_library( name = "lib", srcs = [ @@ -48,8 +71,10 @@ runtime.python_library( deps = [ ":export", ":recipe", + ":stages", ":recipe_registry", - ":recipe_provider" + ":recipe_provider", + ":types", ], ) @@ -78,3 +103,10 @@ runtime.python_library( ":recipe", ] ) + +runtime.python_library( + name = "types", + srcs = [ + "types.py", + ], +) diff --git a/export/__init__.py b/export/__init__.py index a39f7b86a53..d5f3826ab90 100644 --- a/export/__init__.py +++ b/export/__init__.py @@ -15,13 +15,15 @@ """ from .export import export, ExportSession -from .recipe import ExportRecipe, QuantizationRecipe, RecipeType +from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType from .recipe_provider import BackendRecipeProvider from .recipe_registry import recipe_registry - +from .types import StageType __all__ = [ + "StageType", "ExportRecipe", + "LoweringRecipe", "QuantizationRecipe", "ExportSession", "export", diff --git a/export/export.py b/export/export.py index 0246a375493..e5c3b793ccd 100644 --- a/export/export.py +++ b/export/export.py @@ -5,428 +5,30 @@ # LICENSE file in the root directory of this source tree. import logging -from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch -from executorch.devtools.backend_debug import get_delegation_info from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_api import validation_disabled -from executorch.exir.program import ( - EdgeProgramManager, - ExecutorchProgramManager, - to_edge_transform_and_lower, -) -from executorch.exir.program._program import _transform +from executorch.exir.program import ExecutorchProgramManager from executorch.exir.schema import Program -from executorch.export.recipe import QuantizationRecipe from executorch.extension.export_util.utils import save_pte_program from executorch.runtime import Runtime, Verification from tabulate import tabulate from torch import nn -from torch._export.pass_base import PassType -from torch.export import ExportedProgram -from torchao.quantization import quantize_ -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - -from torchao.quantization.pt2e.quantizer import ComposableQuantizer -from torchao.utils import unwrap_tensor_subclass - -from .recipe import ExportRecipe - - -class Stage(ABC): - """ - Interface for a Stage in the ExecuTorch export pipeline. - - Each stage can be connected to other stages to form a pipeline. - Stages have clear run and get_outputs functions to make the data flow explicit. - Each stage implements its own run method with specific parameter names. - """ - - def __init__(self) -> None: - """ - Initialize the stage. - """ - self._next_stage = None - - @property - @abstractmethod - def name(self) -> str: - """ - Returns the name of this stage. - """ - pass - - @abstractmethod - def run(self, **kwargs) -> None: - """ - Executes this stage with the given inputs. - - Each concrete stage class implements this method with specific parameter names. - """ - pass - - @abstractmethod - def get_artifacts(self) -> Any: - """ - Returns the artifacts generated by this stage. - - Returns: - The artifacts of this stage, to be used as inputs for the next stage - """ - pass - - def set_next_stage(self, next_stage: "Stage") -> None: - """ - Set the next stage in the pipeline. - - Args: - next_stage: The next stage to execute after this one - """ - self._next_stage = next_stage - - @property - def next_stage(self) -> Optional["Stage"]: - """ - Get the next stage in the pipeline. - - Returns: - The next stage, or None if this is the last stage - """ - return self._next_stage - - -class ExportStage(Stage): - """ - First stage: Export PyTorch model to ExportedProgram. - """ - - def __init__( - self, - pre_edge_transform_passes: Optional[List[PassType]] = None, - ) -> None: - self._exported_program: Dict[str, ExportedProgram] = {} - self._pre_edge_transform_passes = pre_edge_transform_passes - self._model_dict: Dict[str, nn.Module] = {} - self._example_inputs_dict: Dict[str, List[tuple[torch.Tensor, ...]]] = {} - self._dynamic_shapes_dict: Dict[str, Any] = {} - - @property - def name(self) -> str: - return "export" - - def run( - self, - models: Dict[str, Any], - export_config: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> None: - """ - Export PyTorch model to ExportedProgram. - - Args: - models: Dictionary mapping method names to PyTorch models - export_config: Configuration containing example inputs and dynamic shapes - **kwargs: Additional keyword arguments (not used) - """ - # Store inputs - self._model_dict = models.get("model", {}) - - if export_config is not None: - self._example_inputs_dict = export_config.get("example_inputs", {}) - self._dynamic_shapes_dict = export_config.get("dynamic_shapes", {}) - - # Process inputs - with torch.no_grad(): - for method_name, model in self._model_dict.items(): - # Check if method_name exists in example_inputs - if method_name not in self._example_inputs_dict: - raise ValueError( - f"Example inputs for method {method_name} not found." - ) - - # Get dynamic shapes if available - dynamic_shapes = None - if method_name in self._dynamic_shapes_dict: - dynamic_shapes = self._dynamic_shapes_dict[method_name] - - # Export the model - self._exported_program[method_name] = torch.export.export( - model, - self._example_inputs_dict[method_name][0], - dynamic_shapes=dynamic_shapes, - strict=True, - ) - - # Apply pre-edge transform passes if available - if pre_edge_transform_passes := self._pre_edge_transform_passes or []: - for pass_ in pre_edge_transform_passes: - self._exported_program[method_name] = _transform( - self._exported_program[method_name], pass_ - ) - - def get_artifacts(self) -> Dict[str, ExportedProgram]: - """ - Returns the exported program dictionary. - - Returns: - Dictionary mapping method names to exported programs - """ - return self._exported_program - - -class EdgeTransformAndLowerStage(Stage): - """ - Second stage: Transform and lower to EdgeProgramManager. - """ - - def __init__( - self, - partitioners: Optional[List[Any]] = None, - transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, - compile_config: Optional[Any] = None, - ) -> None: - self._partitioners = partitioners - self._transform_passes = transform_passes - self._compile_config = compile_config - self._edge_program_manager: Optional[EdgeProgramManager] = None - self._delegation_info = None - self._exported_program: Dict[str, ExportedProgram] = {} - self._constant_methods = None - - @property - def name(self) -> str: - return "edge_transform_and_lower" - - def run( - self, - exported_programs: Dict[str, ExportedProgram], - transform_config: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> None: - """ - Transform and lower to EdgeProgramManager. - - Args: - exported_programs: Dictionary mapping method names to exported programs - transform_config: Configuration containing constant methods - **kwargs: Additional keyword arguments (not used) - """ - # Store inputs - self._exported_program = exported_programs - - self._constant_methods = None - if transform_config is not None: - self._constant_methods = transform_config.get("constant_methods", None) - - # Process inputs - with validation_disabled(): - self._edge_program_manager = to_edge_transform_and_lower( - self._exported_program, - partitioner=self._partitioners, - transform_passes=self._transform_passes, - constant_methods=self._constant_methods, - compile_config=self._compile_config, - ) - self._delegation_info = get_delegation_info( - self._edge_program_manager.exported_program().graph_module - ) - - def get_artifacts(self) -> EdgeProgramManager: - """ - Returns the edge program manager. - - Returns: - The edge program manager - - Raises: - RuntimeError: If the edge program manager is not initialized - """ - if self._edge_program_manager is None: - raise RuntimeError("Edge program manager is not initialized.") - return self._edge_program_manager - - @property - def delegation_info(self) -> Any: - """ - Returns the delegation info. - """ - return self._delegation_info - - -class ExecutorchStage(Stage): - """ - Third stage: Convert to ExecutorchProgramManager. - """ - - def __init__(self, backend_config: Any) -> None: - self._backend_config = backend_config - self._executorch_program_manager: Optional[ExecutorchProgramManager] = None - self._edge_program_manager: Optional[EdgeProgramManager] = None - - @property - def name(self) -> str: - return "executorch" - - def run( - self, - edge_program: EdgeProgramManager, - backend_options: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> None: - """ - Convert to ExecutorchProgramManager. - - Args: - edge_program: Edge program manager containing the lowered program - backend_options: Additional backend-specific options (not used in this stage) - **kwargs: Additional keyword arguments (not used) - """ - # Store inputs - self._edge_program_manager = edge_program - - # Process inputs - if self._edge_program_manager is None: - raise RuntimeError("Edge program manager is not set.") - - self._executorch_program_manager = self._edge_program_manager.to_executorch( - self._backend_config - ) - - def get_artifacts(self) -> ExecutorchProgramManager: - """ - Returns the executorch program manager. - - Returns: - The executorch program manager - - Raises: - RuntimeError: If the executorch program manager is not initialized - """ - if self._executorch_program_manager is None: - raise RuntimeError("Executorch program manager is not initialized.") - return self._executorch_program_manager - - -class SourceTransformStage(Stage): - """ - Source transform stage: Apply source transformations to the model. - """ - - def __init__(self, quantization_recipe: Any) -> None: - self._quantization_recipe = quantization_recipe - self._transformed_models: Dict[str, nn.Module] = {} - - @property - def name(self) -> str: - return "source_transform" - - def run(self, models: Dict[str, nn.Module], *args, **kwargs) -> None: - """ - Apply source transformations to the model. - - Args: - models: Dictionary mapping method names to PyTorch models - **kwargs: Additional keyword arguments (not used) - """ - # Store the original models - self._transformed_models = models - - # Check if there's a quantization recipe with ao_base_config - if self._quantization_recipe and self._quantization_recipe.ao_base_config: - # Apply torchao quantize_ to each model - for method_name, model in models.items(): - for config in self._quantization_recipe.ao_base_config: - quantize_(model, config) - unwrap_tensor_subclass(model) - self._transformed_models[method_name] = model - - def get_artifacts(self) -> Dict[str, nn.Module]: - """ - Returns the transformed models. - - Returns: - Dictionary mapping method names to transformed models - """ - return self._transformed_models - - -class QuantizeStage(Stage): - """ - Optional stage: Perform post-training quantization on the model. - """ - - def __init__(self, quantizers: Any) -> None: - self._quantizers = quantizers - self._quantized_models: Dict[str, nn.Module] = {} - self._exported_programs: Dict[str, ExportedProgram] = {} - self._model_dict: Dict[str, nn.Module] = {} - self._example_inputs_dict: Dict[str, List[tuple[torch.Tensor, ...]]] = {} - - @property - def name(self) -> str: - return "quantize" - - def run( - self, - models: Dict[str, nn.Module], - calibration_config: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> None: - """ - Perform post-training quantization on the model. - - Args: - models: Dictionary containing models to quantize - calibration_config: Configuration containing example inputs for calibration - **kwargs: Additional keyword arguments (not used) - """ - # Store inputs - self._model_dict = models - - # Initialize with empty dictionaries - self._example_inputs_dict = {} - - if calibration_config is not None: - self._example_inputs_dict = calibration_config.get("example_inputs", {}) - - # Process inputs - for method_name, model in self._model_dict.items(): - # Check if method_name exists in example_inputs and has at least one element - if ( - method_name not in self._example_inputs_dict - or not self._example_inputs_dict[method_name] - ): - raise ValueError( - f"Example inputs for method {method_name} not found or empty." - ) - - # Export the model for training to get a captured graph - inputs = self._example_inputs_dict[method_name][0] - captured_graph = torch.export.export(model, inputs, strict=True).module() - - # Prepare the model for quantization - composed_quantizer = ComposableQuantizer(self._quantizers) - prepared_model = prepare_pt2e(captured_graph, composed_quantizer) # type: ignore - - # Calibrate the model with the provided calibration data - for calibration_input in self._example_inputs_dict[method_name]: # type: ignore - prepared_model(*calibration_input) - - # Convert the prepared model to a quantized model - quantized_model = convert_pt2e(prepared_model) - self._quantized_models[method_name] = quantized_model - - def get_artifacts(self) -> Dict[str, nn.Module]: - """ - Returns the quantized models. - - Returns: - Dictionary mapping method names to quantized models - """ - return self._quantized_models +from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe +from .stages import ( + EdgeTransformAndLowerStage, + ExecutorchStage, + PipelineArtifact, + QuantizeStage, + SourceTransformStage, + Stage, + ToBackendStage, + ToEdgeStage, + TorchExportStage, +) +from .types import StageType @experimental( @@ -535,106 +137,171 @@ def __init__( else: self._dynamic_shapes = {"forward": dynamic_shapes} - self._name = name - self._constant_methods = constant_methods - self._artifact_dir = artifact_dir self._export_recipe = export_recipe self._quant_recipe: Optional[QuantizationRecipe] = ( self._export_recipe.quantization_recipe ) - # Initialize pipeline as a list of stages - self._pipeline = [] + self._lowering_recipe: Optional[LoweringRecipe] = ( + self._export_recipe.lowering_recipe + ) - # Create the source transform stage if a quantization recipe is provided - if self._quant_recipe is not None and self._quant_recipe.ao_base_config: - source_transform_stage = SourceTransformStage( - quantization_recipe=self._export_recipe.quantization_recipe - ) - self._pipeline.append(source_transform_stage) + # Stages to run + self._pipeline_stages = ( + self._export_recipe.pipeline_stages or self._get_default_pipeline() + ) - enable_quantize_stage = ( - self._quant_recipe is not None and self._quant_recipe.quantizers + # Stage registry: map of StageType to Stage instances + self._stage_registry: Dict[StageType, Stage] = self._build_stages( + self._pipeline_stages ) - # Create the quantize stage if a quantizer is provided - if enable_quantize_stage: - # pyre-ignore - if quantizers := self._quant_recipe.quantizers: - quantize_stage = QuantizeStage(quantizers=quantizers) - self._pipeline.append(quantize_stage) + # Intialize run context + self._run_context: Dict[str, Any] = { + "example_inputs": self._example_inputs, + "dynamic_shapes": self._dynamic_shapes, + "constant_methods": constant_methods, + "export_recipe": self._export_recipe, + "session_name": name, + "artifact_dir": artifact_dir, + } + + self._stage_to_artifacts: Dict[StageType, PipelineArtifact] = {} + + def _get_default_pipeline(self) -> List[StageType]: + return [ + StageType.SOURCE_TRANSFORM, # Optional stage, returns original model if quant recipe is invalid + StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + + def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: + """Build the stage registry from the given stages.""" + stage_registry: Dict[StageType, Stage] = {} + + stage = None + for stage_type in stages or self._get_default_pipeline(): + if stage_type == StageType.SOURCE_TRANSFORM: + stage = SourceTransformStage(self._quant_recipe) + elif stage_type == StageType.QUANTIZE: + stage = QuantizeStage(self._quant_recipe) + elif stage_type == StageType.TORCH_EXPORT: + pre_edge_passes = None + if self._export_recipe.pre_edge_transform_passes is not None: + pre_edge_passes = list( + self._export_recipe.pre_edge_transform_passes + ) + stage = TorchExportStage(pre_edge_passes) + elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER: + stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe) + elif stage_type == StageType.TO_EDGE: + stage = ToEdgeStage.from_recipe(self._lowering_recipe) + elif stage_type == StageType.TO_BACKEND: + stage = ToBackendStage.from_recipe(self._lowering_recipe) + elif stage_type == StageType.TO_EXECUTORCH: + stage = ExecutorchStage(self._export_recipe.executorch_backend_config) + else: + logging.info( + f"{stage_type} is unknown, you have to register it before executing export()" + ) - # Create the export stage - export_stage = ExportStage( - pre_edge_transform_passes=self._export_recipe.pre_edge_transform_passes, - ) - self._pipeline.append(export_stage) + if stage: + stage_registry[stage_type] = stage + return stage_registry - # Create the edge transform and lower stage - edge_transform_and_lower_stage = EdgeTransformAndLowerStage( - partitioners=self._export_recipe.partitioners, - transform_passes=self._export_recipe.edge_transform_passes, - compile_config=self._export_recipe.edge_compile_config, - ) - self._pipeline.append(edge_transform_and_lower_stage) + def register_stage(self, stage_type: StageType, stage: Stage) -> None: + """ + Register a new stage or override an existing stage implementation. - # Create the executorch stage - executorch_stage = ExecutorchStage( - backend_config=self._export_recipe.executorch_backend_config - ) - self._pipeline.append(executorch_stage) + Args: + stage_type: The type of stage to register + stage: The stage instance to register + """ + self._stage_registry[stage_type] = stage - # Initialize stage artifacts - self._exported_models: Dict[str, nn.Module] = {} + def get_registered_stage(self, stage_type: StageType) -> Optional[Stage]: + """ + Get a registered stage by its type. - # Initialize stage artifacts - self._exported_program: Dict[str, ExportedProgram] = {} - self._edge_program_manager: Optional[EdgeProgramManager] = None - self._executorch_program_manager: Optional[ExecutorchProgramManager] = None - self._delegation_info = None + Args: + stage_type: The type of stage to retrieve - def _run_pipeline(self) -> None: + Returns: + The registered stage instance, or None if not found + """ + return self._stage_registry.get(stage_type) + + def get_all_registered_stages(self) -> Dict[StageType, Stage]: """ - Run the pipeline from the beginning. + Get all registered stages. - This method cascades through the pipeline of stages, executing each stage in order. - Each stage directly configures the inputs for the next stage when it completes. + Returns: + Dictionary mapping stage types to stage instances """ - # Process each stage in the pipeline - for stage in self._pipeline: - stage_name = stage.name - logging.info(f"Executing stage: {stage_name}") - # Configure inputs for the current stage - if stage_name == "source_transform": - # Run the source transform stage - stage.run(self._model, {}) - self._model = stage.get_artifacts() - elif stage_name == "quantize": - # Run the quantize stage - config_params = {"example_inputs": self._example_inputs} - stage.run(self._model, config_params) - self._model = stage.get_artifacts() - elif stage_name == "export": - # Run the export stage - models = {"model": self._model} - config_params = { - "example_inputs": self._example_inputs, - "dynamic_shapes": self._dynamic_shapes, - } - stage.run(models, config_params) - self._exported_program = stage.get_artifacts() - elif stage_name == "edge_transform_and_lower": - # Run the edge transform and lower stage - stage.run( - self._exported_program, {"constant_methods": self._constant_methods} + return self._stage_registry + + def _validate_pipeline_sequence( + self, + stages: List[StageType], + ) -> None: + if not stages: + raise ValueError("Pipeline stages cannot be empty") + + # Validate that the first stage can start a pipeline + first_stage = stages[0] + first_stage_instance = self._stage_registry.get(first_stage) + if first_stage_instance is None: + raise ValueError( + f"Stage {first_stage} not found in registry, register it using session.register_stage()" + ) + + if not first_stage_instance.can_start_pipeline: + raise ValueError(f"Stage {first_stage} cannot start a pipeline. ") + + # Validate stage transitions + for i in range(1, len(stages)): + current_stage = stages[i] + previous_stage = stages[i - 1] + + # Get the stage instance to check its valid predecessors + stage_instance = self._stage_registry.get(current_stage) + if stage_instance is None: + raise ValueError( + f"Stage {current_stage} not found in registry, , register it using session.register_stage()" + ) + + valid_predecessors = stage_instance.valid_predecessor_stages + + # Check if the previous stage is valid for the current stage + if valid_predecessors and previous_stage not in valid_predecessors: + raise ValueError( + f"Invalid transition from {previous_stage} to {current_stage}. " + f"Valid predecessors for {current_stage}: {valid_predecessors}" ) - self._edge_program_manager = stage.get_artifacts() - self._delegation_info = stage.delegation_info - elif stage_name == "executorch": - # Run the executorch stage - stage.run(self._edge_program_manager, {}) - self._executorch_program_manager = stage.get_artifacts() + + def _run_pipeline(self) -> None: + # Validate if given stage sequence is valid + self._validate_pipeline_sequence( + stages=self._pipeline_stages, + ) + + current_artifact = PipelineArtifact(data=self._model, context=self._run_context) + + # Execute stages from registry in the order specified by pipeline_stages + for stage_type in self._pipeline_stages: + stage = self._stage_registry.get(stage_type) + if stage is None: + raise ValueError(f"Stage {stage_type} not found in registry") + + logging.info(f"Executing stage: {stage_type}") + + stage.run(current_artifact) + current_artifact = stage.get_artifacts() + + self._stage_to_artifacts[stage_type] = current_artifact def export(self) -> None: """ @@ -649,6 +316,9 @@ def export(self) -> None: # Run the pipeline from the beginning self._run_pipeline() + def get_stage_artifacts(self) -> Dict[StageType, PipelineArtifact]: + return self._stage_to_artifacts + def save_pte_file(self, path: str) -> None: """ Save the exported program to a PTE file. @@ -659,11 +329,7 @@ def save_pte_file(self, path: str) -> None: Raises: RuntimeError: If the executorch program manager is not initialized """ - if self._executorch_program_manager is None: - raise RuntimeError( - "Executorch program manager is not initialized. Run export() first." - ) - self._executorch_program_manager.save(path) + self.get_executorch_program_manager().save(path) def get_executorch_program(self) -> Program: """ @@ -675,11 +341,7 @@ def get_executorch_program(self) -> Program: Raises: RuntimeError: If the executorch program manager is not initialized """ - if self._executorch_program_manager is None: - raise RuntimeError( - "Executorch program manager is not initialized. Run export() first." - ) - return self._executorch_program_manager.executorch_program + return self.get_executorch_program_manager().executorch_program def get_executorch_program_manager(self) -> ExecutorchProgramManager: """ @@ -691,11 +353,12 @@ def get_executorch_program_manager(self) -> ExecutorchProgramManager: Raises: RuntimeError: If the executorch program manager is not initialized """ - if self._executorch_program_manager is None: + artifact = self._stage_to_artifacts.get(StageType.TO_EXECUTORCH) + if artifact is None or artifact.data is None: raise RuntimeError( - "Executorch program manager is not initialized. Run export() first." + "Executorch program manager is not initialized. Run Executorch Stage first." ) - return self._executorch_program_manager + return artifact.data def get_pte_buffer(self) -> bytes: """ @@ -707,11 +370,7 @@ def get_pte_buffer(self) -> bytes: Raises: RuntimeError: If the executorch program manager is not initialized """ - if self._executorch_program_manager is None: - raise RuntimeError( - "Executorch program manager is not initialized. Run export() first." - ) - return self._executorch_program_manager.buffer + return self.get_executorch_program_manager().buffer def save_to_pte(self, output_name: str) -> None: """ @@ -721,11 +380,7 @@ def save_to_pte(self, output_name: str) -> None: output_name (Optional[str]): The name of the .pte file. """ assert output_name, "Need a valid output name" - if self._executorch_program_manager is None: - raise RuntimeError( - "Executorch program manager is not initialized. Run export() first." - ) - save_pte_program(self._executorch_program_manager, output_name) + save_pte_program(self.get_executorch_program_manager(), output_name) def get_example_input( self, method_name: str = "forward" @@ -791,6 +446,10 @@ def print_delegation_info(self) -> None: """ Print delegation information for the exported program. """ - print(self._delegation_info.get_summary()) - df = self._delegation_info.get_operator_delegation_dataframe() - print(tabulate(df, headers="keys", tablefmt="fancy_grid")) + delegation_info = self._run_context.get("delegation_info", None) + if delegation_info: + logging.info(delegation_info.get_summary()) + df = delegation_info.get_operator_delegation_dataframe() + logging.info(tabulate(df, headers="keys", tablefmt="fancy_grid")) + else: + logging.info("No delegation info available") diff --git a/export/recipe.py b/export/recipe.py index d95c4e77696..8f7251cd419 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -16,6 +16,8 @@ from torchao.core.config import AOBaseConfig from torchao.quantization.pt2e.quantizer import Quantizer +from .types import StageType + """ Export recipe definitions for ExecuTorch. @@ -70,7 +72,8 @@ class QuantizationRecipe: This class holds the configuration parameters for quantizing a model. Attributes: - quantizer: Optional quantizer for model quantization + quantizers: Optional list of quantizers for model quantization + ao_base_config: Optional list of AO base configurations """ quantizers: Optional[List[Quantizer]] = None @@ -78,14 +81,34 @@ class QuantizationRecipe: def get_quantizers(self) -> Optional[List[Quantizer]]: """ - Get the quantizer associated with this recipe. + Get the quantizers associated with this recipe. Returns: - The quantizer if one is set, otherwise None + The quantizers if any are set, otherwise None """ return self.quantizers +@dataclass +class LoweringRecipe: + """ + Configuration recipe for lowering and partitioning. + + This class holds the configuration parameters for lowering a model + to backend-specific representations. + + Attributes: + partitioners: Optional list of partitioners for model partitioning + edge_transform_passes: Optional sequence of transformation passes to apply + edge_compile_config: Optional edge compilation configuration + """ + + partitioners: Optional[List[Partitioner]] = None + edge_transform_passes: Optional[Sequence[PassType]] = None + # pyre-ignore[11]: Type not defined + edge_compile_config: Optional[EdgeCompileConfig] = None + + @experimental( "This API and all of its related functionality such as ExportSession and ExportRecipe are experimental." ) @@ -100,27 +123,21 @@ class ExportRecipe: Attributes: name: Optional name for the recipe quantization_recipe: Optional quantization recipe for model quantization - edge_compile_config: Optional edge compilation configuration pre_edge_transform_passes: Optional function to apply transformation passes before edge lowering - edge_transform_passes: Optional sequence of transformation passes to apply - during edge lowering - transform_check_ir_validity: Whether to check IR validity during transformation - partitioners: Optional list of partitioners for model partitioning + lowering_recipe: Optional lowering recipe for model lowering and partitioning executorch_backend_config: Optional backend configuration for ExecuTorch + pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline. mode: Export mode (debug or release) """ name: Optional[str] = None quantization_recipe: Optional[QuantizationRecipe] = None - # pyre-ignore[11]: Type not defined - edge_compile_config: Optional[EdgeCompileConfig] = None pre_edge_transform_passes: Optional[Sequence[PassType]] = None - edge_transform_passes: Optional[Sequence[PassType]] = None - transform_check_ir_validity: bool = True - partitioners: Optional[List[Partitioner]] = None + lowering_recipe: Optional[LoweringRecipe] = None # pyre-ignore[11]: Type not defined executorch_backend_config: Optional[ExecutorchBackendConfig] = None + pipeline_stages: Optional[List[StageType]] = None mode: Mode = Mode.RELEASE @classmethod diff --git a/export/stages.py b/export/stages.py new file mode 100644 index 00000000000..dd22155e929 --- /dev/null +++ b/export/stages.py @@ -0,0 +1,502 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Sequence + +import torch +from executorch.devtools.backend_debug import get_delegation_info +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.backend_api import validation_disabled +from executorch.exir.program import to_edge, to_edge_transform_and_lower +from executorch.exir.program._program import _transform +from executorch.export.recipe import LoweringRecipe, QuantizationRecipe +from executorch.export.types import StageType +from torch import nn +from torch._export.pass_base import PassType +from torchao.quantization import quantize_ +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import ComposableQuantizer +from torchao.utils import unwrap_tensor_subclass + + +class PipelineArtifact: + def __init__( + self, + data: Any, + context: Dict[str, Any], + ) -> None: + self.data = data + self.context = context + + def add_context(self, key: str, value: Any) -> None: + self.context[key] = value + + def get_context(self, key: str, default: Any = None) -> Any: + return self.context.get(key, default) + + def copy_with_new_data(self, new_data: Any) -> "PipelineArtifact": + return PipelineArtifact(data=new_data, context=self.context.copy()) + + +class Stage(ABC): + """ + Interface for a Stage in the ExecuTorch export pipeline. + + Each stage can be connected to other stages to form a pipeline. + Each stage implements its own run method with specific parameter names. + """ + + def __init__(self) -> None: + """ + Initialize the stage. + """ + self._artifact = None + + @property + @abstractmethod + def stage_type(self) -> "StageType": + """ + Returns the type of this stage. + """ + pass + + @property + @abstractmethod + def valid_predecessor_stages(self) -> List["StageType"]: + """ + Returns the list of stage types that can come before this stage. + """ + pass + + @property + @abstractmethod + def can_start_pipeline(self) -> bool: + """ + Returns whether this stage can be the first stage in a pipeline. + """ + pass + + @abstractmethod + def run(self, artifact: PipelineArtifact) -> None: + """ + Executes this stage with the given inputs. + + Each concrete stage class implements this method with specific parameter names. + """ + pass + + def get_artifacts(self) -> "PipelineArtifact": + if self._artifact is None: + raise RuntimeError(f"Stage: {self.__class__.__name__} not executed") + return self._artifact + + +class TorchExportStage(Stage): + """ + Purpose: Export PyTorch model to ExportedProgram. + """ + + def __init__( + self, + pre_edge_transform_passes: Optional[List[PassType]] = None, + ) -> None: + super().__init__() + self._pre_edge_transform_passes = pre_edge_transform_passes + + @property + def stage_type(self) -> str: + return StageType.TORCH_EXPORT + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.SOURCE_TRANSFORM, StageType.QUANTIZE] + + @property + def can_start_pipeline(self) -> bool: + return True + + def run(self, artifact: PipelineArtifact) -> None: + models = artifact.data + example_inputs = artifact.get_context("example_inputs") + dynamic_shapes = artifact.get_context("dynamic_shapes", {}) + + exported_programs = {} + + with torch.no_grad(): + for method_name, model in models.items(): + if method_name not in example_inputs: + raise ValueError( + f"Example inputs for method {method_name} not found." + ) + + method_dynamic_shapes = dynamic_shapes.get(method_name) + + # Export the model + exported_programs[method_name] = torch.export.export( + model, + example_inputs[method_name][0], + dynamic_shapes=method_dynamic_shapes, + strict=True, + ) + + # Apply pre-edge transform passes if available + for pass_ in self._pre_edge_transform_passes or []: + exported_programs[method_name] = _transform( + exported_programs[method_name], pass_ + ) + + self._artifact = artifact.copy_with_new_data(exported_programs) + + +class EdgeTransformAndLowerStage(Stage): + """ + Second stage: Transform and lower to EdgeProgramManager. + """ + + def __init__( + self, + partitioners: Optional[List[Any]] = None, + transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + compile_config: Optional[Any] = None, + ) -> None: + self._partitioners = partitioners + self._transform_passes = transform_passes + self._compile_config = compile_config + + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "EdgeTransformAndLowerStage": + if lowering_recipe is None: + return cls() + + return cls( + partitioners=lowering_recipe.partitioners, + transform_passes=lowering_recipe.edge_transform_passes, + compile_config=lowering_recipe.edge_compile_config, + ) + + @property + def stage_type(self) -> str: + return StageType.TO_EDGE_TRANSFORM_AND_LOWER + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TORCH_EXPORT] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Transform and lower to EdgeProgramManager. + """ + exported_programs = artifact.data + constant_methods = artifact.get_context("constant_methods") + + with validation_disabled(): + edge_program_manager = to_edge_transform_and_lower( + exported_programs, + partitioner=self._partitioners, + transform_passes=self._transform_passes, + constant_methods=constant_methods, + compile_config=self._compile_config, + ) + + delegation_info = get_delegation_info( + edge_program_manager.exported_program().graph_module + ) + self._artifact = artifact.copy_with_new_data(edge_program_manager) + self._artifact.add_context("delegation_info", delegation_info) + + @property + def delegation_info(self) -> Any: + """ + Returns the delegation info. + """ + return self._artifact.get_context("delegation_info") + + +class ExecutorchStage(Stage): + """ + Convert to ExecutorchProgramManager. + """ + + def __init__(self, backend_config: Any) -> None: + self._backend_config = backend_config + + @property + def stage_type(self) -> str: + return StageType.TO_EXECUTORCH + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Convert to ExecutorchProgramManager. + """ + edge_program_manager = artifact.data + + # Process inputs + if edge_program_manager is None: + raise RuntimeError("Edge program manager is not set.") + + # Convert to ExecutorchProgramManager + executorch_program_manager = edge_program_manager.to_executorch( + self._backend_config + ) + self._artifact = artifact.copy_with_new_data(executorch_program_manager) + + +class SourceTransformStage(Stage): + """ + Optional stage: Source transform stage: Apply source transformations to the model. + """ + + def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None: + self._quantization_recipe = quantization_recipe + self._transformed_models: Dict[str, nn.Module] = {} + + @property + def stage_type(self) -> str: + return StageType.SOURCE_TRANSFORM + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [] + + @property + def can_start_pipeline(self) -> bool: + return True + + def run(self, artifact: PipelineArtifact) -> None: + """ + Apply source transformations to the model. + """ + if ( + not self._quantization_recipe + or not self._quantization_recipe.ao_base_config + ): + logging.info( + "Quantization recipe is invalid to run SourceTransform, returning original artifact" + ) + self._artifact = artifact + return + + assert isinstance(artifact.data, dict) + + # Store the original models + self._transformed_models = artifact.data + + # Apply torchao quantize_ to each model + for method_name, model in artifact.data.items(): + # pyre-ignore + for config in self._quantization_recipe.ao_base_config: + quantize_(model, config) + unwrap_tensor_subclass(model) + self._transformed_models[method_name] = model + + self._artifact = artifact.copy_with_new_data(self._transformed_models) + + +class QuantizeStage(Stage): + """ + Optional stage: Perform post-training quantization on the model. + """ + + def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None: + self._quantization_recipe = quantization_recipe + + @property + def stage_type(self) -> str: + return StageType.QUANTIZE + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.SOURCE_TRANSFORM] + + @property + def can_start_pipeline(self) -> bool: + return True + + def run(self, artifact: PipelineArtifact) -> None: + if not self._quantization_recipe or not self._quantization_recipe.quantizers: + logging.info( + "Quantization recipe is invalid to run QunatizeStage, returning original model" + ) + self._artifact = artifact + return + + assert isinstance(artifact.data, dict) + + models = artifact.data + example_inputs = artifact.get_context("example_inputs") + + quantized_models = {} + + for method_name, model in models.items(): + if method_name not in example_inputs or not example_inputs[method_name]: + raise ValueError( + f"Example inputs for method {method_name} not found or empty." + ) + + inputs = example_inputs[method_name][0] + captured_graph = torch.export.export(model, inputs, strict=True).module() + + composed_quantizer = ComposableQuantizer( + # pyre-ignore + self._quantization_recipe.quantizers + ) + prepared_model = prepare_pt2e(captured_graph, composed_quantizer) + + for calibration_input in example_inputs[method_name]: + prepared_model(*calibration_input) + + quantized_model = convert_pt2e(prepared_model) + quantized_models[method_name] = quantized_model + + self._artifact = artifact.copy_with_new_data(quantized_models) + + +class ToEdgeStage(Stage): + """ + Stage: Convert ExportedProgram to EdgeProgramManager. + """ + + def __init__( + self, + edge_compile_config: Optional[EdgeCompileConfig] = None, # pyre-ignore + ) -> None: + super().__init__() + self._edge_compile_config = edge_compile_config + + @classmethod + def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStage": + if lowering_recipe is None: + return cls() + + return cls( + edge_compile_config=lowering_recipe.edge_compile_config, + ) + + @property + def stage_type(self) -> str: + return StageType.TO_EDGE + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TORCH_EXPORT] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Convert ExportedProgram to EdgeProgramManager. + + Args: + artifact: Contains exported programs and context + """ + exported_programs = artifact.data + constant_methods = artifact.get_context("constant_methods") + + # Convert to edge program manager + edge_program_manager = to_edge( + exported_programs, + constant_methods=constant_methods, + compile_config=self._edge_compile_config, + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + + +class ToBackendStage(Stage): + """ + Stage: Apply transformations and partitioning to EdgeProgramManager. + """ + + def __init__( + self, + partitioners: Optional[List[Any]] = None, + transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + ) -> None: + super().__init__() + self._partitioners = partitioners + self._transform_passes = transform_passes + + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "ToBackendStage": + if lowering_recipe is None: + return cls() + + return cls( + partitioners=lowering_recipe.partitioners, + transform_passes=lowering_recipe.edge_transform_passes, + ) + + @property + def stage_type(self) -> str: + return StageType.TO_BACKEND + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TO_EDGE] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Apply transformations and partitioning to EdgeProgramManager. + + Args: + artifact: Contains edge program manager and context + """ + edge_program_manager = artifact.data + + if edge_program_manager is None: + raise RuntimeError("Edge program manager is not set.") + + # Apply transform passes if available + if self._transform_passes: + edge_program_manager = edge_program_manager.transform( + self._transform_passes + ) + + # Apply partitioners if available + if self._partitioners is not None and len(self._partitioners) > 0: + with validation_disabled(): + # pyre-ignore + for partitioner in self._partitioners: + edge_program_manager = edge_program_manager.to_backend(partitioner) + + # Get delegation info + delegation_info = get_delegation_info( + edge_program_manager.exported_program().graph_module + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + self._artifact.add_context("delegation_info", delegation_info) + + @property + def delegation_info(self) -> Any: + """ + Returns the delegation info. + """ + return self._artifact.get_context("delegation_info") diff --git a/export/tests/TARGETS b/export/tests/TARGETS index 50751c552e5..068c3436b6a 100644 --- a/export/tests/TARGETS +++ b/export/tests/TARGETS @@ -21,6 +21,7 @@ runtime.python_test( "test_recipe_provider.py", "test_recipe_registry.py", "test_export_recipe.py", + "test_export_session.py", "test_export_stages.py", ], deps = [ diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py new file mode 100644 index 00000000000..92aeebb7304 --- /dev/null +++ b/export/tests/test_export_session.py @@ -0,0 +1,482 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import List +from unittest.mock import Mock + +import torch +from executorch.export import ExportRecipe, ExportSession +from executorch.export.recipe import LoweringRecipe, QuantizationRecipe +from executorch.export.stages import PipelineArtifact +from executorch.export.types import StageType + + +class SimpleTestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +class TestExportSessionCoreFlow(unittest.TestCase): + """Test core export flow and pipeline execution.""" + + def setUp(self) -> None: + self.model = SimpleTestModel() + self.example_inputs = [(torch.randn(2, 10),)] + self.recipe = ExportRecipe(name="test") + + def _create_mock_stage(self, stage_type: StageType) -> Mock: + mock_stage = Mock() + mock_artifact = Mock(spec=PipelineArtifact) + mock_artifact.data = Mock() + mock_artifact.context = {} + mock_stage.get_artifacts.return_value = mock_artifact + mock_stage.stage_type = stage_type + + # Add the new properties required by the Stage interface + if stage_type == StageType.SOURCE_TRANSFORM: + mock_stage.valid_predecessor_stages = [] + mock_stage.can_start_pipeline = True + elif stage_type == StageType.QUANTIZE: + mock_stage.valid_predecessor_stages = [StageType.SOURCE_TRANSFORM] + mock_stage.can_start_pipeline = True + elif stage_type == StageType.TORCH_EXPORT: + mock_stage.valid_predecessor_stages = [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + ] + mock_stage.can_start_pipeline = True + elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER: + mock_stage.valid_predecessor_stages = [StageType.TORCH_EXPORT] + mock_stage.can_start_pipeline = False + elif stage_type == StageType.TO_EXECUTORCH: + mock_stage.valid_predecessor_stages = [ + StageType.TO_EDGE_TRANSFORM_AND_LOWER + ] + mock_stage.can_start_pipeline = True + else: + mock_stage.valid_predecessor_stages = [] + mock_stage.can_start_pipeline = True + + return mock_stage + + def test_default_pipeline_execution_order(self) -> None: + # Test that pipeline stages are executed in the correct order + stage_types = [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + mock_stages = [ + self._create_mock_stage(stage_type) for stage_type in stage_types + ] + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=self.recipe, + ) + + # Replace the stages in the registry with our mocked stages + for stage_type, mock_stage in zip(stage_types, mock_stages): + session.register_stage(stage_type, mock_stage) + + session.export() + + # Verify all stages were called + for stage in mock_stages: + stage.run.assert_called_once() + + # Verify artifacts were stored for each stage + self.assertEqual(len(session._stage_to_artifacts), 5) + self.assertEqual(set(session._stage_to_artifacts.keys()), set(stage_types)) + + def test_overriden_pipeline_execution_order(self) -> None: + # Test when pipeline stages that are passed through recipe + stage_types = [ + StageType.SOURCE_TRANSFORM, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + mock_stages = [ + self._create_mock_stage(stage_type) for stage_type in stage_types + ] + + self.recipe.pipeline_stages = stage_types + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=self.recipe, + ) + + # Replace the stages in the registry with our mocked stages + for stage_type, mock_stage in zip(stage_types, mock_stages): + session.register_stage(stage_type, mock_stage) + session.export() + + # Verify all stages were called + for stage in mock_stages: + stage.run.assert_called_once() + + # Verify artifacts were stored for each stage + self.assertEqual(len(session._stage_to_artifacts), 4) + self.assertEqual(set(session._stage_to_artifacts.keys()), set(stage_types)) + + def test_model_standardization_single_to_dict(self) -> None: + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=self.recipe, + ) + + self.assertIsInstance(session._model, dict) + self.assertIn("forward", session._model) + self.assertEqual(session._model["forward"], self.model) + + self.assertIsInstance(session._example_inputs, dict) + self.assertIn("forward", session._example_inputs) + self.assertEqual(session._example_inputs["forward"], self.example_inputs) + + def test_model_standardization_preserves_dict(self) -> None: + # Test that dictionary models are preserved as-is. + model_dict = {"method1": self.model, "method2": SimpleTestModel()} + inputs_dict = { + "method1": self.example_inputs, + "method2": [(torch.randn(1, 10),)], + } + + session = ExportSession( + model=model_dict, # pyre-ignore[6] + example_inputs=inputs_dict, + export_recipe=self.recipe, + ) + + self.assertEqual(session._model, model_dict) + self.assertEqual(session._example_inputs, inputs_dict) + + def test_context_propagation_through_pipeline(self) -> None: + # Test that context is properly propagated through the pipeline + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=self.recipe, + name="test_session", + constant_methods={"const_method": lambda: torch.tensor([1, 2, 3])}, + ) + + # Check that initial context is set up correctly + expected_context_keys = { + "example_inputs", + "dynamic_shapes", + "constant_methods", + "export_recipe", + "session_name", + "artifact_dir", + } + self.assertEqual(set(session._run_context.keys()), expected_context_keys) + self.assertEqual(session._run_context["session_name"], "test_session") + self.assertIsNotNone(session._run_context["constant_methods"]) + + def test_stage_registry_unknown_stage_type(self) -> None: + # Test error handling for unknown stage types in pipeline + unknown_stage_type = Mock() + unknown_stage_type.name = "UNKNOWN_STAGE" + recipe = ExportRecipe(name="test", pipeline_stages=[unknown_stage_type]) + + with self.assertRaises(ValueError) as cm: + ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + )._run_pipeline() + self.assertIn("not found in registry", str(cm.exception)) + + def test_multi_method_model_export(self) -> None: + # Test export with multi-method models + model_dict = { + "forward": self.model, + "inference": SimpleTestModel(), + } + inputs_dict = { + "forward": self.example_inputs, + "inference": [(torch.randn(1, 10),)], + } + + session = ExportSession( + model=model_dict, # pyre-ignore[6] + example_inputs=inputs_dict, + export_recipe=ExportRecipe(name="multi_method_test"), + ) + + # Verify proper initialization + self.assertEqual(session._model, model_dict) + self.assertEqual(session._example_inputs, inputs_dict) + + # Test getting example inputs for different methods + forward_input = session.get_example_input("forward") + inference_input = session.get_example_input("inference") + + self.assertEqual(forward_input, self.example_inputs[0]) + self.assertEqual(inference_input, inputs_dict["inference"][0]) + + +class TestPipelineValidation(unittest.TestCase): + def setUp(self) -> None: + self.model = SimpleTestModel() + self.example_inputs = [(torch.randn(2, 10),)] + self.recipe = ExportRecipe(name="test") + + # pyre-ignore + def _get_export_session(self, stages: List[StageType]): + self.recipe.pipeline_stages = stages + return ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=self.recipe, + ) + + def test_valid_pipeline_sequences(self) -> None: + """Test various valid pipeline sequences.""" + valid_sequences = [ + # Full pipeline with to_edge_transform_lower + [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ], + # Full pipeline with to_edge, to_backend + [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE, + StageType.TO_BACKEND, + StageType.TO_EXECUTORCH, + ], + # Skip quantize + [ + StageType.SOURCE_TRANSFORM, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ], + # Skip source transform and tart with quantize + [ + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ], + # Start with torch export + [ + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ], + ] + + for i, stages in enumerate(valid_sequences): + with self.subTest(sequence=i, stages=[s.name for s in stages]): + session = self._get_export_session(stages) + # Should not raise any exception + try: + session._validate_pipeline_sequence(stages) + except Exception as e: + self.fail(f"Valid sequence {[s.name for s in stages]} raised {e}") + + def test_invalid_pipeline_start_stages(self) -> None: + """Test stages that cannot start a pipeline.""" + invalid_stage_sequence = [ + # Edge stage cannot start pipeline + [StageType.TO_EDGE_TRANSFORM_AND_LOWER], + [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_EXECUTORCH], + ] + + for i, stages in enumerate(invalid_stage_sequence): + with self.subTest(sequence=i, stages=[s.name for s in stages]): + session = self._get_export_session(stages) + with self.assertRaises(ValueError) as cm: + session._validate_pipeline_sequence(stages) + self.assertIn("cannot start a pipeline", str(cm.exception)) + + def test_pipeline_transitions(self) -> None: + """Test both valid and invalid pipeline transitions""" + test_cases = [ + # Valid cases + ([StageType.SOURCE_TRANSFORM, StageType.QUANTIZE], True), + ([StageType.QUANTIZE, StageType.TORCH_EXPORT], True), + ([StageType.SOURCE_TRANSFORM, StageType.TORCH_EXPORT], True), + ([StageType.TORCH_EXPORT, StageType.TO_EDGE_TRANSFORM_AND_LOWER], True), + # Invalid cases - transitions + ([StageType.QUANTIZE, StageType.TO_EDGE_TRANSFORM_AND_LOWER], False), + ( + [StageType.SOURCE_TRANSFORM, StageType.TO_EDGE_TRANSFORM_AND_LOWER], + False, + ), + ( + [ + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.QUANTIZE, + ], + False, + ), + ([StageType.TO_EXECUTORCH, StageType.TORCH_EXPORT], False), + ] + + for i, (stages, should_pass) in enumerate(test_cases): + with self.subTest( + sequence=i, stages=[s.name for s in stages], should_pass=should_pass + ): + session = self._get_export_session(stages) + if should_pass: + try: + session._validate_pipeline_sequence(stages) + except Exception as e: + self.fail( + f"Expected valid sequence {[s.name for s in stages]} but got {e}" + ) + else: + with self.assertRaises(ValueError): + session._validate_pipeline_sequence(stages) + + def test_empty_pipeline_sequence(self) -> None: + """Test empty pipeline sequence.""" + session = self._get_export_session([]) + with self.assertRaises(ValueError) as cm: + session._validate_pipeline_sequence([]) + self.assertIn("Pipeline stages cannot be empty", str(cm.exception)) + + +class TestExportSessionErrorHandling(unittest.TestCase): + """Test error handling in export session.""" + + def setUp(self) -> None: + self.model = SimpleTestModel() + self.example_inputs = [(torch.randn(2, 10),)] + self.recipe = ExportRecipe(name="test") + + def test_access_results_before_export(self) -> None: + """Test that accessing results before export raises appropriate errors.""" + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=self.recipe, + ) + + with self.assertRaises(RuntimeError) as cm: + session.get_executorch_program_manager() + self.assertIn( + "Executorch program manager is not initialized", str(cm.exception) + ) + + with self.assertRaises(RuntimeError) as cm: + session.get_executorch_program() + self.assertIn( + "Executorch program manager is not initialized", str(cm.exception) + ) + + with self.assertRaises(RuntimeError) as cm: + session.get_pte_buffer() + self.assertIn( + "Executorch program manager is not initialized", str(cm.exception) + ) + + def test_invalid_method_name_in_example_inputs(self) -> None: + """Test error handling for invalid method names.""" + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=self.recipe, + ) + + with self.assertRaises(KeyError) as cm: + session.get_example_input("nonexistent_method") + self.assertIn("Method name 'nonexistent_method' not found", str(cm.exception)) + + def test_empty_example_inputs_list(self) -> None: + """Test error handling for empty example inputs.""" + session = ExportSession( + model={"forward": self.model}, + example_inputs={"forward": []}, + export_recipe=self.recipe, + ) + + with self.assertRaises(ValueError) as cm: + session.get_example_input("forward") + self.assertIn( + "Example inputs list for method forward is empty", str(cm.exception) + ) + + def test_save_to_pte_invalid_name(self) -> None: + """Test save_to_pte with invalid output name.""" + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=self.recipe, + ) + + with self.assertRaises(AssertionError): + session.save_to_pte("") + + with self.assertRaises(AssertionError): + session.save_to_pte(None) # pyre-ignore + + +class TestExportSessionPipelineBuilding(unittest.TestCase): + """Test pipeline building and stage configuration.""" + + def setUp(self) -> None: + self.model = SimpleTestModel() + self.example_inputs = [(torch.randn(2, 10),)] + + def test_pipeline_building_with_all_recipes(self) -> None: + """Test pipeline building with quantization and lowering recipes.""" + # Create comprehensive recipes + quant_recipe = QuantizationRecipe( + ao_base_config=[Mock()], + quantizers=[Mock()], + ) + lowering_recipe = LoweringRecipe( + partitioners=[Mock()], + edge_transform_passes=[Mock()], + edge_compile_config=Mock(), + ) + recipe = ExportRecipe( + name="comprehensive_test", + quantization_recipe=quant_recipe, + lowering_recipe=lowering_recipe, + executorch_backend_config=Mock(), + ) + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + registered_stages = session.get_all_registered_stages() + + self.assertEqual(len(registered_stages), 5) + expected_types = [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + self.assertListEqual(list(registered_stages.keys()), expected_types) diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 7e6fddbf231..2b3e533723a 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -11,18 +11,19 @@ import torch from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager -from executorch.export import ExportRecipe, QuantizationRecipe -from executorch.export.export import ( +from executorch.export import QuantizationRecipe +from executorch.export.stages import ( EdgeTransformAndLowerStage, ExecutorchStage, - ExportSession, - ExportStage, + PipelineArtifact, QuantizeStage, SourceTransformStage, + StageType, + ToBackendStage, + ToEdgeStage, + TorchExportStage, ) from torch.export import ExportedProgram -from torchao.quantization.granularity import PerAxis -from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig class SimpleTestModel(torch.nn.Module): @@ -34,12 +35,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) -class TestExportStage(unittest.TestCase): +class TestPipelineArtifact(unittest.TestCase): + + def test_copy_with_new_data(self) -> None: + original_data = {"original": "data"} + context = {"key": "value"} + artifact = PipelineArtifact(data=original_data, context=context) + + new_data = {"new": "data"} + new_artifact = artifact.copy_with_new_data(new_data) + + self.assertEqual(new_artifact.data, new_data) + self.assertEqual(new_artifact.context, context) + # Ensure original is unchanged + self.assertEqual(artifact.data, original_data) + + +class TestTorchExportStage(unittest.TestCase): def setUp(self) -> None: self.model = SimpleTestModel() self.example_inputs = [(torch.randn(2, 10),)] self.models_dict = {"forward": self.model} - self.export_config = { + self.context = { "example_inputs": {"forward": self.example_inputs}, "dynamic_shapes": {}, } @@ -49,8 +66,10 @@ def test_export_stage_run_success(self, mock_torch_export: Mock) -> None: mock_exported_program = Mock(spec=ExportedProgram) mock_torch_export.return_value = mock_exported_program - stage = ExportStage() - stage.run({"model": self.models_dict}, self.export_config) + stage = TorchExportStage() + artifact = PipelineArtifact(data=self.models_dict, context=self.context) + + stage.run(artifact) mock_torch_export.assert_called_once_with( self.model, @@ -60,43 +79,50 @@ def test_export_stage_run_success(self, mock_torch_export: Mock) -> None: ) # Verify artifacts - artifacts = stage.get_artifacts() - self.assertIn("forward", artifacts) - self.assertEqual(artifacts["forward"], mock_exported_program) + artifact = stage.get_artifacts() + self.assertIn("forward", artifact.data) + self.assertEqual(artifact.data["forward"], mock_exported_program) def test_export_stage_missing_example_inputs(self) -> None: - stage = ExportStage() - with self.assertRaises(ValueError) as context: - stage.run({"model": self.models_dict}, {"example_inputs": {}}) - self.assertIn( - "Example inputs for method forward not found", str(context.exception) - ) + stage = TorchExportStage() + context = {"example_inputs": {}} + artifact = PipelineArtifact(data=self.models_dict, context=context) + + with self.assertRaises(ValueError) as cm: + stage.run(artifact) + self.assertIn("Example inputs for method forward not found", str(cm.exception)) + + def test_get_artifacts_before_run(self) -> None: + """Test error when getting artifacts before running stage.""" + stage = TorchExportStage() + with self.assertRaises(RuntimeError) as cm: + stage.get_artifacts() + self.assertIn("Stage: TorchExportStage not executed", str(cm.exception)) class TestEdgeTransformAndLowerStage(unittest.TestCase): def setUp(self) -> None: self.mock_exported_program = Mock(spec=ExportedProgram) self.exported_programs = {"forward": self.mock_exported_program} + self.context = {"constant_methods": None} + + def test_run_with_partitioners_and_config(self) -> None: + """Test execution with partitioners and compile config""" + mock_partitioners = [Mock()] + mock_transform_passes = [Mock()] + mock_compile_config = Mock() + + stage = EdgeTransformAndLowerStage( + partitioners=mock_partitioners, + transform_passes=mock_transform_passes, + compile_config=mock_compile_config, + ) - def test_edge_transform_stage_with_partitioners(self) -> None: - """Test that EdgeTransformAndLowerStage can be initialized with partitioners.""" - mock_partitioner = Mock() - stage = EdgeTransformAndLowerStage(partitioners=[mock_partitioner]) - self.assertEqual(stage.name, "edge_transform_and_lower") - self.assertEqual(stage._partitioners, [mock_partitioner]) - - def test_edge_transform_stage_with_config(self) -> None: - """Test that EdgeTransformAndLowerStage can be initialized with compile config.""" - mock_config = Mock() - stage = EdgeTransformAndLowerStage(compile_config=mock_config) - self.assertEqual(stage.name, "edge_transform_and_lower") - self.assertEqual(stage._compile_config, mock_config) - - def test_edge_transform_stage_get_artifacts_not_initialized(self) -> None: - stage = EdgeTransformAndLowerStage() - with self.assertRaises(RuntimeError) as context: - stage.get_artifacts() - self.assertIn("Edge program manager is not initialized", str(context.exception)) + # Test that the stage has the right configuration + self.assertEqual(stage.stage_type, StageType.TO_EDGE_TRANSFORM_AND_LOWER) + self.assertEqual(stage._partitioners, mock_partitioners) + self.assertEqual(stage._transform_passes, mock_transform_passes) + self.assertEqual(stage._compile_config, mock_compile_config) class TestExecutorchStage(unittest.TestCase): @@ -109,7 +135,8 @@ def test_executorch_stage_run_success(self) -> None: self.mock_edge_manager.to_executorch.return_value = mock_executorch_manager stage = ExecutorchStage(self.mock_backend_config) - stage.run(self.mock_edge_manager, {}) + artifact = PipelineArtifact(data=self.mock_edge_manager, context={}) + stage.run(artifact) # Verify to_executorch was called self.mock_edge_manager.to_executorch.assert_called_once_with( @@ -118,15 +145,15 @@ def test_executorch_stage_run_success(self) -> None: # Verify artifacts artifacts = stage.get_artifacts() - self.assertEqual(artifacts, mock_executorch_manager) + self.assertEqual(artifacts.data, mock_executorch_manager) def test_executorch_stage_get_artifacts_not_initialized(self) -> None: stage = ExecutorchStage(self.mock_backend_config) - with self.assertRaises(RuntimeError) as context: - stage.get_artifacts() - self.assertIn( - "Executorch program manager is not initialized", str(context.exception) - ) + artifact = PipelineArtifact(data=None, context={}) + + with self.assertRaises(RuntimeError) as cm: + stage.run(artifact) + self.assertIn("Edge program manager is not set", str(cm.exception)) class TestSourceTransformStage(unittest.TestCase): @@ -135,370 +162,227 @@ def setUp(self) -> None: self.models_dict = {"forward": self.model} def test_source_transform_stage_no_quantization(self) -> None: - stage = SourceTransformStage(None) - stage.run(self.models_dict) + mock_recipe = Mock(spec=QuantizationRecipe) + mock_recipe.ao_base_config = None + stage = SourceTransformStage(mock_recipe) + artifact = PipelineArtifact(data=self.models_dict, context={}) - artifacts = stage.get_artifacts() - self.assertEqual(artifacts, self.models_dict) + stage.run(artifact) + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, self.models_dict) -class TestQuantizeStage(unittest.TestCase): - def setUp(self) -> None: - self.model = SimpleTestModel() - self.models_dict = {"forward": self.model} - self.example_inputs = [(torch.randn(2, 10),)] - self.calibration_config = {"example_inputs": {"forward": self.example_inputs}} + @patch("executorch.export.stages.quantize_") + @patch("executorch.export.stages.unwrap_tensor_subclass") + def test_run_with_ao_base_config( + self, mock_unwrap: Mock, mock_quantize: Mock + ) -> None: + mock_config = Mock() + mock_recipe = Mock(spec=QuantizationRecipe) + mock_recipe.ao_base_config = [mock_config] - def test_quantize_stage_missing_example_inputs(self) -> None: - mock_quantizers = [Mock()] - stage = QuantizeStage(mock_quantizers) + stage = SourceTransformStage(mock_recipe) - with self.assertRaises(ValueError) as context: - stage.run(self.models_dict, {"example_inputs": {}}) - self.assertIn( - "Example inputs for method forward not found or empty", - str(context.exception), - ) + models_dict = {"forward": self.model} + artifact = PipelineArtifact(data=models_dict, context={}) + stage.run(artifact) + + # Verify quantize_ was called with the model and config + mock_quantize.assert_called_once_with(self.model, mock_config) + + # Verify unwrap_tensor_subclass was called with the model + mock_unwrap.assert_called_once_with(self.model) -class TestExportSession(unittest.TestCase): +class TestQuantizeStage(unittest.TestCase): def setUp(self) -> None: self.model = SimpleTestModel() + self.models_dict = {"forward": self.model} self.example_inputs = [(torch.randn(2, 10),)] - - def test_export_session_fp32_pipeline(self) -> None: - """Test that FP32 export creates the expected pipeline stages.""" - recipe = ExportRecipe(name="test_fp32") - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - ) - - # Verify pipeline stages for FP32 - expected_stages = ["export", "edge_transform_and_lower", "executorch"] - actual_stages = [stage.name for stage in session._pipeline] - self.assertEqual(actual_stages, expected_stages) - - def test_export_session_quantized_pipeline_with_quantizers(self) -> None: - """Test that quantized export with quantizers creates the expected pipeline stages.""" + self.context = {"example_inputs": {"forward": self.example_inputs}} + + def test_run_no_quantizers(self) -> None: + """Test execution with no quantizers.""" + mock_recipe = Mock(spec=QuantizationRecipe) + mock_recipe.quantizers = None + stage = QuantizeStage(mock_recipe) + artifact = PipelineArtifact(data=self.models_dict, context=self.context) + stage.run(artifact) + + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact, artifact) + + @patch("executorch.export.stages.convert_pt2e") + @patch("executorch.export.stages.prepare_pt2e") + @patch("executorch.export.stages.ComposableQuantizer") + @patch("torch.export.export") + def test_run_with_quantizers( + self, + mock_torch_export: Mock, + mock_composable_quantizer: Mock, + mock_prepare_pt2e: Mock, + mock_convert_pt2e: Mock, + ) -> None: + """Test execution with quantizers""" mock_quantizer = Mock() - quant_recipe = QuantizationRecipe(quantizers=[mock_quantizer]) - recipe = ExportRecipe(name="test_quantized", quantization_recipe=quant_recipe) + mock_recipe = Mock(spec=QuantizationRecipe) + mock_recipe.quantizers = [mock_quantizer] + stage = QuantizeStage(mock_recipe) - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - ) + # Mock the torch.export.export chain + mock_exported_program = Mock(spec=ExportedProgram) + mock_captured_graph = Mock() + mock_exported_program.module.return_value = mock_captured_graph + mock_torch_export.return_value = mock_exported_program - # Verify pipeline stages for quantized export with quantizers - # The quantize stage is followed by a re-export stage - expected_stages = [ - "quantize", - "export", - "edge_transform_and_lower", - "executorch", - ] - actual_stages = [stage.name for stage in session._pipeline] - self.assertEqual(actual_stages, expected_stages) - - def test_export_session_source_transform_pipeline(self) -> None: - """Test that source transform creates the expected pipeline stages.""" - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, - weight_granularity=PerAxis(axis=0), - ) - quant_recipe = QuantizationRecipe(ao_base_config=[config]) - recipe = ExportRecipe( - name="test_source_transform", quantization_recipe=quant_recipe - ) + # Mock the quantization chain + mock_composed_quantizer = Mock() + mock_composable_quantizer.return_value = mock_composed_quantizer + mock_prepared_model = Mock() + mock_prepare_pt2e.return_value = mock_prepared_model + mock_quantized_model = Mock() + mock_convert_pt2e.return_value = mock_quantized_model - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - ) + artifact = PipelineArtifact(data=self.models_dict, context=self.context) + stage.run(artifact) - # Verify pipeline stages for source transform - expected_stages = [ - "source_transform", - "export", - "edge_transform_and_lower", - "executorch", - ] - actual_stages = [stage.name for stage in session._pipeline] - self.assertEqual(actual_stages, expected_stages) - - def test_export_session_full_quantization_pipeline(self) -> None: - """Test that full quantization (source transform + quantizers) creates the expected pipeline stages.""" - mock_quantizer = Mock() - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, - weight_granularity=PerAxis(axis=0), - ) - quant_recipe = QuantizationRecipe( - quantizers=[mock_quantizer], - ao_base_config=[config], - ) - recipe = ExportRecipe( - name="test_full_quantization", quantization_recipe=quant_recipe + # Verify torch.export.export was called + mock_torch_export.assert_called_once_with( + self.model, self.example_inputs[0], strict=True ) - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - ) + # Verify ComposableQuantizer was created with the quantizers + mock_composable_quantizer.assert_called_once_with([mock_quantizer]) - # Verify pipeline stages for full quantization - # The quantize stage is followed by a re-export stage - expected_stages = [ - "source_transform", - "quantize", - "export", - "edge_transform_and_lower", - "executorch", - ] - actual_stages = [stage.name for stage in session._pipeline] - self.assertEqual(actual_stages, expected_stages) - - @patch("executorch.export.export.ExportSession._run_pipeline") - def test_export_session_export_calls_pipeline( - self, mock_run_pipeline: Mock - ) -> None: - """Test that export() method calls the pipeline.""" - recipe = ExportRecipe(name="test") - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, + # Verify prepare_pt2e was called + mock_prepare_pt2e.assert_called_once_with( + mock_captured_graph, mock_composed_quantizer ) - session.export() - mock_run_pipeline.assert_called_once() - - def test_export_session_standardize_inputs(self) -> None: - """Test that inputs are properly standardized to dictionary format.""" - recipe = ExportRecipe(name="test") + # Verify calibration was performed (prepared model called with example inputs) + mock_prepared_model.assert_called_once_with(*self.example_inputs[0]) - # Test single model and example_inputs - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - ) + # Verify convert_pt2e was called + mock_convert_pt2e.assert_called_once_with(mock_prepared_model) - self.assertIsInstance(session._model, dict) - self.assertIn("forward", session._model) - self.assertEqual(session._model["forward"], self.model) - - self.assertIsInstance(session._example_inputs, dict) - self.assertIn("forward", session._example_inputs) - self.assertEqual(session._example_inputs["forward"], self.example_inputs) - - def test_export_session_dict_inputs(self) -> None: - """Test that dictionary inputs are preserved.""" - recipe = ExportRecipe(name="test") - model_dict = {"method1": self.model, "method2": SimpleTestModel()} - example_inputs_dict = { - "method1": self.example_inputs, - "method2": [(torch.randn(1, 10),)], - } + # Verify artifacts are returned correctly + result_artifact = stage.get_artifacts() + self.assertIn("forward", result_artifact.data) + self.assertEqual(result_artifact.data["forward"], mock_quantized_model) - session = ExportSession( - model=model_dict, - example_inputs=example_inputs_dict, - export_recipe=recipe, + def test_run_empty_example_inputs(self) -> None: + """Test error when example inputs list is empty.""" + mock_quantizer = Mock() + mock_recipe = Mock(spec=QuantizationRecipe) + mock_recipe.quantizers = [mock_quantizer] + stage = QuantizeStage(mock_recipe) + context = {"example_inputs": {"forward": []}} + artifact = PipelineArtifact(data=self.models_dict, context=context) + + with self.assertRaises(ValueError) as cm: + stage.run(artifact) + self.assertIn( + "Example inputs for method forward not found or empty", str(cm.exception) ) - self.assertEqual(session._model, model_dict) - self.assertEqual(session._example_inputs, example_inputs_dict) - def test_export_session_get_example_input(self) -> None: - """Test getting example input for a method.""" - recipe = ExportRecipe(name="test") - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - ) +class TestToEdgeStage(unittest.TestCase): + def setUp(self) -> None: + self.mock_exported_program = Mock(spec=ExportedProgram) + self.exported_programs = {"forward": self.mock_exported_program} + self.context = {"constant_methods": None} - example_input = session.get_example_input("forward") - self.assertEqual(example_input, self.example_inputs[0]) + @patch("executorch.export.stages.to_edge") + def test_run_success(self, mock_to_edge: Mock) -> None: + mock_edge_manager = Mock(spec=EdgeProgramManager) + mock_to_edge.return_value = mock_edge_manager + mock_config = Mock() - def test_export_session_get_example_input_missing_method(self) -> None: - """Test error when getting example input for non-existent method.""" - recipe = ExportRecipe(name="test") - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - ) + stage = ToEdgeStage(edge_compile_config=mock_config) + artifact = PipelineArtifact(data=self.exported_programs, context=self.context) + stage.run(artifact) - with self.assertRaises(KeyError) as context: - session.get_example_input("nonexistent") - self.assertIn("Method name 'nonexistent' not found", str(context.exception)) - - def test_export_session_runtime_errors_before_export(self) -> None: - """Test that runtime errors are raised when accessing results before export.""" - recipe = ExportRecipe(name="test") - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, + # Verify to_edge was called with correct parameters + mock_to_edge.assert_called_once_with( + self.exported_programs, + constant_methods=None, + compile_config=mock_config, ) - with self.assertRaises(RuntimeError): - session.get_executorch_program() - - with self.assertRaises(RuntimeError): - session.get_executorch_program_manager() - - with self.assertRaises(RuntimeError): - session.get_pte_buffer() - - with self.assertRaises(RuntimeError): - session.save_to_pte("test.pte") + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_manager) -class TestExportSessionPipelineExecution(unittest.TestCase): - """Test the actual pipeline execution with mocked stages.""" - +class TestToBackendStage(unittest.TestCase): def setUp(self) -> None: - self.model = SimpleTestModel() - self.example_inputs = [(torch.randn(2, 10),)] + self.mock_edge_manager = Mock(spec=EdgeProgramManager) + self.context = {} - @patch("executorch.export.export.ExecutorchStage") - @patch("executorch.export.export.EdgeTransformAndLowerStage") - @patch("executorch.export.export.ExportStage") - def test_pipeline_execution_order_fp32( - self, - mock_export_stage_class: Mock, - mock_edge_stage_class: Mock, - mock_executorch_stage_class: Mock, + @patch("executorch.export.stages.get_delegation_info") + def test_run_success_no_transforms_or_partitioners( + self, mock_get_delegation_info: Mock ) -> None: - """Test that stages are executed in the correct order for FP32.""" - # Create mock stages - mock_export_stage = Mock() - mock_export_stage.name = "export" - mock_export_stage.get_artifacts.return_value = {"forward": Mock()} - - mock_edge_stage = Mock() - mock_edge_stage.name = "edge_transform_and_lower" - mock_edge_stage.get_artifacts.return_value = Mock() - mock_edge_stage.delegation_info = Mock() - - mock_executorch_stage = Mock() - mock_executorch_stage.name = "executorch" - mock_executorch_stage.get_artifacts.return_value = Mock() - - # Configure the mock classes to return our mock instances - mock_export_stage_class.return_value = mock_export_stage - mock_edge_stage_class.return_value = mock_edge_stage - mock_executorch_stage_class.return_value = mock_executorch_stage - - recipe = ExportRecipe(name="test_fp32") - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, + # Test successful execution without transforms or partitioners + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + self.mock_edge_manager.exported_program.return_value = mock_exported_program + + stage = ToBackendStage() + artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context) + stage.run(artifact) + + # Verify get_delegation_info was called + mock_get_delegation_info.assert_called_once_with(mock_graph_module) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, self.mock_edge_manager) + self.assertEqual( + result_artifact.get_context("delegation_info"), mock_delegation_info ) - session.export() - - # Verify stages were called in the correct order - mock_export_stage.run.assert_called_once() - mock_edge_stage.run.assert_called_once() - mock_executorch_stage.run.assert_called_once() - - @patch("executorch.export.export.ExecutorchStage") - @patch("executorch.export.export.EdgeTransformAndLowerStage") - @patch("executorch.export.export.ExportStage") - @patch("executorch.export.export.QuantizeStage") - def test_pipeline_execution_order_quantized( - self, - mock_quantize_stage_class: Mock, - mock_export_stage_class: Mock, - mock_edge_stage_class: Mock, - mock_executorch_stage_class: Mock, + @patch("executorch.export.stages.get_delegation_info") + def test_run_with_partitioners_and_passes( + self, mock_get_delegation_info: Mock ) -> None: - """Test that stages are executed in the correct order for quantized export.""" - # Create mock stages - mock_quantize_stage = Mock() - mock_quantize_stage.name = "quantize" - mock_quantize_stage.get_artifacts.return_value = {"forward": Mock()} - - mock_export_stage = Mock() - mock_export_stage.name = "export" - mock_export_stage.get_artifacts.return_value = {"forward": Mock()} - - mock_edge_stage = Mock() - mock_edge_stage.name = "edge_transform_and_lower" - mock_edge_stage.get_artifacts.return_value = Mock() - mock_edge_stage.delegation_info = Mock() - - mock_executorch_stage = Mock() - mock_executorch_stage.name = "executorch" - mock_executorch_stage.get_artifacts.return_value = Mock() - - # Configure the mock classes to return our mock instances - mock_quantize_stage_class.return_value = mock_quantize_stage - mock_export_stage_class.return_value = mock_export_stage - mock_edge_stage_class.return_value = mock_edge_stage - mock_executorch_stage_class.return_value = mock_executorch_stage + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module - mock_quantizer = Mock() - quant_recipe = QuantizationRecipe(quantizers=[mock_quantizer]) - recipe = ExportRecipe(name="test_quantized", quantization_recipe=quant_recipe) + mock_edge_program_manager = Mock(spec=EdgeProgramManager) + mock_edge_program_manager.transform.return_value = mock_edge_program_manager + mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager - session = ExportSession( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, + mock_partitioner = Mock() + mock_transform_passes = [Mock(), Mock()] + stage = ToBackendStage( + partitioners=[mock_partitioner], transform_passes=mock_transform_passes ) + artifact = PipelineArtifact( + data=mock_edge_program_manager, context=self.context + ) + stage.run(artifact) - session.export() - - # Verify stages were called in the correct order - mock_quantize_stage.run.assert_called_once() - mock_export_stage.run.assert_called_once() - mock_edge_stage.run.assert_called_once() - mock_executorch_stage.run.assert_called_once() - + # Verify transform and to_backend called correctly + mock_edge_program_manager.transform.assert_called_once_with( + mock_transform_passes + ) + mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner) -class TestExportFunction(unittest.TestCase): - """Test the top-level export function.""" + # Verify artifacts contain the backend manager + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_program_manager) - def setUp(self) -> None: - self.model = SimpleTestModel() - self.example_inputs = [(torch.randn(2, 10),)] + def test_run_edge_manager_none(self) -> None: + stage = ToBackendStage() + artifact = PipelineArtifact(data=None, context=self.context) - @patch("executorch.export.export.ExportSession") - def test_export_function_creates_session_and_exports( - self, mock_session_class: Mock - ) -> None: - """Test that export function creates session and calls export.""" - mock_session = Mock() - mock_session_class.return_value = mock_session - - recipe = ExportRecipe(name="test") - from executorch.export import export - - result = export( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - name="test_export", - ) - mock_session_class.assert_called_once_with( - model=self.model, - example_inputs=self.example_inputs, - export_recipe=recipe, - name="test_export", - dynamic_shapes=None, - constant_methods=None, - artifact_dir=None, - ) - mock_session.export.assert_called_once() - self.assertEqual(result, mock_session) + with self.assertRaises(RuntimeError) as cm: + stage.run(artifact) + self.assertIn("Edge program manager is not set", str(cm.exception)) diff --git a/export/types.py b/export/types.py new file mode 100644 index 00000000000..760f8461d41 --- /dev/null +++ b/export/types.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +class StageType(str, Enum): + """ + Enum representing the different stages in the ExecuTorch export pipeline. + """ + + SOURCE_TRANSFORM = "source_transform" + QUANTIZE = "quantize" + TORCH_EXPORT = "torch_export" + TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower" + TO_EDGE = "to_edge" + TO_BACKEND = "to_backend" + TO_EXECUTORCH = "to_executorch"