Skip to content

Commit a5c74dd

Browse files
committed
reverting removal of DTypeLike changes
DTypeLike was removed because pyright was choking on it against numpy 2.2.1 Now numpy at 2.2.6 seems to have changes needed for pyright to not have issues with it. additionally, DTypeLike is accurate in a way that np.dtype[Any] was not.
1 parent d49b004 commit a5c74dd

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

gguf-py/gguf/gguf_reader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] =
181181
self.data_offset = offs
182182
self._build_tensors(offs, tensors_fields)
183183

184-
_DT = TypeVar('_DT', bound = np.dtype[Any])
184+
_DT = TypeVar('_DT', bound = npt.DTypeLike)
185185

186186
# Fetch a key/value metadata field by key.
187187
def get_field(self, key: str) -> Union[ReaderField, None]:
@@ -192,7 +192,7 @@ def get_tensor(self, idx: int) -> ReaderTensor:
192192
return self.tensors[idx]
193193

194194
def _get(
195-
self, offset: int, dtype: np.dtype[Any], count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
195+
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
196196
) -> np.ndarray:
197197
count = int(count)
198198
itemsize = int(np.empty([], dtype = dtype).itemsize)
@@ -328,7 +328,7 @@ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
328328
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
329329
n_bytes = n_elems * type_size // block_size
330330
data_offs = int(start_offs + offset_tensor[0])
331-
item_type: np.dtype[Any]
331+
item_type: npt.DTypeLike
332332
if ggml_type == GGMLQuantizationType.F16:
333333
item_count = n_elems
334334
item_type = np.dtype(np.float16)

gguf-py/gguf/lazy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Callable
66

77
import numpy as np
8+
from numpy.typing import DTypeLike
89

910

1011
logger = logging.getLogger(__name__)
@@ -106,7 +107,7 @@ def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
106107
return o
107108

108109
@classmethod
109-
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | np.dtype[Any] | tuple[np.dtype[Any], Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
110+
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
110111
def wrapped_fn(*args, **kwargs):
111112
if kwargs is None:
112113
kwargs = {}
@@ -203,7 +204,7 @@ class LazyNumpyTensor(LazyBase):
203204
shape: tuple[int, ...] # Makes the type checker happy in quants.py
204205

205206
@classmethod
206-
def meta_with_dtype_and_shape(cls, dtype: np.dtype[Any], shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
207+
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
207208
# The initial idea was to use np.nan as the fill value,
208209
# but non-float types like np.int16 can't use that.
209210
# So zero it is.

gguf-py/gguf/quants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Any, Callable, Sequence
44
from math import log2, ceil
55

6+
from numpy.typing import DTypeLike
7+
68
from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
79
from .lazy import LazyNumpyTensor
810

@@ -24,7 +26,7 @@ def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizati
2426

2527

2628
# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
27-
def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: np.dtype[Any], oshape: tuple[int, ...]) -> np.ndarray:
29+
def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
2830
rows = arr.reshape((-1, arr.shape[-1]))
2931
osize = 1
3032
for dim in oshape:

0 commit comments

Comments
 (0)