Skip to content

Commit 9430abd

Browse files
[Executorch][Export][2/N] Add to_edge and to_backend stages
Pull Request resolved: #12937 Address (6) in the rfc: #12660 1. Adds stage implementations for `to_edge` and `to_backend` 2. Adds unit tests for the two stages 3. Adds these two stages in the validation pipeline. Fixes #12932 ghstack-source-id: 299637526 @exported-using-ghexport Differential Revision: [D79120576](https://our.internmc.facebook.com/intern/diff/D79120576/)
1 parent 0ff40fd commit 9430abd

File tree

3 files changed

+230
-1
lines changed

3 files changed

+230
-1
lines changed

export/export.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
SourceTransformStage,
2626
Stage,
2727
StageType,
28+
ToBackendStage,
29+
ToEdgeStage,
2830
TorchExportStage,
2931
)
3032

@@ -191,6 +193,15 @@ def _build_default_stages(self) -> Dict[StageType, Stage]:
191193
transform_passes=self._export_recipe.edge_transform_passes,
192194
compile_config=self._export_recipe.edge_compile_config,
193195
)
196+
elif stage_type == StageType.TO_EDGE:
197+
stage = ToEdgeStage(
198+
edge_compile_config=self._export_recipe.edge_compile_config
199+
)
200+
elif stage_type == StageType.TO_BACKEND:
201+
stage = ToBackendStage(
202+
partitioners=self._export_recipe.partitioners,
203+
transform_passes=self._export_recipe.edge_transform_passes,
204+
)
194205
elif stage_type == StageType.TO_EXECUTORCH:
195206
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)
196207
else:

export/stages.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
import torch
1313
from executorch.devtools.backend_debug import get_delegation_info
14+
from executorch.exir import EdgeCompileConfig
1415
from executorch.exir.backend.backend_api import validation_disabled
15-
from executorch.exir.program import to_edge_transform_and_lower
16+
from executorch.exir.program import to_edge, to_edge_transform_and_lower
1617
from executorch.exir.program._program import _transform
1718
from executorch.export.recipe import QuantizationRecipe
1819
from torch import nn
@@ -29,6 +30,8 @@ class StageType(str, Enum):
2930
TORCH_EXPORT = "torch_export"
3031
TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower"
3132
TO_EXECUTORCH = "to_executorch"
33+
TO_EDGE = "to_edge"
34+
TO_BACKEND = "to_backend"
3235

3336

3437
class PipelineArtifact:
@@ -362,3 +365,114 @@ def run(self, artifact: PipelineArtifact) -> None:
362365
quantized_models[method_name] = quantized_model
363366

364367
self._artifact = artifact.copy_with_new_data(quantized_models)
368+
369+
370+
class ToEdgeStage(Stage):
371+
"""
372+
Stage: Convert ExportedProgram to EdgeProgramManager.
373+
"""
374+
375+
def __init__(
376+
self,
377+
edge_compile_config: Optional[EdgeCompileConfig] = None, # pyre-ignore
378+
) -> None:
379+
super().__init__()
380+
self._edge_compile_config = edge_compile_config
381+
382+
@property
383+
def stage_type(self) -> str:
384+
return StageType.TO_EDGE
385+
386+
@property
387+
def valid_predecessor_stages(self) -> List["StageType"]:
388+
return [StageType.TORCH_EXPORT]
389+
390+
@property
391+
def can_start_pipeline(self) -> bool:
392+
return False
393+
394+
def run(self, artifact: PipelineArtifact) -> None:
395+
"""
396+
Convert ExportedProgram to EdgeProgramManager.
397+
398+
Args:
399+
artifact: Contains exported programs and context
400+
"""
401+
exported_programs = artifact.data
402+
constant_methods = artifact.get_context("constant_methods")
403+
404+
# Convert to edge program manager
405+
edge_program_manager = to_edge(
406+
exported_programs,
407+
constant_methods=constant_methods,
408+
compile_config=self._edge_compile_config,
409+
)
410+
411+
self._artifact = artifact.copy_with_new_data(edge_program_manager)
412+
413+
414+
class ToBackendStage(Stage):
415+
"""
416+
Stage: Apply transformations and partitioning to EdgeProgramManager.
417+
"""
418+
419+
def __init__(
420+
self,
421+
partitioners: Optional[List[Any]] = None,
422+
transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None,
423+
) -> None:
424+
super().__init__()
425+
self._partitioners = partitioners
426+
self._transform_passes = transform_passes
427+
428+
@property
429+
def stage_type(self) -> str:
430+
return StageType.TO_BACKEND
431+
432+
@property
433+
def valid_predecessor_stages(self) -> List["StageType"]:
434+
return [StageType.TO_EDGE]
435+
436+
@property
437+
def can_start_pipeline(self) -> bool:
438+
return False
439+
440+
def run(self, artifact: PipelineArtifact) -> None:
441+
"""
442+
Apply transformations and partitioning to EdgeProgramManager.
443+
444+
Args:
445+
artifact: Contains edge program manager and context
446+
"""
447+
edge_program_manager = artifact.data
448+
449+
if edge_program_manager is None:
450+
raise RuntimeError("Edge program manager is not set.")
451+
452+
# Apply transform passes if available
453+
if self._transform_passes:
454+
edge_program_manager = edge_program_manager.transform(
455+
self._transform_passes
456+
)
457+
458+
# Apply partitioners if available
459+
if self._partitioners is not None and len(self._partitioners) > 0:
460+
with validation_disabled():
461+
# pyre-ignore
462+
for partitioner in self._partitioners:
463+
edge_program_manager = edge_program_manager.to_backend(partitioner)
464+
465+
# Get delegation info
466+
delegation_info = get_delegation_info(
467+
edge_program_manager.exported_program().graph_module
468+
)
469+
470+
self._artifact = artifact.copy_with_new_data(edge_program_manager)
471+
self._artifact.add_context("delegation_info", delegation_info)
472+
473+
@property
474+
def delegation_info(self) -> Any:
475+
"""
476+
Returns the delegation info.
477+
"""
478+
return self._artifact.get_context("delegation_info")

