1+ # pyright: reportInvalidTypeForm=false
12#
23# GGUF file reading/modification support. For API usage information,
34# please see the files scripts/ for some fairly simple examples.
1516
1617from .quants import quant_shape_to_byte_shape
1718
19+
1820if __name__ == "__main__" :
1921 from pathlib import Path
2022
@@ -104,7 +106,7 @@ class ReaderTensor(NamedTuple):
104106 n_elements : int
105107 n_bytes : int
106108 data_offset : int
107- data : npt . NDArray [ Any ]
109+ data : np . ndarray
108110 field : ReaderField
109111
110112
@@ -181,7 +183,7 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] =
181183 self .data_offset = offs
182184 self ._build_tensors (offs , tensors_fields )
183185
184- _DT = TypeVar ('_DT' , bound = npt . DTypeLike )
186+ _DT = TypeVar ('_DT' , bound = np . dtype [ Any ] )
185187
186188 # Fetch a key/value metadata field by key.
187189 def get_field (self , key : str ) -> Union [ReaderField , None ]:
@@ -192,8 +194,8 @@ def get_tensor(self, idx: int) -> ReaderTensor:
192194 return self .tensors [idx ]
193195
194196 def _get (
195- self , offset : int , dtype : npt . DTypeLike , count : int = 1 , override_order : None | Literal ['I' , 'S' , '<' ] = None ,
196- ) -> npt . NDArray [ Any ] :
197+ self , offset : int , dtype : np . dtype [ Any ] , count : int = 1 , override_order : None | Literal ['I' , 'S' , '<' ] = None ,
198+ ) -> np . ndarray :
197199 count = int (count )
198200 itemsize = int (np .empty ([], dtype = dtype ).itemsize )
199201 end_offs = offset + itemsize * count
@@ -213,7 +215,7 @@ def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
213215
214216 def _get_str (self , offset : int ) -> tuple [npt .NDArray [np .uint64 ], npt .NDArray [np .uint8 ]]:
215217 slen = self ._get (offset , np .uint64 )
216- return slen , self ._get (offset + 8 , np .uint8 , slen [0 ])
218+ return slen , self ._get (offset + 8 , np .uint8 , slen [0 ]. item () )
217219
218220 def _get_field_parts (
219221 self , orig_offs : int , raw_type : int ,
@@ -230,7 +232,7 @@ def _get_field_parts(
230232 # Check if it's a simple scalar type.
231233 nptype = self .gguf_scalar_to_np .get (gtype )
232234 if nptype is not None :
233- val = self ._get (offs , nptype )
235+ val = self ._get (offs , np . dtype ( nptype ) )
234236 return int (val .nbytes ), [val ], [0 ], types
235237 # Handle arrays.
236238 if gtype == GGUFValueType .ARRAY :
@@ -242,7 +244,7 @@ def _get_field_parts(
242244 data_idxs : list [int ] = []
243245 # FIXME: Handle multi-dimensional arrays properly instead of flattening
244246 for idx in range (alen [0 ]):
245- curr_size , curr_parts , curr_idxs , curr_types = self ._get_field_parts (offs , raw_itype [0 ])
247+ curr_size , curr_parts , curr_idxs , curr_types = self ._get_field_parts (offs , raw_itype [0 ]. item () )
246248 if idx == 0 :
247249 types += curr_types
248250 idxs_offs = len (aparts )
@@ -265,7 +267,7 @@ def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
265267 offs += int (n_dims .nbytes )
266268
267269 # Get Tensor Dimension Array
268- dims = self ._get (offs , np .uint64 , n_dims [0 ])
270+ dims = self ._get (offs , np .uint64 , n_dims [0 ]. item () )
269271 offs += int (dims .nbytes )
270272
271273 # Get Tensor Encoding Scheme Type
@@ -292,7 +294,7 @@ def _build_fields(self, offs: int, count: int) -> int:
292294 offs += int (raw_kv_type .nbytes )
293295 parts : list [npt .NDArray [Any ]] = [kv_klen , kv_kdata , raw_kv_type ]
294296 idxs_offs = len (parts )
295- field_size , field_parts , field_idxs , field_types = self ._get_field_parts (offs , raw_kv_type [0 ])
297+ field_size , field_parts , field_idxs , field_types = self ._get_field_parts (offs , raw_kv_type [0 ]. item () )
296298 parts += field_parts
297299 self ._push_field (ReaderField (
298300 orig_offs ,
@@ -328,28 +330,28 @@ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
328330 block_size , type_size = GGML_QUANT_SIZES [ggml_type ]
329331 n_bytes = n_elems * type_size // block_size
330332 data_offs = int (start_offs + offset_tensor [0 ])
331- item_type : npt . DTypeLike
333+ item_type : np . dtype [ Any ]
332334 if ggml_type == GGMLQuantizationType .F16 :
333335 item_count = n_elems
334- item_type = np .float16
336+ item_type = np .dtype ( np . float16 )
335337 elif ggml_type == GGMLQuantizationType .F32 :
336338 item_count = n_elems
337- item_type = np .float32
339+ item_type = np .dtype ( np . float32 )
338340 elif ggml_type == GGMLQuantizationType .F64 :
339341 item_count = n_elems
340- item_type = np .float64
342+ item_type = np .dtype ( np . float64 )
341343 elif ggml_type == GGMLQuantizationType .I8 :
342344 item_count = n_elems
343- item_type = np .int8
345+ item_type = np .dtype ( np . int8 )
344346 elif ggml_type == GGMLQuantizationType .I16 :
345347 item_count = n_elems
346- item_type = np .int16
348+ item_type = np .dtype ( np . int16 )
347349 elif ggml_type == GGMLQuantizationType .I32 :
348350 item_count = n_elems
349- item_type = np .int32
351+ item_type = np .dtype ( np . int32 )
350352 elif ggml_type == GGMLQuantizationType .I64 :
351353 item_count = n_elems
352- item_type = np .int64
354+ item_type = np .dtype ( np . int64 )
353355 else :
354356 item_count = n_bytes
355357 item_type = np .uint8
0 commit comments