Skip to content

Commit 206672f

Browse files
committed
refactor
Signed-off-by: Isotr0py <[email protected]>
1 parent 814f795 commit 206672f

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

gguf-py/gguf/gguf_reader.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import os
99
import sys
10+
import struct
1011
from collections import OrderedDict
1112
from typing import Any, Literal, NamedTuple, TypeVar, Union
1213

@@ -130,11 +131,15 @@ class GGUFReader:
130131
}
131132

132133
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
133-
self.data = np.memmap(path, mode = mode)
134+
file_mode = "rb+" if mode == 'r+' else 'rb'
135+
self.mode = mode
136+
self.data = open(path, mode=file_mode)
137+
self.mmap = np.memmap(self.data, mode = mode)
134138
offs = 0
135139

136140
# Check for GGUF magic
137-
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
141+
self.data.seek(offs)
142+
if struct.unpack("<I", self.data.read(4))[0] != GGUF_MAGIC:
138143
raise ValueError('GGUF magic invalid')
139144
offs += 4
140145

@@ -192,13 +197,22 @@ def get_tensor(self, idx: int) -> ReaderTensor:
192197
return self.tensors[idx]
193198

194199
def _get(
195-
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
200+
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, use_mmap: bool = False
196201
) -> npt.NDArray[Any]:
197202
count = int(count)
198-
itemsize = int(np.empty([], dtype = dtype).itemsize)
203+
dtype = np.dtype(dtype).newbyteorder(override_order or self.byte_order)
204+
itemsize = dtype.itemsize
199205
end_offs = offset + itemsize * count
200-
arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
201-
return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
206+
if self.mode != "r" or use_mmap:
207+
data = (
208+
self.mmap[offset:end_offs]
209+
.view(dtype)[:count]
210+
)
211+
self.data.seek(end_offs)
212+
else:
213+
self.data.seek(offset)
214+
data = np.frombuffer(self.data.read(itemsize * count), dtype = dtype)
215+
return data
202216

203217
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
204218
if field.name in self.fields:
@@ -212,8 +226,17 @@ def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
212226
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
213227

214228
def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
215-
slen = self._get(offset, np.uint64)
216-
return slen, self._get(offset + 8, np.uint8, slen[0])
229+
if self.mode != "r":
230+
slen = self._get(offset, np.uint64)
231+
sdata = self._get(offset + 8, np.uint8, slen.item())
232+
else:
233+
# This is faster to return a read-only str structure with less seek calling.
234+
self.data.seek(offset)
235+
u64 = np.dtype(np.uint64).newbyteorder(self.byte_order)
236+
u8 = np.dtype(np.uint8).newbyteorder(self.byte_order)
237+
slen = np.frombuffer(self.data.read(8), dtype=u64)
238+
sdata = np.frombuffer(self.data.read(slen.item()), dtype=u8)
239+
return slen, sdata
217240

218241
def _get_field_parts(
219242
self, orig_offs: int, raw_type: int,
@@ -224,8 +247,8 @@ def _get_field_parts(
224247
types.append(gtype)
225248
# Handle strings.
226249
if gtype == GGUFValueType.STRING:
227-
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
228-
size = sum(int(part.nbytes) for part in sparts)
250+
sparts: list[npt.NDArray[Any]] = self._get_str(offs)
251+
size = 8 + sparts[0].item()
229252
return size, sparts, [1], types
230253
# Check if it's a simple scalar type.
231254
nptype = self.gguf_scalar_to_np.get(gtype)
@@ -235,9 +258,9 @@ def _get_field_parts(
235258
# Handle arrays.
236259
if gtype == GGUFValueType.ARRAY:
237260
raw_itype = self._get(offs, np.uint32)
238-
offs += int(raw_itype.nbytes)
261+
offs = self.data.tell()
239262
alen = self._get(offs, np.uint64)
240-
offs += int(alen.nbytes)
263+
offs = self.data.tell()
241264
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
242265
data_idxs: list[int] = []
243266
# FIXME: Handle multi-dimensional arrays properly instead of flattening
@@ -258,23 +281,23 @@ def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
258281

259282
# Get Tensor Name
260283
name_len, name_data = self._get_str(offs)
261-
offs += int(name_len.nbytes + name_data.nbytes)
284+
offs = self.data.tell()
262285

263286
# Get Tensor Dimensions Count
264287
n_dims = self._get(offs, np.uint32)
265-
offs += int(n_dims.nbytes)
288+
offs = self.data.tell()
266289

267290
# Get Tensor Dimension Array
268291
dims = self._get(offs, np.uint64, n_dims[0])
269-
offs += int(dims.nbytes)
292+
offs = self.data.tell()
270293

271294
# Get Tensor Encoding Scheme Type
272295
raw_dtype = self._get(offs, np.uint32)
273-
offs += int(raw_dtype.nbytes)
296+
offs = self.data.tell()
274297

275298
# Get Tensor Offset
276299
offset_tensor = self._get(offs, np.uint64)
277-
offs += int(offset_tensor.nbytes)
300+
offs = self.data.tell()
278301

279302
return ReaderField(
280303
orig_offs,
@@ -287,9 +310,9 @@ def _build_fields(self, offs: int, count: int) -> int:
287310
for _ in range(count):
288311
orig_offs = offs
289312
kv_klen, kv_kdata = self._get_str(offs)
290-
offs += int(kv_klen.nbytes + kv_kdata.nbytes)
313+
offs = self.data.tell()
291314
raw_kv_type = self._get(offs, np.uint32)
292-
offs += int(raw_kv_type.nbytes)
315+
offs = self.data.tell()
293316
parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
294317
idxs_offs = len(parts)
295318
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
@@ -308,7 +331,7 @@ def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderFie
308331
tensor_fields = []
309332
for _ in range(count):
310333
field = self._get_tensor_info_field(offs)
311-
offs += sum(int(part.nbytes) for part in field.parts)
334+
offs = self.data.tell()
312335
tensor_fields.append(field)
313336
return offs, tensor_fields
314337

@@ -361,7 +384,7 @@ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
361384
n_elements = n_elems,
362385
n_bytes = n_bytes,
363386
data_offset = data_offs,
364-
data = self._get(data_offs, item_type, item_count).reshape(np_dims),
387+
data = self._get(data_offs, item_type, item_count, use_mmap=True).reshape(np_dims),
365388
field = field,
366389
))
367390
self.tensors = tensors

0 commit comments

Comments
 (0)