|
5 | 5 | from typing import Any, Callable |
6 | 6 |
|
7 | 7 | import numpy as np |
| 8 | +import os |
8 | 9 | from numpy.typing import DTypeLike |
9 | 10 |
|
10 | 11 |
|
@@ -221,3 +222,78 @@ def tofile(self, *args, **kwargs): |
221 | 222 | return eager.tofile(*args, **kwargs) |
222 | 223 |
|
223 | 224 | # TODO: __array_function__ |
| 225 | + |
| 226 | +# --- begin low-memory streaming for dtype casts ------------------------------ |
| 227 | +# This block monkey-patches LazyNumpyTensor.astype and .tofile so that pure |
| 228 | +# dtype-cast nodes are streamed to disk in chunks, avoiding large RAM spikes. |
| 229 | +# Tunable via env: GGUF_CAST_CHUNK_MB (MB per chunk, default 256). |
| 230 | + |
| 231 | +try: |
| 232 | + _LAZY_ORIG_ASTYPE = getattr(LazyNumpyTensor, "astype") |
| 233 | + _LAZY_ORIG_TOFILE = getattr(LazyNumpyTensor, "tofile") |
| 234 | +except NameError: |
| 235 | + # If class names change in the future, fail noisily. |
| 236 | + raise RuntimeError("Expected LazyNumpyTensor to be defined above this block") |
| 237 | + |
| 238 | +def _gguf_streaming_astype(self, dtype, *args, **kwargs): |
| 239 | + """Wrap the original .astype to tag the new lazy node as a streamable cast.""" |
| 240 | + tgt = np.dtype(dtype) |
| 241 | + out = _LAZY_ORIG_ASTYPE(self, dtype, *args, **kwargs) |
| 242 | + # mark the node so tofile() can detect and stream it |
| 243 | + setattr(out, "_gguf_stream_cast", True) |
| 244 | + setattr(out, "_gguf_stream_cast_dtype", tgt) |
| 245 | + return out |
| 246 | + |
| 247 | +def _gguf_stream_cast_write(fout, src, tgt_dtype, chunk_elems): |
| 248 | + """Write src.astype(tgt_dtype) to fout in chunks, capping peak RAM.""" |
| 249 | + flat = src.reshape(-1) |
| 250 | + n = flat.size |
| 251 | + start = 0 |
| 252 | + mv = memoryview # local for speed |
| 253 | + while start < n: |
| 254 | + end = min(start + chunk_elems, n) |
| 255 | + # copy=False prevents an extra temporary when NumPy can reuse buffers |
| 256 | + chunk = flat[start:end].astype(tgt_dtype, copy=False) |
| 257 | + fout.write(mv(chunk).tobytes()) |
| 258 | + start = end |
| 259 | + |
| 260 | +def _gguf_streaming_tofile(self, fout): |
| 261 | + """ |
| 262 | + If this lazy node represents a pure dtype cast, stream it in chunks. |
| 263 | + Otherwise, fall back to the original behavior (materialize then write). |
| 264 | + """ |
| 265 | + if getattr(self, "_gguf_stream_cast", False): |
| 266 | + # The original astype stored the source object as the first arg |
| 267 | + base = getattr(self, "_args", None) |
| 268 | + base = base[0] if base else None |
| 269 | + |
| 270 | + # Try to obtain an eager ndarray for the source |
| 271 | + src_arr = None |
| 272 | + try: |
| 273 | + src_arr = LazyNumpyTensor.to_eager(base) |
| 274 | + except Exception: |
| 275 | + pass |
| 276 | + |
| 277 | + if isinstance(src_arr, np.ndarray): |
| 278 | + # chunk size in MB; default 256 if unset/invalid |
| 279 | + try: |
| 280 | + mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "256") or "256") |
| 281 | + except Exception: |
| 282 | + mb = 256 |
| 283 | + mb = max(1, mb) |
| 284 | + |
| 285 | + tgt_dtype = getattr(self, "_gguf_stream_cast_dtype", src_arr.dtype) |
| 286 | + # choose element count so that each chunk ~mb megabytes of the *larger* itemsize |
| 287 | + itemsize = max(src_arr.dtype.itemsize, np.dtype(tgt_dtype).itemsize) |
| 288 | + chunk_elems = max(1, (mb * 1024 * 1024) // itemsize) |
| 289 | + |
| 290 | + _gguf_stream_cast_write(fout, src_arr, tgt_dtype, chunk_elems) |
| 291 | + return |
| 292 | + |
| 293 | + # Fallback: original behavior |
| 294 | + _LAZY_ORIG_TOFILE(self, fout) |
| 295 | + |
| 296 | +# Install the monkey patches |
| 297 | +LazyNumpyTensor.astype = _gguf_streaming_astype |
| 298 | +LazyNumpyTensor.tofile = _gguf_streaming_tofile |
| 299 | +# --- end low-memory streaming for dtype casts -------------------------------- |
0 commit comments