|
7 | 7 |
|
8 | 8 | import logging
|
9 | 9 |
|
10 |
| -import os |
11 | 10 | from collections import Counter
|
12 | 11 | from pprint import pformat
|
13 | 12 | from typing import (
|
|
42 | 41 | )
|
43 | 42 | from executorch.backends.arm.test.runner_utils import (
|
44 | 43 | dbg_tosa_fb_to_json,
|
45 |
| - get_elf_path, |
46 | 44 | get_output_quantization_params,
|
47 |
| - get_target_board, |
48 |
| - run_target, |
49 | 45 | TosaReferenceModelDispatch,
|
50 | 46 | )
|
51 | 47 |
|
52 | 48 | from executorch.backends.arm.test.tester.analyze_output_utils import (
|
53 | 49 | dump_error_output,
|
54 | 50 | print_error_diffs,
|
55 | 51 | )
|
| 52 | +from executorch.backends.arm.test.tester.serialize import Serialize |
56 | 53 | from executorch.backends.arm.tosa import TosaSpecification
|
57 | 54 | from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
|
58 | 55 | from executorch.backends.arm.tosa.mapping import extract_tensor_meta
|
|
90 | 87 |
|
91 | 88 | from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec
|
92 | 89 | from torch.fx import Graph
|
93 |
| -from torch.utils._pytree import tree_flatten |
94 | 90 |
|
95 | 91 |
|
96 | 92 | logger = logging.getLogger(__name__)
|
@@ -179,43 +175,6 @@ def run(
|
179 | 175 | )
|
180 | 176 |
|
181 | 177 |
|
182 |
| -class Serialize(tester.Serialize): |
183 |
| - def __init__(self, compile_spec: ArmCompileSpec, timeout): |
184 |
| - super().__init__() |
185 |
| - self.timeout = timeout |
186 |
| - self.executorch_program_manager: ExecutorchProgramManager | None |
187 |
| - self.compile_spec = compile_spec |
188 |
| - |
189 |
| - def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: |
190 |
| - super().run(artifact, inputs) |
191 |
| - # Keep the entire ExecutorchProgramManager for execution. |
192 |
| - self.executorch_program_manager = artifact |
193 |
| - |
194 |
| - def run_artifact(self, inputs): |
195 |
| - if self.executorch_program_manager is None: |
196 |
| - raise RuntimeError( |
197 |
| - "Tried running artifact from Serialize stage without running the stage." |
198 |
| - ) |
199 |
| - inputs_flattened, _ = tree_flatten(inputs) |
200 |
| - intermediate_path = self.compile_spec.get_intermediate_path() |
201 |
| - target_board = get_target_board(self.compile_spec) |
202 |
| - elf_path = get_elf_path(target_board) |
203 |
| - |
204 |
| - if not os.path.exists(elf_path): |
205 |
| - raise FileNotFoundError( |
206 |
| - f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?" |
207 |
| - ) |
208 |
| - |
209 |
| - return run_target( |
210 |
| - self.executorch_program_manager, |
211 |
| - inputs_flattened, |
212 |
| - intermediate_path, |
213 |
| - target_board, |
214 |
| - elf_path, |
215 |
| - self.timeout, |
216 |
| - ) |
217 |
| - |
218 |
| - |
219 | 178 | class ToExecutorch(tester.ToExecutorch):
|
220 | 179 | def run_artifact(self, inputs):
|
221 | 180 | with TosaReferenceModelDispatch():
|
@@ -419,7 +378,11 @@ def serialize(
|
419 | 378 | self, serialize_stage: Optional[Serialize] = None, timeout: int = 480
|
420 | 379 | ):
|
421 | 380 | if serialize_stage is None:
|
422 |
| - serialize_stage = Serialize(self.compile_spec, timeout) |
| 381 | + serialize_stage = Serialize( |
| 382 | + compile_spec=self.compile_spec, |
| 383 | + module=self.original_module, |
| 384 | + timeout=timeout, |
| 385 | + ) |
423 | 386 | assert (
|
424 | 387 | self.compile_spec.get_intermediate_path() is not None
|
425 | 388 | ), "Can't dump serialized file when compile specs do not contain an artifact path."
|
|
0 commit comments