Skip to content

Commit 66e7c4a

Browse files
[Executorch][Export][2/N] Add to_edge and to_backend stages
Address (6) in the rfc: #12660 1. Adds stage implementations for `to_edge` and `to_backend` 2. Adds unit tests for the two stages 3. Adds these two stages in the validation pipeline. Fixes #12932 Differential Revision: [D79120576](https://our.internmc.facebook.com/intern/diff/D79120576/) ghstack-source-id: 299106075 Pull Request resolved: #12937
1 parent 5db1bfb commit 66e7c4a

File tree

3 files changed

+217
-1
lines changed

3 files changed

+217
-1
lines changed

export/export.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
SourceTransformStage,
2626
Stage,
2727
StageType,
28+
ToBackendStage,
29+
ToEdgeStage,
2830
TorchExportStage,
2931
)
3032

@@ -151,9 +153,12 @@ def __init__(
151153
StageType.SOURCE_TRANSFORM: [StageType.QUANTIZE, StageType.TORCH_EXPORT],
152154
StageType.QUANTIZE: [StageType.TORCH_EXPORT],
153155
StageType.TORCH_EXPORT: [
156+
StageType.TO_EDGE,
154157
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
155158
],
156159
StageType.TO_EDGE_TRANSFORM_AND_LOWER: [StageType.TO_EXECUTORCH],
160+
StageType.TO_EDGE: [StageType.TO_BACKEND],
161+
StageType.TO_BACKEND: [StageType.TO_EXECUTORCH],
157162
StageType.TO_EXECUTORCH: [],
158163
}
159164

@@ -199,6 +204,15 @@ def _build_pipeline_from_stages(self, stage_types: List[StageType]) -> List[Stag
199204
transform_passes=self._export_recipe.edge_transform_passes,
200205
compile_config=self._export_recipe.edge_compile_config,
201206
)
207+
elif stage_type == StageType.TO_EDGE:
208+
stage = ToEdgeStage(
209+
edge_compile_config=self._export_recipe.edge_compile_config
210+
)
211+
elif stage_type == StageType.TO_BACKEND:
212+
stage = ToBackendStage(
213+
partitioners=self._export_recipe.partitioners,
214+
transform_passes=self._export_recipe.edge_transform_passes,
215+
)
202216
elif stage_type == StageType.TO_EXECUTORCH:
203217
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)
204218
else:

export/stages.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
import torch
1313
from executorch.devtools.backend_debug import get_delegation_info
14+
from executorch.exir import EdgeCompileConfig
1415
from executorch.exir.backend.backend_api import validation_disabled
15-
from executorch.exir.program import to_edge_transform_and_lower
16+
from executorch.exir.program import to_edge, to_edge_transform_and_lower
1617
from executorch.exir.program._program import _transform
1718
from executorch.export.recipe import QuantizationRecipe
1819
from torch import nn
@@ -29,6 +30,8 @@ class StageType(str, Enum):
2930
TORCH_EXPORT = "torch_export"
3031
TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower"
3132
TO_EXECUTORCH = "to_executorch"
33+
TO_EDGE = "to_edge"
34+
TO_BACKEND = "to_backend"
3235

3336

3437
class PipelineArtifact:
@@ -306,3 +309,98 @@ def run(self, artifact: PipelineArtifact) -> None:
306309
quantized_models[method_name] = quantized_model
307310

308311
self._artifact = artifact.copy_with_new_data(quantized_models)
312+
313+
314+
class ToEdgeStage(Stage):
315+
"""
316+
Stage: Convert ExportedProgram to EdgeProgramManager.
317+
"""
318+
319+
def __init__(
320+
self,
321+
edge_compile_config: Optional[EdgeCompileConfig] = None, # pyre-ignore
322+
) -> None:
323+
super().__init__()
324+
self._edge_compile_config = edge_compile_config
325+
326+
@property
327+
def stage_type(self) -> str:
328+
return StageType.TO_EDGE
329+
330+
def run(self, artifact: PipelineArtifact) -> None:
331+
"""
332+
Convert ExportedProgram to EdgeProgramManager.
333+
334+
Args:
335+
artifact: Contains exported programs and context
336+
"""
337+
exported_programs = artifact.data
338+
constant_methods = artifact.get_context("constant_methods")
339+
340+
# Convert to edge program manager
341+
edge_program_manager = to_edge(
342+
exported_programs,
343+
constant_methods=constant_methods,
344+
compile_config=self._edge_compile_config,
345+
)
346+
347+
self._artifact = artifact.copy_with_new_data(edge_program_manager)
348+
349+
350+
class ToBackendStage(Stage):
351+
"""
352+
Stage: Apply transformations and partitioning to EdgeProgramManager.
353+
"""
354+
355+
def __init__(
356+
self,
357+
partitioners: Optional[List[Any]] = None,
358+
transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None,
359+
) -> None:
360+
super().__init__()
361+
self._partitioners = partitioners
362+
self._transform_passes = transform_passes
363+
364+
@property
365+
def stage_type(self) -> str:
366+
return StageType.TO_BACKEND
367+
368+
def run(self, artifact: PipelineArtifact) -> None:
369+
"""
370+
Apply transformations and partitioning to EdgeProgramManager.
371+
372+
Args:
373+
artifact: Contains edge program manager and context
374+
"""
375+
edge_program_manager = artifact.data
376+
377+
if edge_program_manager is None:
378+
raise RuntimeError("Edge program manager is not set.")
379+
380+
# Apply transform passes if available
381+
if self._transform_passes:
382+
edge_program_manager = edge_program_manager.transform(
383+
self._transform_passes
384+
)
385+
386+
# Apply partitioners if available
387+
if self._partitioners is not None and len(self._partitioners) > 0:
388+
with validation_disabled():
389+
# pyre-ignore
390+
for partitioner in self._partitioners:
391+
edge_program_manager = edge_program_manager.to_backend(partitioner)
392+
393+
# Get delegation info
394+
delegation_info = get_delegation_info(
395+
edge_program_manager.exported_program().graph_module
396+
)
397+
398+
self._artifact = artifact.copy_with_new_data(edge_program_manager)
399+
self._artifact.add_context("delegation_info", delegation_info)
400+
401+
@property
402+
def delegation_info(self) -> Any:
403+
"""
404+
Returns the delegation info.
405+
"""
406+
return self._artifact.get_context("delegation_info")

