77import logging
88import os
99import sys
10+ import struct
1011from collections import OrderedDict
1112from typing import Any , Literal , NamedTuple , TypeVar , Union
1213
@@ -130,11 +131,15 @@ class GGUFReader:
130131 }
131132
132133 def __init__ (self , path : os .PathLike [str ] | str , mode : Literal ['r' , 'r+' , 'c' ] = 'r' ):
133- self .data = np .memmap (path , mode = mode )
134+ file_mode = "rb+" if mode == 'r+' else 'rb'
135+ self .mode = mode
136+ self .data = open (path , mode = file_mode )
137+ self .mmap = np .memmap (self .data , mode = mode )
134138 offs = 0
135139
136140 # Check for GGUF magic
137- if self ._get (offs , np .uint32 , override_order = '<' )[0 ] != GGUF_MAGIC :
141+ self .data .seek (offs )
142+ if struct .unpack ("<I" , self .data .read (4 ))[0 ] != GGUF_MAGIC :
138143 raise ValueError ('GGUF magic invalid' )
139144 offs += 4
140145
@@ -192,13 +197,22 @@ def get_tensor(self, idx: int) -> ReaderTensor:
192197 return self .tensors [idx ]
193198
194199 def _get (
195- self , offset : int , dtype : npt .DTypeLike , count : int = 1 , override_order : None | Literal ['I' , 'S' , '<' ] = None ,
200+ self , offset : int , dtype : npt .DTypeLike , count : int = 1 , override_order : None | Literal ['I' , 'S' , '<' ] = None , use_mmap : bool = False
196201 ) -> npt .NDArray [Any ]:
197202 count = int (count )
198- itemsize = int (np .empty ([], dtype = dtype ).itemsize )
203+ dtype = np .dtype (dtype ).newbyteorder (override_order or self .byte_order )
204+ itemsize = dtype .itemsize
199205 end_offs = offset + itemsize * count
200- arr = self .data [offset :end_offs ].view (dtype = dtype )[:count ]
201- return arr .view (arr .dtype .newbyteorder (self .byte_order if override_order is None else override_order ))
206+ if self .mode != "r" or use_mmap :
207+ data = (
208+ self .mmap [offset :end_offs ]
209+ .view (dtype )[:count ]
210+ )
211+ self .data .seek (end_offs )
212+ else :
213+ self .data .seek (offset )
214+ data = np .frombuffer (self .data .read (itemsize * count ), dtype = dtype )
215+ return data
202216
203217 def _push_field (self , field : ReaderField , skip_sum : bool = False ) -> int :
204218 if field .name in self .fields :
@@ -212,8 +226,17 @@ def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
212226 return 0 if skip_sum else sum (int (part .nbytes ) for part in field .parts )
213227
214228 def _get_str (self , offset : int ) -> tuple [npt .NDArray [np .uint64 ], npt .NDArray [np .uint8 ]]:
215- slen = self ._get (offset , np .uint64 )
216- return slen , self ._get (offset + 8 , np .uint8 , slen [0 ])
229+ if self .mode != "r" :
230+ slen = self ._get (offset , np .uint64 )
231+ sdata = self ._get (offset + 8 , np .uint8 , slen .item ())
232+ else :
233+ # This is faster to return a read-only str structure with less seek calling.
234+ self .data .seek (offset )
235+ u64 = np .dtype (np .uint64 ).newbyteorder (self .byte_order )
236+ u8 = np .dtype (np .uint8 ).newbyteorder (self .byte_order )
237+ slen = np .frombuffer (self .data .read (8 ), dtype = u64 )
238+ sdata = np .frombuffer (self .data .read (slen .item ()), dtype = u8 )
239+ return slen , sdata
217240
218241 def _get_field_parts (
219242 self , orig_offs : int , raw_type : int ,
@@ -224,8 +247,8 @@ def _get_field_parts(
224247 types .append (gtype )
225248 # Handle strings.
226249 if gtype == GGUFValueType .STRING :
227- sparts : list [npt .NDArray [Any ]] = list ( self ._get_str (offs ) )
228- size = sum ( int ( part . nbytes ) for part in sparts )
250+ sparts : list [npt .NDArray [Any ]] = self ._get_str (offs )
251+ size = 8 + sparts [ 0 ]. item ( )
229252 return size , sparts , [1 ], types
230253 # Check if it's a simple scalar type.
231254 nptype = self .gguf_scalar_to_np .get (gtype )
@@ -235,9 +258,9 @@ def _get_field_parts(
235258 # Handle arrays.
236259 if gtype == GGUFValueType .ARRAY :
237260 raw_itype = self ._get (offs , np .uint32 )
238- offs += int ( raw_itype . nbytes )
261+ offs = self . data . tell ( )
239262 alen = self ._get (offs , np .uint64 )
240- offs += int ( alen . nbytes )
263+ offs = self . data . tell ( )
241264 aparts : list [npt .NDArray [Any ]] = [raw_itype , alen ]
242265 data_idxs : list [int ] = []
243266 # FIXME: Handle multi-dimensional arrays properly instead of flattening
@@ -258,23 +281,23 @@ def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
258281
259282 # Get Tensor Name
260283 name_len , name_data = self ._get_str (offs )
261- offs += int ( name_len . nbytes + name_data . nbytes )
284+ offs = self . data . tell ( )
262285
263286 # Get Tensor Dimensions Count
264287 n_dims = self ._get (offs , np .uint32 )
265- offs += int ( n_dims . nbytes )
288+ offs = self . data . tell ( )
266289
267290 # Get Tensor Dimension Array
268291 dims = self ._get (offs , np .uint64 , n_dims [0 ])
269- offs += int ( dims . nbytes )
292+ offs = self . data . tell ( )
270293
271294 # Get Tensor Encoding Scheme Type
272295 raw_dtype = self ._get (offs , np .uint32 )
273- offs += int ( raw_dtype . nbytes )
296+ offs = self . data . tell ( )
274297
275298 # Get Tensor Offset
276299 offset_tensor = self ._get (offs , np .uint64 )
277- offs += int ( offset_tensor . nbytes )
300+ offs = self . data . tell ( )
278301
279302 return ReaderField (
280303 orig_offs ,
@@ -287,9 +310,9 @@ def _build_fields(self, offs: int, count: int) -> int:
287310 for _ in range (count ):
288311 orig_offs = offs
289312 kv_klen , kv_kdata = self ._get_str (offs )
290- offs += int ( kv_klen . nbytes + kv_kdata . nbytes )
313+ offs = self . data . tell ( )
291314 raw_kv_type = self ._get (offs , np .uint32 )
292- offs += int ( raw_kv_type . nbytes )
315+ offs = self . data . tell ( )
293316 parts : list [npt .NDArray [Any ]] = [kv_klen , kv_kdata , raw_kv_type ]
294317 idxs_offs = len (parts )
295318 field_size , field_parts , field_idxs , field_types = self ._get_field_parts (offs , raw_kv_type [0 ])
@@ -308,7 +331,7 @@ def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderFie
308331 tensor_fields = []
309332 for _ in range (count ):
310333 field = self ._get_tensor_info_field (offs )
311- offs += sum ( int ( part . nbytes ) for part in field . parts )
334+ offs = self . data . tell ( )
312335 tensor_fields .append (field )
313336 return offs , tensor_fields
314337
@@ -361,7 +384,7 @@ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
361384 n_elements = n_elems ,
362385 n_bytes = n_bytes ,
363386 data_offset = data_offs ,
364- data = self ._get (data_offs , item_type , item_count ).reshape (np_dims ),
387+ data = self ._get (data_offs , item_type , item_count , use_mmap = True ).reshape (np_dims ),
365388 field = field ,
366389 ))
367390 self .tensors = tensors
0 commit comments