77# pyre-strict
88
99import hashlib
10- import math
1110from dataclasses import dataclass
1211
1312# from dataclasses import dataclass
14- from typing import Dict , List , Optional
13+ from typing import Dict , Optional , Sequence
1514
16-
17- @dataclass
18- class BufferEntry :
19- """A class to hold the buffer entries for serialization.
20-
21- Attributes:
22- buffer: The buffer bytes.
23- alignment: The alignment of the buffer.
24- """
25-
26- buffer : bytes
27- alignment : int
15+ from executorch .exir ._serialize .data_serializer import DataEntry
16+ from executorch .exir .tensor_layout import TensorLayout
2817
2918
3019@dataclass
3120class NamedDataStoreOutput :
3221 """
33- Holds named data for serialization.
22+ Holds named data for serialization. Note: a DataEntry contains the index into
23+ `buffers`, the alignment and a tensor layout, if applicable.
3424
3525 Attributes:
3626 buffers: A list of unique buffer entries.
3727 pte_data: Contains data that is stored inside the PTE file. A mapping from
38- {key: buffer_index }.
28+ {key: DataEntry }.
3929 external_data: Contains data that is stored external to the PTE. A mapping
40- from {filename: {key: buffer_index }}.
30+ from {filename: {key: DataEntry }}.
4131 """
4232
43- buffers : List [ BufferEntry ]
44- pte_data : Dict [str , int ]
45- external_data : Dict [str , Dict [str , int ]]
33+ buffers : Sequence [ bytes ]
34+ pte_data : Dict [str , DataEntry ]
35+ external_data : Dict [str , Dict [str , DataEntry ]]
4636
4737
4838class NamedDataStore :
@@ -61,12 +51,12 @@ class NamedDataStore:
6151 """
6252
6353 # List of unique blobs.
64- buffers : List [ BufferEntry ]
65- # Named data stored inside the PTE file. Map of {key: buffer_index }.
66- pte_data : Dict [str , int ]
54+ buffers : Sequence [ bytes ]
55+ # Named data stored inside the PTE file. Map of {key: DataEntry }.
56+ pte_data : Dict [str , DataEntry ]
6757 # Named data stored outside of the PTE file.
68- # Map of {filename: {key: buffer_index }}.
69- external_data : Dict [str , Dict [str , int ]]
58+ # Map of {filename: {key: DataEntry }}.
59+ external_data : Dict [str , Dict [str , DataEntry ]]
7060
7161 # Cache of the data hash for deduplication.
7262 # Use a hash instead of the data as a key because a sha256 collision is
@@ -93,7 +83,8 @@ def _add_named_data_to_map(
9383 key : str ,
9484 data : bytes ,
9585 alignment : int ,
96- local_key_to_buffer_idx : Dict [str , int ],
86+ local_key_to_buffer_idx : Dict [str , DataEntry ],
87+ tensor_layout : Optional [TensorLayout ] = None ,
9788 ) -> None :
9889 """
9990 Add data to a map and update the alignment. Ensure that the key-data
@@ -116,33 +107,31 @@ def _add_named_data_to_map(
116107
117108 # Check if the key exists.
118109 buffer_idx = self .key_to_buffer_idx .get (key , - 1 )
119- if buffer_idx != - 1 :
120- # If the key exists, the corresponding data must be identical.
121- if self .data_hash_to_buffer_idx .get (hashed , - 1 ) != buffer_idx :
122- raise ValueError (
123- f"Duplicate key { key } with different data. "
124- f"Existing data: { self .buffers [buffer_idx ].buffer } . "
125- f"New data: { data } ."
126- )
127- self .buffers [buffer_idx ].alignment = math .lcm (
128- self .buffers [buffer_idx ].alignment , alignment
110+ # If the key exists, the corresponding data must be identical.
111+ if (
112+ buffer_idx != - 1
113+ and self .data_hash_to_buffer_idx .get (hashed , - 1 ) != buffer_idx
114+ ):
115+ raise ValueError (
116+ f"Duplicate key { key } with different data. "
117+ f"Existing data: { self .buffers [buffer_idx ]} . "
118+ f"New data: { data } ."
129119 )
130120 else :
131121 # Key doesn't exist; check if the data exists.
132122 buffer_idx = self .data_hash_to_buffer_idx .get (hashed , - 1 )
133- if buffer_idx != - 1 :
134- # The data exists; update the alignment.
135- self .buffers [buffer_idx ].alignment = math .lcm (
136- self .buffers [buffer_idx ].alignment , alignment
137- )
138- else :
123+ if buffer_idx == - 1 :
139124 # The data doesn't exist; add it to the data store.
140125 buffer_idx = len (self .buffers )
141- self .buffers .append (BufferEntry ( data , alignment ) )
126+ self .buffers .append (data )
142127 self .data_hash_to_buffer_idx [hashed ] = buffer_idx
143128
144129 # Add key to the map and the key cache.
145- local_key_to_buffer_idx [key ] = buffer_idx
130+ local_key_to_buffer_idx [key ] = DataEntry (
131+ buffer_index = buffer_idx ,
132+ alignment = alignment ,
133+ tensor_layout = tensor_layout ,
134+ )
146135 self .key_to_buffer_idx [key ] = buffer_idx
147136
148137 def add_named_data (
@@ -151,6 +140,7 @@ def add_named_data(
151140 data : bytes ,
152141 alignment : Optional [int ] = 1 ,
153142 external_tag : Optional [str ] = None ,
143+ tensor_layout : Optional [TensorLayout ] = None ,
154144 ) -> None :
155145 """
156146 Adds a named blob to the NamedDataStore.
@@ -159,6 +149,7 @@ def add_named_data(
159149 data (bytes): Bytes being requested to be serialized.
160150 alignment (int): alignment for bytes to be serialized with.
161151 external (Optional[str]): the external filename that this data is saved to.
152+ tensor_layout (Optional[TensorLayout]): layout of the tensor, if applicable.
162153 Raises:
163154 ValueError: when the key exists in the store, and corresponding data
164155 is different.
@@ -171,10 +162,16 @@ def add_named_data(
171162 raise ValueError (f"Alignment must be greater than 0, received { alignment } ." )
172163
173164 if external_tag is None :
174- self ._add_named_data_to_map (key , data , alignment , self .pte_data )
165+ self ._add_named_data_to_map (
166+ key , data , alignment , self .pte_data , tensor_layout
167+ )
175168 else :
176169 self ._add_named_data_to_map (
177- key , data , alignment , self .external_data .setdefault (external_tag , {})
170+ key ,
171+ data ,
172+ alignment ,
173+ self .external_data .setdefault (external_tag , {}),
174+ tensor_layout ,
178175 )
179176
180177 def get_named_data_store_output (self ) -> NamedDataStoreOutput :
@@ -192,19 +189,22 @@ def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
192189 data is different between them.
193190 """
194191 # Merge the pte_data.
195- for key , buffer_idx in other .pte_data .items ():
192+ for key , data_entry in other .pte_data .items ():
196193 self .add_named_data (
197194 key ,
198- other .buffers [buffer_idx ].buffer ,
199- other .buffers [buffer_idx ].alignment ,
195+ other .buffers [data_entry .buffer_index ],
196+ data_entry .alignment ,
197+ external_tag = None ,
198+ tensor_layout = data_entry .tensor_layout ,
200199 )
201200
202201 # Merge the external_data.
203- for filename , key_to_buffer_idx in other .external_data .items ():
204- for key , buffer_idx in key_to_buffer_idx .items ():
202+ for filename , key_to_data_entry in other .external_data .items ():
203+ for key , data_entry in key_to_data_entry .items ():
205204 self .add_named_data (
206205 key ,
207- other .buffers [buffer_idx ]. buffer ,
208- other . buffers [ buffer_idx ] .alignment ,
206+ other .buffers [data_entry . buffer_index ] ,
207+ data_entry .alignment ,
209208 external_tag = filename ,
209+ tensor_layout = data_entry .tensor_layout ,
210210 )
0 commit comments