|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | 2 | # All rights reserved. |
| 3 | +# Copyright 2025 Arm Limited and/or its affiliates. |
3 | 4 | # |
4 | 5 | # This source code is licensed under the BSD-style license found in the |
5 | 6 | # LICENSE file in the root directory of this source tree. |
|
10 | 11 | import io |
11 | 12 | import logging |
12 | 13 | import os |
13 | | -from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union |
| 14 | +from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Type, Union |
14 | 15 |
|
15 | 16 | import torch |
16 | 17 | import torch._export |
|
66 | 67 | ) |
67 | 68 | from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer |
68 | 69 | from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass |
| 70 | +from torch._export.verifier import Verifier |
69 | 71 | from torch.export import ExportedProgram |
70 | 72 | from torch.export._remove_auto_functionalized_pass import ( |
71 | 73 | unsafe_remove_auto_functionalized_pass, |
@@ -213,21 +215,29 @@ def _transform(self, *passes: PassType) -> "ExportedProgram": |
213 | 215 | if transformed_gm is self.graph_module and not res.modified: |
214 | 216 | return self |
215 | 217 |
|
| 218 | + return _update_exported_program_graph_module(self, transformed_gm) |
| 219 | + |
| 220 | + |
| 221 | +def _update_exported_program_graph_module( |
| 222 | + exported_program: ExportedProgram, |
| 223 | + gm: torch.fx.GraphModule, |
| 224 | + override_verifiers: None | list[Type[Verifier]] = None, |
| 225 | +) -> "ExportedProgram": |
216 | 226 | transformed_ep = ExportedProgram( |
217 | | - root=transformed_gm, |
218 | | - graph=transformed_gm.graph, |
| 227 | + root=gm, |
| 228 | + graph=gm.graph, |
219 | 229 | graph_signature=_get_updated_graph_signature( |
220 | | - self.graph_signature, transformed_gm |
| 230 | + exported_program.graph_signature, gm |
221 | 231 | ), |
222 | | - state_dict=self.state_dict, |
223 | | - range_constraints=_get_updated_range_constraints(transformed_gm), |
224 | | - module_call_graph=copy.deepcopy(self._module_call_graph), |
225 | | - example_inputs=self.example_inputs, |
226 | | - constants=self.constants, |
227 | | - verifiers=[self.verifier], |
| 232 | + state_dict=exported_program.state_dict, |
| 233 | + range_constraints=_get_updated_range_constraints(gm), |
| 234 | + module_call_graph=copy.deepcopy(exported_program._module_call_graph), |
| 235 | + example_inputs=exported_program.example_inputs, |
| 236 | + constants=exported_program.constants, |
| 237 | + verifiers=override_verifiers or [exported_program.verifier], |
228 | 238 | ) |
229 | | - transformed_ep.graph_module.meta.update(self.graph_module.meta) |
230 | | - transformed_ep.graph_module.meta.update(res.graph_module.meta) |
| 239 | + transformed_ep.graph_module.meta.update(exported_program.graph_module.meta) |
| 240 | + transformed_ep.graph_module.meta.update(gm.meta) |
231 | 241 | return transformed_ep |
232 | 242 |
|
233 | 243 |
|
|
0 commit comments