99import  copy 
1010import  io 
1111import  logging 
12+ import  os 
1213from  typing  import  Any , Dict , List , Optional , Sequence , Set , TextIO , Tuple , Union 
1314
1415import  torch 
1516import  torch ._export 
16- from  executorch .exir ._serialize  import  _serialize_pte_binary 
1717from  executorch .exir ._serialize ._cord  import  Cord 
18+ from  executorch .exir ._serialize ._serialize  import  serialize_for_executorch 
19+ from  executorch .exir ._serialize .data_serializer  import  DataSerializer 
1820from  executorch .exir ._warnings  import  experimental 
1921from  executorch .exir .backend .backend_api  import  to_backend 
2022from  executorch .exir .backend .partitioner  import  Partitioner 
5961    EXIREdgeDialectVerifier ,
6062    get_aten_verifier ,
6163)
64+ from  executorch .extension .flat_tensor .serialize .serialize  import  FlatTensorSerializer 
6265from  torch ._export .passes  import  ReplaceViewOpsWithViewCopyOpsPass 
6366from  torch .export  import  ExportedProgram 
6467from  torch .export ._remove_auto_functionalized_pass  import  (
@@ -497,23 +500,31 @@ def __init__(
497500            )
498501        self .exported_program  =  exir_exported_program .exported_program 
499502        self ._pte_data : Optional [Cord ] =  None 
503+         self ._tensor_data : Optional [Dict [str , Cord ]] =  None 
500504        self ._buffer : Optional [bytes ] =  None 
501505        self ._emitter_output : Optional [EmitterOutput ] =  None 
502506        self ._emit_stacktrace : bool  =  emit_stacktrace 
503507        self ._extract_delegate_segments : bool  =  extract_delegate_segments 
504508        self ._segment_alignment : int  =  segment_alignment 
505509        self ._constant_tensor_alignment : Optional [int ] =  constant_tensor_alignment 
506510        self ._delegate_alignment : Optional [int ] =  delegate_alignment 
511+         self ._data_serializer : DataSerializer  =  FlatTensorSerializer ()
512+ 
513+     def  _get_emitter_output (self ) ->  EmitterOutput :
514+         if  self ._emitter_output  is  None :
515+             self ._emitter_output  =  emit_program (
516+                 self .exported_program , self ._emit_stacktrace 
517+             )
518+         return  self ._emitter_output 
507519
508520    def  _get_pte_data (self ) ->  Cord :
509521        if  self ._pte_data  is  None :
510-             self ._pte_data  =  _serialize_pte_binary (
511-                 program = self .program ,
512-                 extract_delegate_segments = self ._extract_delegate_segments ,
513-                 segment_alignment = self ._segment_alignment ,
514-                 constant_tensor_alignment = self ._constant_tensor_alignment ,
515-                 delegate_alignment = self ._delegate_alignment ,
522+             self ._pte_data , self ._tensor_data  =  serialize_for_executorch (
523+                 self ._get_emitter_output (),
524+                 ExecutorchBackendConfig (),
525+                 self ._data_serializer ,
516526            )
527+         assert  self ._pte_data  is  not None 
517528        return  self ._pte_data 
518529
519530    @property  
@@ -532,11 +543,7 @@ def buffer(self) -> bytes:
532543
533544    @property  
534545    def  program (self ) ->  Program :
535-         if  self ._emitter_output  is  None :
536-             self ._emitter_output  =  emit_program (
537-                 self .exported_program , self ._emit_stacktrace 
538-             )
539-         return  self ._emitter_output .program 
546+         return  self ._get_emitter_output ().program 
540547
541548    @property  
542549    def  debug_handle_map (self ) ->  Dict [int , Union [int , List [int ]]]:
@@ -571,6 +578,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
571578        """ 
572579        self ._get_pte_data ().write_to_file (open_file )
573580
581+     def  write_tensor_data_to_file (self , outdir ) ->  None :
582+         """ 
583+         Writes the serialized ExecuTorch data files to the directory at `outdir`. 
584+         """ 
585+         assert  self ._tensor_data  is  not None 
586+         for  filename , cord  in  self ._tensor_data .items ():
587+             with  open (os .path .join (outdir , f"{ filename }  ), "wb" ) as  f :
588+                 logging .info (f"Writing data file to { filename }  )
589+                 cord .write_to_file (f )
590+ 
574591
575592def  _get_aten_to_edge_passes (config : EdgeCompileConfig ):
576593    # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable 
@@ -1453,13 +1470,9 @@ def __init__(
14531470        )
14541471
14551472        # Serialize emitter output, ready to be written to a file. 
1456-         self ._pte_data : Cord  =  _serialize_pte_binary (
1457-             program = self ._emitter_output .program ,
1458-             mutable_data = self ._emitter_output .mutable_data ,
1459-             extract_delegate_segments = backend_config .extract_delegate_segments ,
1460-             segment_alignment = backend_config .segment_alignment ,
1461-             constant_tensor_alignment = backend_config .constant_tensor_alignment ,
1462-             delegate_alignment = backend_config .delegate_alignment ,
1473+         self ._data_serializer  =  FlatTensorSerializer ()
1474+         self ._pte_data , self ._tensor_data  =  serialize_for_executorch (
1475+             self ._emitter_output , ExecutorchBackendConfig (), self ._data_serializer 
14631476        )
14641477        self ._buffer : Optional [bytes ] =  None 
14651478
@@ -1542,6 +1555,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
15421555        """ 
15431556        self ._pte_data .write_to_file (open_file )
15441557
1558+     def  write_tensor_data_to_file (self , outdir ) ->  None :
1559+         """ 
1560+         Writes the serialized ExecuTorch data files to the directory at `outdir`. 
1561+         """ 
1562+         assert  self ._tensor_data  is  not None 
1563+         for  filename , cord  in  self ._tensor_data .items ():
1564+             with  open (os .path .join (outdir , f"{ filename }  ), "wb" ) as  f :
1565+                 logging .info (f"Writing data file to { filename }  )
1566+                 cord .write_to_file (f )
1567+ 
15451568    def  save (self , path : str ) ->  None :
15461569        """ 
15471570        Saves the serialized ExecuTorch binary to the file at `path`. 
0 commit comments