15
15
from __future__ import annotations
16
16
17
17
import io
18
+ import os
18
19
import pickle
19
20
from functools import lru_cache
20
21
21
22
import numpy as np
23
+ from _io import BufferedReader
22
24
23
25
MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30
24
26
@@ -38,6 +40,12 @@ def __repr__(self):
38
40
return f"size: { self .size } key: { self .key } , nbytes: { self .nbytes } , dtype: { self .dtype } "
39
41
40
42
43
+ class SerializationError (Exception ):
44
+ """Exception for serialization"""
45
+
46
+ pass
47
+
48
+
41
49
@lru_cache (maxsize = None )
42
50
def _storage_type_to_dtype_to_map ():
43
51
"""convert storage type to numpy dtype"""
@@ -123,6 +131,52 @@ def dumpy(*args, **kwarsg):
123
131
return None
124
132
125
133
134
+ def seek_by_string (file_handler : BufferedReader , string : str , file_size : int ) -> int :
135
+ """seek the index of file-handler with target words
136
+
137
+ Args:
138
+ file_handler (BufferedReader): file handler
139
+ string (str): the specific string in the file
140
+ file_size (int): size of file
141
+
142
+ Returns:
143
+ int: end index of target string
144
+ """
145
+ word_index = 0
146
+ word_bytes = string .encode ("latin" )
147
+ empty_byte = "" .encode ("latin" )
148
+
149
+ while word_index < len (string ) and file_handler .tell () < file_size :
150
+ content = file_handler .read (1 )
151
+ if content == empty_byte :
152
+ break
153
+
154
+ if word_bytes [word_index ] == content [0 ]:
155
+ word_index += 1
156
+ else :
157
+ word_index = 0
158
+
159
+ if file_handler .tell () >= file_size - 1 :
160
+ raise SerializationError (f"can't find the find the target string<{ string } > in the file" )
161
+ return file_handler .tell ()
162
+
163
+
164
+ def read_prefix_key (file_handler : BufferedReader , file_size : int ):
165
+ """read the prefix key in model weight file, eg: archive/pytorch_model
166
+
167
+ Args:
168
+ file_handler (BufferedReader): file handler
169
+ fiel_size (_type_): size of file
170
+
171
+ Returns:
172
+ _type_: _description_
173
+ """
174
+ end_index = seek_by_string (file_handler , "data.pkl" , file_size )
175
+ file_handler .seek (MZ_ZIP_LOCAL_DIR_HEADER_SIZE )
176
+ prefix_key = file_handler .read (end_index - MZ_ZIP_LOCAL_DIR_HEADER_SIZE - len ("/data.pkl" ))
177
+ return prefix_key
178
+
179
+
126
180
def load_torch (path : str , ** pickle_load_args ):
127
181
"""
128
182
load torch weight file with the following steps:
@@ -142,8 +196,6 @@ def load_torch(path: str, **pickle_load_args):
142
196
# 1. load the structure of pytorch weight file
143
197
def persistent_load_stage1 (saved_id ):
144
198
assert isinstance (saved_id , tuple )
145
- print (saved_id )
146
-
147
199
data = saved_id [1 :]
148
200
storage_type , key , _ , numel = data
149
201
dtype = storage_type .dtype
@@ -173,21 +225,20 @@ def extract_maybe_dict(result):
173
225
metadata = sorted (metadata , key = lambda x : x .key )
174
226
# 3. parse the tensor of pytorch weight file
175
227
stage1_key_to_tensor = {}
228
+ content_size = os .stat (path ).st_size
176
229
with open (path , "rb" ) as file_handler :
230
+ prefix_key = read_prefix_key (file_handler , content_size ).decode ("latin" )
177
231
file_handler .seek (pre_offset )
232
+
178
233
for tensor_meta in metadata :
179
234
key = tensor_meta .key
180
235
# eg: archive/data/1FB
181
- filename_with_fb = len (f"archive/data/{ key } " ) + 2
182
-
183
- # skip the fix position to read tensor data
184
- # `MZ_ZIP_LOCAL_DIR_HEADER_SIZE` is from: https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.cc#L186
185
- # `16` is the fixed characters size from binary file.
186
- # `filename_with_fb` is the length of dynamic data key name
187
- file_handler .seek (MZ_ZIP_LOCAL_DIR_HEADER_SIZE + 16 + filename_with_fb , 1 )
236
+ filename = f"{ prefix_key } /data/{ key } "
237
+ seek_by_string (file_handler , filename , content_size )
238
+ file_handler .seek (2 , 1 )
188
239
189
240
padding_offset = np .frombuffer (file_handler .read (2 )[:1 ], dtype = np .uint8 )[0 ]
190
- file_handler .read (padding_offset )
241
+ file_handler .seek (padding_offset , 1 )
191
242
192
243
# save the tensor info in result to re-use memory
193
244
stage1_key_to_tensor [key ] = np .frombuffer (
0 commit comments