export/tests/test_export_stages.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
QuantizeStage,
2020
SourceTransformStage,
2121
StageType,
22+
ToBackendStage,
23+
ToEdgeStage,
2224
TorchExportStage,
2325
)
2426
from torch.export import ExportedProgram
@@ -282,3 +284,105 @@ def test_run_empty_example_inputs(self) -> None:
282284
self.assertIn(
283285
"Example inputs for method forward not found or empty", str(cm.exception)
284286
)
287+
288+
289+
class TestToEdgeStage(unittest.TestCase):
290+
def setUp(self) -> None:
291+
self.mock_exported_program = Mock(spec=ExportedProgram)
292+
self.exported_programs = {"forward": self.mock_exported_program}
293+
self.context = {"constant_methods": None}
294+
295+
@patch("executorch.export.stages.to_edge")
296+
def test_run_success(self, mock_to_edge: Mock) -> None:
297+
mock_edge_manager = Mock(spec=EdgeProgramManager)
298+
mock_to_edge.return_value = mock_edge_manager
299+
mock_config = Mock()
300+
301+
stage = ToEdgeStage(edge_compile_config=mock_config)
302+
artifact = PipelineArtifact(data=self.exported_programs, context=self.context)
303+
stage.run(artifact)
304+
305+
# Verify to_edge was called with correct parameters
306+
mock_to_edge.assert_called_once_with(
307+
self.exported_programs,
308+
constant_methods=None,
309+
compile_config=mock_config,
310+
)
311+
312+
# Verify artifacts are set correctly
313+
result_artifact = stage.get_artifacts()
314+
self.assertEqual(result_artifact.data, mock_edge_manager)
315+
316+
317+
class TestToBackendStage(unittest.TestCase):
318+
def setUp(self) -> None:
319+
self.mock_edge_manager = Mock(spec=EdgeProgramManager)
320+
self.context = {}
321+
322+
@patch("executorch.export.stages.get_delegation_info")
323+
def test_run_success_no_transforms_or_partitioners(
324+
self, mock_get_delegation_info: Mock
325+
) -> None:
326+
# Test successful execution without transforms or partitioners
327+
mock_delegation_info = {"delegation": "info"}
328+
mock_get_delegation_info.return_value = mock_delegation_info
329+
mock_exported_program = Mock()
330+
mock_graph_module = Mock()
331+
mock_exported_program.graph_module = mock_graph_module
332+
self.mock_edge_manager.exported_program.return_value = mock_exported_program
333+
334+
stage = ToBackendStage()
335+
artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context)
336+
stage.run(artifact)
337+
338+
# Verify get_delegation_info was called
339+
mock_get_delegation_info.assert_called_once_with(mock_graph_module)
340+
341+
# Verify artifacts are set correctly
342+
result_artifact = stage.get_artifacts()
343+
self.assertEqual(result_artifact.data, self.mock_edge_manager)
344+
self.assertEqual(
345+
result_artifact.get_context("delegation_info"), mock_delegation_info
346+
)
347+
348+
@patch("executorch.export.stages.get_delegation_info")
349+
def test_run_with_partitioners_and_passes(
350+
self, mock_get_delegation_info: Mock
351+
) -> None:
352+
mock_delegation_info = {"delegation": "info"}
353+
mock_get_delegation_info.return_value = mock_delegation_info
354+
mock_exported_program = Mock()
355+
mock_graph_module = Mock()
356+
mock_exported_program.graph_module = mock_graph_module
357+
358+
mock_edge_program_manager = Mock(spec=EdgeProgramManager)
359+
mock_edge_program_manager.transform.return_value = mock_edge_program_manager
360+
mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager
361+
362+
mock_partitioner = Mock()
363+
mock_transform_passes = [Mock(), Mock()]
364+
stage = ToBackendStage(
365+
partitioners=[mock_partitioner], transform_passes=mock_transform_passes
366+
)
367+
artifact = PipelineArtifact(
368+
data=mock_edge_program_manager, context=self.context
369+
)
370+
stage.run(artifact)
371+
372+
# Verify transform and to_backend called correctly
373+
mock_edge_program_manager.transform.assert_called_once_with(
374+
mock_transform_passes
375+
)
376+
mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner)
377+
378+
# Verify artifacts contain the backend manager
379+
result_artifact = stage.get_artifacts()
380+
self.assertEqual(result_artifact.data, mock_edge_program_manager)
381+
382+
def test_run_edge_manager_none(self) -> None:
383+
stage = ToBackendStage()
384+
artifact = PipelineArtifact(data=None, context=self.context)
385+
386+
with self.assertRaises(RuntimeError) as cm:
387+
stage.run(artifact)
388+
self.assertIn("Edge program manager is not set", str(cm.exception))

0 commit comments

Comments
 (0)