Skip to content

Commit 10c502c

Browse files
[Executorch][Export][2/N] Add to_edge and to_backend stages
Pull Request resolved: #12937 Address (6) in the rfc: #12660 1. Adds stage implementations for `to_edge` and `to_backend` 2. Adds unit tests for the two stages 3. Adds these two stages in the validation pipeline. Fixes #12932 ghstack-source-id: 299959405 @exported-using-ghexport Differential Revision: [D79120576](https://our.internmc.facebook.com/intern/diff/D79120576/)
1 parent 5f35064 commit 10c502c

File tree

5 files changed

+260
-10
lines changed

5 files changed

+260
-10
lines changed

export/export.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
QuantizeStage,
2525
SourceTransformStage,
2626
Stage,
27+
ToBackendStage,
28+
ToEdgeStage,
2729
TorchExportStage,
2830
)
2931
from .types import StageType
@@ -147,7 +149,9 @@ def __init__(
147149
)
148150

149151
# Stage registry: map of StageType to Stage instances
150-
self._stage_registry: Dict[StageType, Stage] = self._build_default_stages()
152+
self._stage_registry: Dict[StageType, Stage] = self._build_stages(
153+
self._pipeline_stages
154+
)
151155

152156
# Intialize run context
153157
self._run_context: Dict[str, Any] = {
@@ -170,10 +174,12 @@ def _get_default_pipeline(self) -> List[StageType]:
170174
StageType.TO_EXECUTORCH,
171175
]
172176

173-
def _build_default_stages(self) -> Dict[StageType, Stage]:
177+
def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
178+
"""Build the stage registry from the given stages."""
174179
stage_registry: Dict[StageType, Stage] = {}
175180

176-
for stage_type in self._get_default_pipeline():
181+
stage = None
182+
for stage_type in stages or self._get_default_pipeline():
177183
if stage_type == StageType.SOURCE_TRANSFORM:
178184
stage = SourceTransformStage(self._quant_recipe)
179185
elif stage_type == StageType.QUANTIZE:
@@ -191,12 +197,24 @@ def _build_default_stages(self) -> Dict[StageType, Stage]:
191197
transform_passes=self._export_recipe.edge_transform_passes,
192198
compile_config=self._export_recipe.edge_compile_config,
193199
)
200+
elif stage_type == StageType.TO_EDGE:
201+
stage = ToEdgeStage(
202+
edge_compile_config=self._export_recipe.edge_compile_config
203+
)
204+
elif stage_type == StageType.TO_BACKEND:
205+
stage = ToBackendStage(
206+
partitioners=self._export_recipe.partitioners,
207+
transform_passes=self._export_recipe.edge_transform_passes,
208+
)
194209
elif stage_type == StageType.TO_EXECUTORCH:
195210
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)
196211
else:
197-
raise ValueError(f"Unknown stage type: {stage_type}")
212+
logging.info(
213+
f"{stage_type} is unknown, you have to register it before executing export()"
214+
)
198215

199-
stage_registry[stage_type] = stage
216+
if stage:
217+
stage_registry[stage_type] = stage
200218
return stage_registry
201219

