Skip to content

Commit f81e834

Browse files
authored
Add strict-flag to ExportSession (pytorch#14588)
**Add strict export option to ExportRecipe** Default is True, mirroring earlier behavior. Also update ExportSession to handle this. Signed-off-by: Erik Lundell <[email protected]>
1 parent 3b16bc1 commit f81e834

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

export/export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -200,7 +201,9 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
200201
aten_transform_passes = list(
201202
self._export_recipe.aten_transform_passes
202203
)
203-
stage = TorchExportStage(aten_transform_passes)
204+
stage = TorchExportStage(
205+
aten_transform_passes, strict=self._export_recipe.strict
206+
)
204207
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
205208
stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe)
206209
elif stage_type == StageType.TO_EDGE:

export/recipe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -151,6 +152,7 @@ class ExportRecipe:
151152
executorch_backend_config: Optional backend configuration for ExecuTorch
152153
pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline.
153154
mode: Export mode (debug or release)
155+
strict: Set the strict flag in the torch export call.
154156
"""
155157

156158
name: Optional[str] = None
@@ -163,6 +165,7 @@ class ExportRecipe:
163165
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
164166
pipeline_stages: Optional[List[StageType]] = None
165167
mode: Mode = Mode.RELEASE
168+
strict: bool = True
166169

167170
@classmethod
168171
def get_recipe(cls, recipe: "RecipeType", **kwargs) -> "ExportRecipe":

export/stages.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -110,9 +111,11 @@ def __init__(
110111
aten_transform_passes: Optional[
111112
List[Callable[[str, ExportedProgram], ExportedProgram]]
112113
] = None,
114+
strict=True,
113115
) -> None:
114116
super().__init__()
115117
self._aten_transform_passes = aten_transform_passes
118+
self.strict = strict
116119

117120
@property
118121
def stage_type(self) -> str:
@@ -147,7 +150,7 @@ def run(self, artifact: PipelineArtifact) -> None:
147150
model,
148151
example_inputs[method_name][0],
149152
dynamic_shapes=method_dynamic_shapes,
150-
strict=True,
153+
strict=self.strict,
151154
)
152155

153156
# Apply pre-edge transform passes if available

0 commit comments

Comments
 (0)