99import copy
1010import io
1111import logging
12+ import os
1213from typing import Any , Dict , List , Optional , Sequence , Set , TextIO , Tuple , Union
1314
1415import torch
1516import torch ._export
16- from executorch .exir ._serialize import _serialize_pte_binary
1717from executorch .exir ._serialize ._cord import Cord
18+ from executorch .exir ._serialize ._serialize import serialize_for_executorch
19+ from executorch .exir ._serialize .data_serializer import DataSerializer
1820from executorch .exir ._warnings import experimental
1921from executorch .exir .backend .backend_api import to_backend
2022from executorch .exir .backend .partitioner import Partitioner
5961 EXIREdgeDialectVerifier ,
6062 get_aten_verifier ,
6163)
64+ from executorch .extension .flat_tensor .serialize .serialize import FlatTensorSerializer
6265from torch ._export .passes import ReplaceViewOpsWithViewCopyOpsPass
6366from torch .export import ExportedProgram
6467from torch .export ._remove_auto_functionalized_pass import (
@@ -497,23 +500,31 @@ def __init__(
497500 )
498501 self .exported_program = exir_exported_program .exported_program
499502 self ._pte_data : Optional [Cord ] = None
503+ self ._tensor_data : Optional [Dict [str , Cord ]] = None
500504 self ._buffer : Optional [bytes ] = None
501505 self ._emitter_output : Optional [EmitterOutput ] = None
502506 self ._emit_stacktrace : bool = emit_stacktrace
503507 self ._extract_delegate_segments : bool = extract_delegate_segments
504508 self ._segment_alignment : int = segment_alignment
505509 self ._constant_tensor_alignment : Optional [int ] = constant_tensor_alignment
506510 self ._delegate_alignment : Optional [int ] = delegate_alignment
511+ self ._data_serializer : DataSerializer = FlatTensorSerializer ()
512+
513+ def _get_emitter_output (self ) -> EmitterOutput :
514+ if self ._emitter_output is None :
515+ self ._emitter_output = emit_program (
516+ self .exported_program , self ._emit_stacktrace
517+ )
518+ return self ._emitter_output
507519
508520 def _get_pte_data (self ) -> Cord :
509521 if self ._pte_data is None :
510- self ._pte_data = _serialize_pte_binary (
511- program = self .program ,
512- extract_delegate_segments = self ._extract_delegate_segments ,
513- segment_alignment = self ._segment_alignment ,
514- constant_tensor_alignment = self ._constant_tensor_alignment ,
515- delegate_alignment = self ._delegate_alignment ,
522+ self ._pte_data , self ._tensor_data = serialize_for_executorch (
523+ self ._get_emitter_output (),
524+ ExecutorchBackendConfig (),
525+ self ._data_serializer ,
516526 )
527+ assert self ._pte_data is not None
517528 return self ._pte_data
518529
519530 @property
@@ -532,11 +543,7 @@ def buffer(self) -> bytes:
532543
533544 @property
534545 def program (self ) -> Program :
535- if self ._emitter_output is None :
536- self ._emitter_output = emit_program (
537- self .exported_program , self ._emit_stacktrace
538- )
539- return self ._emitter_output .program
546+ return self ._get_emitter_output ().program
540547
541548 @property
542549 def debug_handle_map (self ) -> Dict [int , Union [int , List [int ]]]:
@@ -571,6 +578,17 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
571578 """
572579 self ._get_pte_data ().write_to_file (open_file )
573580
581+ def write_tensor_data_to_file (self , outdir ) -> None :
582+ """
583+ Writes the serialized ExecuTorch data files to the directory at `outdir`.
584+ """
585+ assert self ._tensor_data is not None
586+ # pyre-ignore[16]: `Optional` has no attribute `items`.
587+ for filename , cord in self ._tensor_data .items ():
588+ with open (os .path .join (outdir , f"{ filename } .ptd" ), "wb" ) as f :
589+ logging .info (f"Writing data file to { filename } .ptd" )
590+ cord .write_to_file (f )
591+
574592
575593def _get_aten_to_edge_passes (config : EdgeCompileConfig ):
576594 # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
@@ -1453,13 +1471,9 @@ def __init__(
14531471 )
14541472
14551473 # Serialize emitter output, ready to be written to a file.
1456- self ._pte_data : Cord = _serialize_pte_binary (
1457- program = self ._emitter_output .program ,
1458- mutable_data = self ._emitter_output .mutable_data ,
1459- extract_delegate_segments = backend_config .extract_delegate_segments ,
1460- segment_alignment = backend_config .segment_alignment ,
1461- constant_tensor_alignment = backend_config .constant_tensor_alignment ,
1462- delegate_alignment = backend_config .delegate_alignment ,
1474+ self ._data_serializer = FlatTensorSerializer ()
1475+ self ._pte_data , self ._tensor_data = serialize_for_executorch (
1476+ self ._emitter_output , ExecutorchBackendConfig (), self ._data_serializer
14631477 )
14641478 self ._buffer : Optional [bytes ] = None
14651479
@@ -1542,6 +1556,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
15421556 """
15431557 self ._pte_data .write_to_file (open_file )
15441558
1559+ def write_tensor_data_to_file (self , outdir ) -> None :
1560+ """
1561+ Writes the serialized ExecuTorch data files to the directory at `outdir`.
1562+ """
1563+ assert self ._tensor_data is not None
1564+ for filename , cord in self ._tensor_data .items ():
1565+ with open (os .path .join (outdir , f"{ filename } .ptd" ), "wb" ) as f :
1566+ logging .info (f"Writing data file to { filename } " )
1567+ cord .write_to_file (f )
1568+
15451569 def save (self , path : str ) -> None :
15461570 """
15471571 Saves the serialized ExecuTorch binary to the file at `path`.
0 commit comments