Skip to content

Commit 5eb22a4

Browse files
committed
gguf-py: stream dtype casts during GGUF write; add GGUF_CAST_CHUNK_MB env knob; remove duplicate imports
1 parent fbef0fa commit 5eb22a4

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

gguf-py/gguf/lazy.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Callable
66

77
import numpy as np
8+
import os
89
from numpy.typing import DTypeLike
910

1011

@@ -221,3 +222,78 @@ def tofile(self, *args, **kwargs):
221222
return eager.tofile(*args, **kwargs)
222223

223224
# TODO: __array_function__
225+
226+
# --- begin low-memory streaming for dtype casts ------------------------------
227+
# This block monkey-patches LazyNumpyTensor.astype and .tofile so that pure
228+
# dtype-cast nodes are streamed to disk in chunks, avoiding large RAM spikes.
229+
# Tunable via env: GGUF_CAST_CHUNK_MB (MB per chunk, default 256).
230+
231+
try:
232+
_LAZY_ORIG_ASTYPE = getattr(LazyNumpyTensor, "astype")
233+
_LAZY_ORIG_TOFILE = getattr(LazyNumpyTensor, "tofile")
234+
except NameError:
235+
# If class names change in the future, fail noisily.
236+
raise RuntimeError("Expected LazyNumpyTensor to be defined above this block")
237+
238+
def _gguf_streaming_astype(self, dtype, *args, **kwargs):
239+
"""Wrap the original .astype to tag the new lazy node as a streamable cast."""
240+
tgt = np.dtype(dtype)
241+
out = _LAZY_ORIG_ASTYPE(self, dtype, *args, **kwargs)
242+
# mark the node so tofile() can detect and stream it
243+
setattr(out, "_gguf_stream_cast", True)
244+
setattr(out, "_gguf_stream_cast_dtype", tgt)
245+
return out
246+
247+
def _gguf_stream_cast_write(fout, src, tgt_dtype, chunk_elems):
248+
"""Write src.astype(tgt_dtype) to fout in chunks, capping peak RAM."""
249+
flat = src.reshape(-1)
250+
n = flat.size
251+
start = 0
252+
mv = memoryview # local for speed
253+
while start < n:
254+
end = min(start + chunk_elems, n)
255+
# copy=False prevents an extra temporary when NumPy can reuse buffers
256+
chunk = flat[start:end].astype(tgt_dtype, copy=False)
257+
fout.write(mv(chunk).tobytes())
258+
start = end
259+
260+
def _gguf_streaming_tofile(self, fout):
261+
"""
262+
If this lazy node represents a pure dtype cast, stream it in chunks.
263+
Otherwise, fall back to the original behavior (materialize then write).
264+
"""
265+
if getattr(self, "_gguf_stream_cast", False):
266+
# The original astype stored the source object as the first arg
267+
base = getattr(self, "_args", None)
268+
base = base[0] if base else None
269+
270+
# Try to obtain an eager ndarray for the source
271+
src_arr = None
272+
try:
273+
src_arr = LazyNumpyTensor.to_eager(base)
274+
except Exception:
275+
pass
276+
277+
if isinstance(src_arr, np.ndarray):
278+
# chunk size in MB; default 256 if unset/invalid
279+
try:
280+
mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "256") or "256")
281+
except Exception:
282+
mb = 256
283+
mb = max(1, mb)
284+
285+
tgt_dtype = getattr(self, "_gguf_stream_cast_dtype", src_arr.dtype)
286+
# choose element count so that each chunk ~mb megabytes of the *larger* itemsize
287+
itemsize = max(src_arr.dtype.itemsize, np.dtype(tgt_dtype).itemsize)
288+
chunk_elems = max(1, (mb * 1024 * 1024) // itemsize)
289+
290+
_gguf_stream_cast_write(fout, src_arr, tgt_dtype, chunk_elems)
291+
return
292+
293+
# Fallback: original behavior
294+
_LAZY_ORIG_TOFILE(self, fout)
295+
296+
# Install the monkey patches
297+
LazyNumpyTensor.astype = _gguf_streaming_astype
298+
LazyNumpyTensor.tofile = _gguf_streaming_tofile
299+
# --- end low-memory streaming for dtype casts --------------------------------

0 commit comments

Comments
 (0)