| 
13 | 13 | 
 
  | 
14 | 14 | import torch  | 
15 | 15 | import torch._export  | 
16 |  | -from executorch.exir._serialize import _serialize_pte_binary  | 
17 | 16 | from executorch.exir._serialize._cord import Cord  | 
 | 17 | +from executorch.exir._serialize._serialize import serialize  | 
 | 18 | +from executorch.exir._serialize.data_serializer import DataSerializer  | 
18 | 19 | from executorch.exir._warnings import experimental  | 
19 | 20 | from executorch.exir.backend.backend_api import to_backend  | 
20 | 21 | from executorch.exir.backend.partitioner import Partitioner  | 
 | 
56 | 57 |     EXIREdgeDialectVerifier,  | 
57 | 58 |     get_aten_verifier,  | 
58 | 59 | )  | 
 | 60 | +from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer  | 
59 | 61 | from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass  | 
60 | 62 | from torch.export import ExportedProgram  | 
61 | 63 | from torch.export._remove_auto_functionalized_pass import (  | 
@@ -494,23 +496,23 @@ def __init__(  | 
494 | 496 |             )  | 
495 | 497 |         self.exported_program = exir_exported_program.exported_program  | 
496 | 498 |         self._pte_data: Optional[Cord] = None  | 
 | 499 | +        self._data_files: Optional[Dict[str, Cord]] = None  | 
497 | 500 |         self._buffer: Optional[bytes] = None  | 
498 | 501 |         self._emitter_output: Optional[EmitterOutput] = None  | 
499 | 502 |         self._emit_stacktrace: bool = emit_stacktrace  | 
500 | 503 |         self._extract_delegate_segments: bool = extract_delegate_segments  | 
501 | 504 |         self._segment_alignment: int = segment_alignment  | 
502 | 505 |         self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment  | 
503 | 506 |         self._delegate_alignment: Optional[int] = delegate_alignment  | 
 | 507 | +        self._data_serializer: DataSerializer = FlatTensorSerializer()  | 
504 | 508 | 
 
  | 
505 | 509 |     def _get_pte_data(self) -> Cord:  | 
506 | 510 |         if self._pte_data is None:  | 
507 |  | -            self._pte_data = _serialize_pte_binary(  | 
508 |  | -                program=self.program,  | 
509 |  | -                extract_delegate_segments=self._extract_delegate_segments,  | 
510 |  | -                segment_alignment=self._segment_alignment,  | 
511 |  | -                constant_tensor_alignment=self._constant_tensor_alignment,  | 
512 |  | -                delegate_alignment=self._delegate_alignment,  | 
 | 511 | +            assert self._emitter_output is not None  | 
 | 512 | +            self._pte_data, self._data_files = serialize(  | 
 | 513 | +                self._emitter_output, ExecutorchBackendConfig(), self._data_serializer  | 
513 | 514 |             )  | 
 | 515 | +        assert self._pte_data is not None  | 
514 | 516 |         return self._pte_data  | 
515 | 517 | 
 
  | 
516 | 518 |     @property  | 
@@ -1443,14 +1445,11 @@ def __init__(  | 
1443 | 1445 |             self._config_methods,  | 
1444 | 1446 |         )  | 
1445 | 1447 | 
 
  | 
 | 1448 | +        self._data_serializer = FlatTensorSerializer()  | 
 | 1449 | + | 
1446 | 1450 |         # Serialize emitter output, ready to be written to a file.  | 
1447 |  | -        self._pte_data: Cord = _serialize_pte_binary(  | 
1448 |  | -            program=self._emitter_output.program,  | 
1449 |  | -            mutable_data=self._emitter_output.mutable_data,  | 
1450 |  | -            extract_delegate_segments=backend_config.extract_delegate_segments,  | 
1451 |  | -            segment_alignment=backend_config.segment_alignment,  | 
1452 |  | -            constant_tensor_alignment=backend_config.constant_tensor_alignment,  | 
1453 |  | -            delegate_alignment=backend_config.delegate_alignment,  | 
 | 1451 | +        self._pte_data, self._data_files = serialize(  | 
 | 1452 | +            self._emitter_output, ExecutorchBackendConfig(), self._data_serializer  | 
1454 | 1453 |         )  | 
1455 | 1454 |         self._buffer: Optional[bytes] = None  | 
1456 | 1455 | 
 
  | 
@@ -1532,3 +1531,8 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:  | 
1532 | 1531 |         reducing the peak memory usage.  | 
1533 | 1532 |         """  | 
1534 | 1533 |         self._pte_data.write_to_file(open_file)  | 
 | 1534 | + | 
 | 1535 | +        for filename, cord in self._data_files.items():  | 
 | 1536 | +            filename = filename + ".ptd"  | 
 | 1537 | +            with open(filename, "wb") as file:  | 
 | 1538 | +                cord.write_to_file(file)  | 
0 commit comments