202220
def register_stage(self, stage_type: StageType, stage: Stage) -> None:
@@ -241,7 +259,9 @@ def _validate_pipeline_sequence(
241259
first_stage = stages[0]
242260
first_stage_instance = self._stage_registry.get(first_stage)
243261
if first_stage_instance is None:
244-
raise ValueError(f"Stage {first_stage} not found in registry")
262+
raise ValueError(
263+
f"Stage {first_stage} not found in registry, register it using session.register_stage()"
264+
)
245265

246266
if not first_stage_instance.can_start_pipeline:
247267
raise ValueError(f"Stage {first_stage} cannot start a pipeline. ")
@@ -254,7 +274,9 @@ def _validate_pipeline_sequence(
254274
# Get the stage instance to check its valid predecessors
255275
stage_instance = self._stage_registry.get(current_stage)
256276
if stage_instance is None:
257-
raise ValueError(f"Stage {current_stage} not found in registry")
277+
raise ValueError(
278+
f"Stage {current_stage} not found in registry, , register it using session.register_stage()"
279+
)
258280

259281
valid_predecessors = stage_instance.valid_predecessor_stages
260282

export/stages.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
import torch
1212
from executorch.devtools.backend_debug import get_delegation_info
13+
from executorch.exir import EdgeCompileConfig
1314
from executorch.exir.backend.backend_api import validation_disabled
14-
from executorch.exir.program import to_edge_transform_and_lower
15+
from executorch.exir.program import to_edge, to_edge_transform_and_lower
1516
from executorch.exir.program._program import _transform
1617
from executorch.export.recipe import QuantizationRecipe
1718
from executorch.export.types import StageType
@@ -223,7 +224,7 @@ def stage_type(self) -> str:
223224

224225
@property
225226
def valid_predecessor_stages(self) -> List["StageType"]:
226-
return [StageType.TO_EDGE_TRANSFORM_AND_LOWER]
227+
return [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND]
227228

228229
@property
229230
def can_start_pipeline(self) -> bool:
@@ -354,3 +355,114 @@ def run(self, artifact: PipelineArtifact) -> None:
354355
quantized_models[method_name] = quantized_model
355356

356357
self._artifact = artifact.copy_with_new_data(quantized_models)
358+
359+
360+
class ToEdgeStage(Stage):
361+
"""
362+
Stage: Convert ExportedProgram to EdgeProgramManager.
363+
"""
364+
365+
def __init__(
366+
self,
367+
edge_compile_config: Optional[EdgeCompileConfig] = None, # pyre-ignore
368+
) -> None:
369+
super().__init__()
370+
self._edge_compile_config = edge_compile_config
371+
372+
@property
373+
def stage_type(self) -> str:
374+
return StageType.TO_EDGE
375+
376+
@property
377+
def valid_predecessor_stages(self) -> List["StageType"]:
378+
return [StageType.TORCH_EXPORT]
379+
380+
@property
381+
def can_start_pipeline(self) -> bool:
382+
return False
383+
384+
def run(self, artifact: PipelineArtifact) -> None:
385+
"""
386+
Convert ExportedProgram to EdgeProgramManager.
387+
388+
Args:
389+
artifact: Contains exported programs and context
390+
"""
391+
exported_programs = artifact.data
392+
constant_methods = artifact.get_context("constant_methods")
393+
394+
# Convert to edge program manager
395+
edge_program_manager = to_edge(
396+
exported_programs,
397+
constant_methods=constant_methods,
398+
compile_config=self._edge_compile_config,
399+
)
400+
401+
self._artifact = artifact.copy_with_new_data(edge_program_manager)
402+
403+
404+
class ToBackendStage(Stage):
405+
"""
406+
Stage: Apply transformations and partitioning to EdgeProgramManager.
407+
"""
408+
409+
def __init__(
410+
self,
411+
partitioners: Optional[List[Any]] = None,
412+
transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None,
413+
) -> None:
414+
super().__init__()
415+
self._partitioners = partitioners
416+
self._transform_passes = transform_passes
417+
418+
@property
419+
def stage_type(self) -> str:
420+
return StageType.TO_BACKEND
421+
422+
@property
423+
def valid_predecessor_stages(self) -> List["StageType"]:
424+
return [StageType.TO_EDGE]
425+
426+
@property
427+
def can_start_pipeline(self) -> bool:
428+
return False
429+
430+
def run(self, artifact: PipelineArtifact) -> None:
431+
"""
432+
Apply transformations and partitioning to EdgeProgramManager.
433+
434+
Args:
435+
artifact: Contains edge program manager and context
436+
"""
437+
edge_program_manager = artifact.data
438+
439+
if edge_program_manager is None:
440+
raise RuntimeError("Edge program manager is not set.")
441+
442+
# Apply transform passes if available
443+
if self._transform_passes:
444+
edge_program_manager = edge_program_manager.transform(
445+
self._transform_passes
446+
)
447+
448+
# Apply partitioners if available
449+
if self._partitioners is not None and len(self._partitioners) > 0:
450+
with validation_disabled():
451+
# pyre-ignore
452+
for partitioner in self._partitioners:
453+
edge_program_manager = edge_program_manager.to_backend(partitioner)
454+
455+
# Get delegation info
456+
delegation_info = get_delegation_info(
457+
edge_program_manager.exported_program().graph_module
458+
)
459+
460+
self._artifact = artifact.copy_with_new_data(edge_program_manager)
461+
self._artifact.add_context("delegation_info", delegation_info)
462+
463+
@property
464+
def delegation_info(self) -> Any:
465+
"""
466+
Returns the delegation info.
467+
"""
468+
return self._artifact.get_context("delegation_info")

export/tests/test_export_session.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,23 @@ def _get_export_session(self, stages: List[StageType]):
249249
def test_valid_pipeline_sequences(self) -> None:
250250
"""Test various valid pipeline sequences."""
251251
valid_sequences = [
252-
# Full pipeline
252+
# Full pipeline with to_edge_transform_lower
253253
[
254254
StageType.SOURCE_TRANSFORM,
255255
StageType.QUANTIZE,
256256
StageType.TORCH_EXPORT,
257257
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
258258
StageType.TO_EXECUTORCH,
259259
],
260+
# Full pipeline with to_edge, to_backend
261+
[
262+
StageType.SOURCE_TRANSFORM,
263+
StageType.QUANTIZE,
264+
StageType.TORCH_EXPORT,
265+
StageType.TO_EDGE,
266+
StageType.TO_BACKEND,
267+
StageType.TO_EXECUTORCH,
268+
],
260269
# Skip quantize
261270
[
262271
StageType.SOURCE_TRANSFORM,

export/tests/test_export_stages.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
QuantizeStage,
2020
SourceTransformStage,
2121
StageType,
22+
ToBackendStage,
23+
ToEdgeStage,
2224
TorchExportStage,
2325
)
2426
from torch.export import ExportedProgram
@@ -282,3 +284,105 @@ def test_run_empty_example_inputs(self) -> None:
282284
self.assertIn(
283285
"Example inputs for method forward not found or empty", str(cm.exception)
284286
)
287+
288+
289+
class TestToEdgeStage(unittest.TestCase):
290+
def setUp(self) -> None:
291+
self.mock_exported_program = Mock(spec=ExportedProgram)
292+
self.exported_programs = {"forward": self.mock_exported_program}
293+
self.context = {"constant_methods": None}
294+
295+
@patch("executorch.export.stages.to_edge")
296+
def test_run_success(self, mock_to_edge: Mock) -> None:
297+
mock_edge_manager = Mock(spec=EdgeProgramManager)
298+
mock_to_edge.return_value = mock_edge_manager
299+
mock_config = Mock()
300+
301+
stage = ToEdgeStage(edge_compile_config=mock_config)
302+
artifact = PipelineArtifact(data=self.exported_programs, context=self.context)
303+
stage.run(artifact)
304+
305+
# Verify to_edge was called with correct parameters
306+
mock_to_edge.assert_called_once_with(
307+
self.exported_programs,
308+
constant_methods=None,
309+
compile_config=mock_config,
310+
)
311+
312+
# Verify artifacts are set correctly
313+
result_artifact = stage.get_artifacts()
314+
self.assertEqual(result_artifact.data, mock_edge_manager)
315+
316+
317+
class TestToBackendStage(unittest.TestCase):
318+
def setUp(self) -> None:
319+
self.mock_edge_manager = Mock(spec=EdgeProgramManager)
320+
self.context = {}
321+
322+
@patch("executorch.export.stages.get_delegation_info")
323+
def test_run_success_no_transforms_or_partitioners(
324+
self, mock_get_delegation_info: Mock
325+
) -> None:
326+
# Test successful execution without transforms or partitioners
327+
mock_delegation_info = {"delegation": "info"}
328+
mock_get_delegation_info.return_value = mock_delegation_info
329+
mock_exported_program = Mock()
330+
mock_graph_module = Mock()
331+
mock_exported_program.graph_module = mock_graph_module
332+
self.mock_edge_manager.exported_program.return_value = mock_exported_program
333+
334+
stage = ToBackendStage()
335+
artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context)
336+
stage.run(artifact)
337+
338+
# Verify get_delegation_info was called
339+
mock_get_delegation_info.assert_called_once_with(mock_graph_module)
340+
341+
# Verify artifacts are set correctly
342+
result_artifact = stage.get_artifacts()
343+
self.assertEqual(result_artifact.data, self.mock_edge_manager)
344+
self.assertEqual(
345+
result_artifact.get_context("delegation_info"), mock_delegation_info
346+
)
347+
348+
@patch("executorch.export.stages.get_delegation_info")
349+
def test_run_with_partitioners_and_passes(
350+
self, mock_get_delegation_info: Mock
351+
) -> None:
352+
mock_delegation_info = {"delegation": "info"}
353+
mock_get_delegation_info.return_value = mock_delegation_info
354+
mock_exported_program = Mock()
355+
mock_graph_module = Mock()
356+
mock_exported_program.graph_module = mock_graph_module
357+
358+
mock_edge_program_manager = Mock(spec=EdgeProgramManager)
359+
mock_edge_program_manager.transform.return_value = mock_edge_program_manager
360+
mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager
361+
362+
mock_partitioner = Mock()
363+
mock_transform_passes = [Mock(), Mock()]
364+
stage = ToBackendStage(
365+
partitioners=[mock_partitioner], transform_passes=mock_transform_passes
366+
)
367+
artifact = PipelineArtifact(
368+
data=mock_edge_program_manager, context=self.context
369+
)
370+
stage.run(artifact)
371+
372+
# Verify transform and to_backend called correctly
373+
mock_edge_program_manager.transform.assert_called_once_with(
374+
mock_transform_passes
375+
)
376+
mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner)
377+
378+
# Verify artifacts contain the backend manager
379+
result_artifact = stage.get_artifacts()
380+
self.assertEqual(result_artifact.data, mock_edge_program_manager)
381+
382+
def test_run_edge_manager_none(self) -> None:
383+
stage = ToBackendStage()
384+
artifact = PipelineArtifact(data=None, context=self.context)
385+
386+
with self.assertRaises(RuntimeError) as cm:
387+
stage.run(artifact)
388+
self.assertIn("Edge program manager is not set", str(cm.exception))

export/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ class StageType(str, Enum):
1111
"""
1212
Enum representing the different stages in the ExecuTorch export pipeline.
1313
"""
14+
1415
SOURCE_TRANSFORM = "source_transform"
1516
QUANTIZE = "quantize"
1617
TORCH_EXPORT = "torch_export"
1718
TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower"
19+
TO_EDGE = "to_edge"
20+
TO_BACKEND = "to_backend"
1821
TO_EXECUTORCH = "to_executorch"

0 commit comments

Comments
 (0)