|  | 
| 13 | 13 | 
 | 
| 14 | 14 | import torch | 
| 15 | 15 | import torch._export | 
| 16 |  | -from executorch.exir._serialize import _serialize_pte_binary | 
| 17 | 16 | from executorch.exir._serialize._cord import Cord | 
|  | 17 | +from executorch.exir._serialize._serialize import serialize | 
|  | 18 | +from executorch.exir._serialize.data_serializer import DataSerializer | 
| 18 | 19 | from executorch.exir._warnings import experimental | 
| 19 | 20 | from executorch.exir.backend.backend_api import to_backend | 
| 20 | 21 | from executorch.exir.backend.partitioner import Partitioner | 
|  | 
| 56 | 57 |     EXIREdgeDialectVerifier, | 
| 57 | 58 |     get_aten_verifier, | 
| 58 | 59 | ) | 
|  | 60 | +from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer | 
| 59 | 61 | from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass | 
| 60 | 62 | from torch.export import ExportedProgram | 
| 61 | 63 | from torch.export._remove_auto_functionalized_pass import ( | 
| @@ -494,23 +496,23 @@ def __init__( | 
| 494 | 496 |             ) | 
| 495 | 497 |         self.exported_program = exir_exported_program.exported_program | 
| 496 | 498 |         self._pte_data: Optional[Cord] = None | 
|  | 499 | +        self._data_files: Optional[Dict[str, Cord]] = None | 
| 497 | 500 |         self._buffer: Optional[bytes] = None | 
| 498 | 501 |         self._emitter_output: Optional[EmitterOutput] = None | 
| 499 | 502 |         self._emit_stacktrace: bool = emit_stacktrace | 
| 500 | 503 |         self._extract_delegate_segments: bool = extract_delegate_segments | 
| 501 | 504 |         self._segment_alignment: int = segment_alignment | 
| 502 | 505 |         self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment | 
| 503 | 506 |         self._delegate_alignment: Optional[int] = delegate_alignment | 
|  | 507 | +        self._data_serializer: DataSerializer = FlatTensorSerializer() | 
| 504 | 508 | 
 | 
| 505 | 509 |     def _get_pte_data(self) -> Cord: | 
| 506 | 510 |         if self._pte_data is None: | 
| 507 |  | -            self._pte_data = _serialize_pte_binary( | 
| 508 |  | -                program=self.program, | 
| 509 |  | -                extract_delegate_segments=self._extract_delegate_segments, | 
| 510 |  | -                segment_alignment=self._segment_alignment, | 
| 511 |  | -                constant_tensor_alignment=self._constant_tensor_alignment, | 
| 512 |  | -                delegate_alignment=self._delegate_alignment, | 
|  | 511 | +            assert self._emitter_output is not None | 
|  | 512 | +            self._pte_data, self._data_files = serialize( | 
|  | 513 | +                self._emitter_output, ExecutorchBackendConfig(), self._data_serializer | 
| 513 | 514 |             ) | 
|  | 515 | +        assert self._pte_data is not None | 
| 514 | 516 |         return self._pte_data | 
| 515 | 517 | 
 | 
| 516 | 518 |     @property | 
| @@ -1443,14 +1445,11 @@ def __init__( | 
| 1443 | 1445 |             self._config_methods, | 
| 1444 | 1446 |         ) | 
| 1445 | 1447 | 
 | 
|  | 1448 | +        self._data_serializer = FlatTensorSerializer() | 
|  | 1449 | + | 
| 1446 | 1450 |         # Serialize emitter output, ready to be written to a file. | 
| 1447 |  | -        self._pte_data: Cord = _serialize_pte_binary( | 
| 1448 |  | -            program=self._emitter_output.program, | 
| 1449 |  | -            mutable_data=self._emitter_output.mutable_data, | 
| 1450 |  | -            extract_delegate_segments=backend_config.extract_delegate_segments, | 
| 1451 |  | -            segment_alignment=backend_config.segment_alignment, | 
| 1452 |  | -            constant_tensor_alignment=backend_config.constant_tensor_alignment, | 
| 1453 |  | -            delegate_alignment=backend_config.delegate_alignment, | 
|  | 1451 | +        self._pte_data, self._data_files = serialize( | 
|  | 1452 | +            self._emitter_output, ExecutorchBackendConfig(), self._data_serializer | 
| 1454 | 1453 |         ) | 
| 1455 | 1454 |         self._buffer: Optional[bytes] = None | 
| 1456 | 1455 | 
 | 
| @@ -1532,3 +1531,8 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None: | 
| 1532 | 1531 |         reducing the peak memory usage. | 
| 1533 | 1532 |         """ | 
| 1534 | 1533 |         self._pte_data.write_to_file(open_file) | 
|  | 1534 | + | 
|  | 1535 | +        for filename, cord in self._data_files.items(): | 
|  | 1536 | +            filename = filename + ".ptd" | 
|  | 1537 | +            with open(filename, "wb") as file: | 
|  | 1538 | +                cord.write_to_file(file) | 
0 commit comments