Skip to content

Commit 706971b

Browse files
wj-Mcatsijunhe
andauthored
[Bug fixes] fix load-torch with different prefix key (#4141)
* fix load-torch scripts * fix tensor reading * add serialization error for seek_by_string Co-authored-by: Sijun He <[email protected]>
1 parent dd792d6 commit 706971b

File tree

2 files changed

+71
-14
lines changed

2 files changed

+71
-14
lines changed

paddlenlp/utils/serialization.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from __future__ import annotations
1616

1717
import io
18+
import os
1819
import pickle
1920
from functools import lru_cache
2021

2122
import numpy as np
23+
from _io import BufferedReader
2224

2325
MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30
2426

@@ -38,6 +40,12 @@ def __repr__(self):
3840
return f"size: {self.size} key: {self.key}, nbytes: {self.nbytes}, dtype: {self.dtype}"
3941

4042

43+
class SerializationError(Exception):
44+
"""Exception for serialization"""
45+
46+
pass
47+
48+
4149
@lru_cache(maxsize=None)
4250
def _storage_type_to_dtype_to_map():
4351
"""convert storage type to numpy dtype"""
@@ -123,6 +131,52 @@ def dumpy(*args, **kwarsg):
123131
return None
124132

125133

134+
def seek_by_string(file_handler: BufferedReader, string: str, file_size: int) -> int:
135+
"""seek the index of file-handler with target words
136+
137+
Args:
138+
file_handler (BufferedReader): file handler
139+
string (str): the specific string in the file
140+
file_size (int): size of file
141+
142+
Returns:
143+
int: end index of target string
144+
"""
145+
word_index = 0
146+
word_bytes = string.encode("latin")
147+
empty_byte = "".encode("latin")
148+
149+
while word_index < len(string) and file_handler.tell() < file_size:
150+
content = file_handler.read(1)
151+
if content == empty_byte:
152+
break
153+
154+
if word_bytes[word_index] == content[0]:
155+
word_index += 1
156+
else:
157+
word_index = 0
158+
159+
if file_handler.tell() >= file_size - 1:
160+
raise SerializationError(f"can't find the find the target string<{string}> in the file")
161+
return file_handler.tell()
162+
163+
164+
def read_prefix_key(file_handler: BufferedReader, file_size: int):
165+
"""read the prefix key in model weight file, eg: archive/pytorch_model
166+
167+
Args:
168+
file_handler (BufferedReader): file handler
169+
fiel_size (_type_): size of file
170+
171+
Returns:
172+
_type_: _description_
173+
"""
174+
end_index = seek_by_string(file_handler, "data.pkl", file_size)
175+
file_handler.seek(MZ_ZIP_LOCAL_DIR_HEADER_SIZE)
176+
prefix_key = file_handler.read(end_index - MZ_ZIP_LOCAL_DIR_HEADER_SIZE - len("/data.pkl"))
177+
return prefix_key
178+
179+
126180
def load_torch(path: str, **pickle_load_args):
127181
"""
128182
load torch weight file with the following steps:
@@ -142,8 +196,6 @@ def load_torch(path: str, **pickle_load_args):
142196
# 1. load the structure of pytorch weight file
143197
def persistent_load_stage1(saved_id):
144198
assert isinstance(saved_id, tuple)
145-
print(saved_id)
146-
147199
data = saved_id[1:]
148200
storage_type, key, _, numel = data
149201
dtype = storage_type.dtype
@@ -173,21 +225,20 @@ def extract_maybe_dict(result):
173225
metadata = sorted(metadata, key=lambda x: x.key)
174226
# 3. parse the tensor of pytorch weight file
175227
stage1_key_to_tensor = {}
228+
content_size = os.stat(path).st_size
176229
with open(path, "rb") as file_handler:
230+
prefix_key = read_prefix_key(file_handler, content_size).decode("latin")
177231
file_handler.seek(pre_offset)
232+
178233
for tensor_meta in metadata:
179234
key = tensor_meta.key
180235
# eg: archive/data/1FB
181-
filename_with_fb = len(f"archive/data/{key}") + 2
182-
183-
# skip the fix position to read tensor data
184-
# `MZ_ZIP_LOCAL_DIR_HEADER_SIZE` is from: https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.cc#L186
185-
# `16` is the fixed characters size from binary file.
186-
# `filename_with_fb` is the length of dynamic data key name
187-
file_handler.seek(MZ_ZIP_LOCAL_DIR_HEADER_SIZE + 16 + filename_with_fb, 1)
236+
filename = f"{prefix_key}/data/{key}"
237+
seek_by_string(file_handler, filename, content_size)
238+
file_handler.seek(2, 1)
188239

189240
padding_offset = np.frombuffer(file_handler.read(2)[:1], dtype=np.uint8)[0]
190-
file_handler.read(padding_offset)
241+
file_handler.seek(padding_offset, 1)
191242

192243
# save the tensor info in result to re-use memory
193244
stage1_key_to_tensor[key] = np.frombuffer(

tests/utils/test_serialization.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from parameterized import parameterized
2222

2323
from paddlenlp.utils import load_torch
24-
from tests.testing_utils import require_package, slow
24+
from tests.testing_utils import require_package
2525

2626

2727
class SerializationTest(TestCase):
@@ -54,14 +54,20 @@ def test_simple_load(self, dtype: str):
5454
torch_data[key].numpy(),
5555
)
5656

57-
@slow
5857
@require_package("torch")
59-
def test_load_bert_model(self):
58+
@parameterized.expand(
59+
[
60+
"hf-internal-testing/tiny-random-codegen",
61+
"hf-internal-testing/tiny-random-Data2VecTextModel",
62+
"hf-internal-testing/tiny-random-SwinModel",
63+
]
64+
)
65+
def test_load_bert_model(self, repo_id):
6066
import torch
6167

6268
with tempfile.TemporaryDirectory() as tempdir:
6369
weight_file = hf_hub_download(
64-
repo_id="hf-internal-testing/tiny-random-codegen",
70+
repo_id=repo_id,
6571
filename="pytorch_model.bin",
6672
cache_dir=tempdir,
6773
library_name="PaddleNLP",

0 commit comments

Comments
 (0)