export/tests/test_export_stages.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
QuantizeStage,
2020
SourceTransformStage,
2121
StageType,
22+
ToBackendStage,
23+
ToEdgeStage,
2224
TorchExportStage,
2325
)
2426
from torch.export import ExportedProgram
@@ -282,3 +284,105 @@ def test_run_empty_example_inputs(self) -> None:
282284
self.assertIn(
283285
"Example inputs for method forward not found or empty", str(cm.exception)
284286
)
287+
288+
289+
class TestToEdgeStage(unittest.TestCase):
290+
def setUp(self) -> None:
291+
self.mock_exported_program = Mock(spec=ExportedProgram)
292+
self.exported_programs = {"forward": self.mock_exported_program}
293+
self.context = {"constant_methods": None}
294+
295+
@patch("executorch.export.stages.to_edge")
296+
def test_run_success(self, mock_to_edge: Mock) -> None:
297+
mock_edge_manager = Mock(spec=EdgeProgramManager)
298+
mock_to_edge.return_value = mock_edge_manager
299+
mock_config = Mock()
300+
301+
stage = ToEdgeStage(edge_compile_config=mock_config)
302+
artifact = PipelineArtifact(data=self.exported_programs, context=self.context)
303+
stage.run(artifact)
304+
305+
# Verify to_edge was called with correct parameters
306+
mock_to_edge.assert_called_once_with(
307+
self.exported_programs,
308+
constant_methods=None,
309+
compile_config=mock_config,
310+
)
311+
312+
# Verify artifacts are set correctly
313+
result_artifact = stage.get_artifacts()
314+
self.assertEqual(result_artifact.data, mock_edge_manager)
315+
316+
317+
class TestToBackendStage(unittest.TestCase):
318+
def setUp(self) -> None:
319+
self.mock_edge_manager = Mock(spec=EdgeProgramManager)
320+
self.context = {}
321+
322+
@patch("executorch.export.stages.get_delegation_info")
323+
def test_run_success_no_transforms_or_partitioners(
324+
self, mock_get_delegation_info: Mock
325+
) -> None:
326+
# Test successful execution without transforms or partitioners
327+
mock_delegation_info = {"delegation": "info"}
328+
mock_get_delegation_info.return_value = mock_delegation_info
329+
mock_exported_program = Mock()
330+
mock_graph_module = Mock()
331+
mock_exported_program.graph_module = mock_graph_module
332+
self.mock_edge_manager.exported_program.return_value = mock_exported_program
333+
334+
stage = ToBackendStage()
335+
artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context)
336+
stage.run(artifact)
337+
338+
# Verify get_delegation_info was called
339+
mock_get_delegation_info.assert_called_once_with(mock_graph_module)
340+
341+
# Verify artifacts are set correctly
342+
result_artifact = stage.get_artifacts()
343+
self.assertEqual(result_artifact.data, self.mock_edge_manager)
344+
self.assertEqual(
345+
result_artifact.get_context("delegation_info"), mock_delegation_info
346+
)
347+
348+
@patch("executorch.export.stages.get_delegation_info")
349+
def test_run_with_partitioners_and_passes(
350+
self, mock_get_delegation_info: Mock
351+
) -> None:
352+
mock_delegation_info = {"delegation": "info"}
353+
mock_get_delegation_info.return_value = mock_delegation_info
354+
mock_exported_program = Mock()
355+
mock_graph_module = Mock()
356+
mock_exported_program.graph_module = mock_graph_module
357+
358+
mock_edge_program_manager = Mock(spec=EdgeProgramManager)
359+
mock_edge_program_manager.transform.return_value = mock_edge_program_manager
360+
mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager
361+
362+
mock_partitioner = Mock()
363+
mock_transform_passes = [Mock(), Mock()]
364+
stage = ToBackendStage(
365+
partitioners=[mock_partitioner], transform_passes=mock_transform_passes
366+
)
367+
artifact = PipelineArtifact(
368+
data=mock_edge_program_manager, context=self.context
369+
)
370+
stage.run(artifact)
371+
372+
# Verify transform and to_backend called correctly
373+
mock_edge_program_manager.transform.assert_called_once_with(
374+
mock_transform_passes
375+
)
376+
mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner)
377+
378+
# Verify artifacts contain the backend manager
379+
result_artifact = stage.get_artifacts()
380+
self.assertEqual(result_artifact.data, mock_edge_program_manager)
381+
382+
def test_run_edge_manager_none(self) -> None:
383+
stage = ToBackendStage()
384+
artifact = PipelineArtifact(data=None, context=self.context)
385+
386+
with self.assertRaises(RuntimeError) as cm:
387+
stage.run(artifact)
388+
self.assertIn("Edge program manager is not set", str(cm.exception))

0 commit comments

Comments
 (0)