Skip to content

Commit cf20725

Browse files
committed
gguf-py: stream dtype casts (default 64 MiB), add writer-side path + debug logs
• New helper: gguf/stream_cast.py with write_cast(fp, src_arr, dst_dtype, chunk_mb) that writes src.astype(dst) in fixed-size chunks to cap peak RSS. • lazy.py: • tag LazyNumpyTensor.astype() results (_gguf_stream_cast, _gguf_stream_cast_dtype) • tofile() streams via write_cast when the node is a pure dtype cast; otherwise falls back. • env vars: GGUF_CAST_CHUNK_MB (default 64) and GGUF_STREAM_LOG (opt-in diagnostics). • gguf_writer.py: call write_cast directly when the tensor is a tagged pure cast. This keeps the benefit even if future changes bypass tofile() / use multi-threaded writes. • Alignment: preserve data_alignment by padding before/after writes. • Repro notes: Ubuntu 24.04 / Python 3.12 / NumPy 2.1; bloom-560m FP16→F32 conversion shows peak RSS reductions when chunking (e.g., 256→64→32→16 MiB) with small runtime trade-offs at smaller chunks. macOS run logs confirm [gguf-stream] activation as well. • Scope & limitations: only pure dtype casts; MoE stacking / complex transforms fall back. • Future work (separate RFC/PR): “chunked lazy tensors” and file-range tracking compatible with multi-threaded writes.
1 parent 5eb22a4 commit cf20725

File tree

3 files changed

+203
-56
lines changed

3 files changed

+203
-56
lines changed

gguf-py/gguf/gguf_writer.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from string import ascii_letters, digits
1515

1616
import numpy as np
17+
from .stream_cast import write_cast
1718

1819
from .constants import (
1920
GGUF_DEFAULT_ALIGNMENT,
@@ -33,6 +34,9 @@
3334

3435
logger = logging.getLogger(__name__)
3536

37+
def _stream_log(msg: str) -> None:
38+
if os.environ.get("GGUF_STREAM_LOG"):
39+
print(f"[gguf-writer] {msg}", flush=True)
3640

3741
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
3842

@@ -411,12 +415,43 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
411415
fout = self.fout[file_id]
412416

413417
# pop the first tensor info
414-
# TODO: cleaner way to get the first key
415418
first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
416419
ti = self.tensors[file_id].pop(first_tensor_name)
417420
assert ti.nbytes == tensor.nbytes
418421

422+
# align to data_alignment before writing tensor data
419423
self.write_padding(fout, fout.tell())
424+
425+
# --- writer-side streaming for pure dtype casts (survives when tofile() isn't used) ---
426+
try:
427+
if getattr(tensor, "_gguf_stream_cast", False):
428+
# derive the pre-cast lazy source from the astype() node args
429+
base = getattr(tensor, "_args", None)
430+
base = base[0] if base else None
431+
432+
src_arr = None
433+
try:
434+
src_arr = type(base).to_eager(base)
435+
except Exception:
436+
src_arr = None
437+
438+
if isinstance(src_arr, np.ndarray):
439+
try:
440+
mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "64") or "64")
441+
except Exception:
442+
mb = 64
443+
tgt_dtype = getattr(tensor, "_gguf_stream_cast_dtype", src_arr.dtype)
444+
_stream_log(f"writer: streaming cast (chunk={mb} MiB) dst={tgt_dtype} shape={getattr(tensor, 'shape', '?')}")
445+
write_cast(fout, src_arr, tgt_dtype, mb)
446+
self.write_padding(fout, ti.nbytes)
447+
self.state = WriterState.WEIGHTS
448+
return
449+
except Exception:
450+
# fall back to normal path on any unexpected issue
451+
pass
452+
# ---------------------------------------------------------------------------------------
453+
454+
# Fallback: rely on the object’s own tofile() (handles lazy or eager)
420455
tensor.tofile(fout)
421456
self.write_padding(fout, tensor.nbytes)
422457

@@ -452,8 +487,46 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
452487
# relying on the fact that Python dicts preserve insertion order (since 3.7)
453488
for ti in tensors.values():
454489
assert ti.tensor is not None # can only iterate once over the tensors
455-
assert ti.tensor.nbytes == ti.nbytes
456-
ti.tensor.tofile(fout)
490+
obj = ti.tensor
491+
assert obj.nbytes == ti.nbytes
492+
493+
# Try writer-side streaming for pure dtype casts
494+
streamed = False
495+
try:
496+
if getattr(obj, "_gguf_stream_cast", False):
497+
# derive the pre-cast lazy source from the astype() node args
498+
base = getattr(obj, "_args", None)
499+
base = base[0] if base else None
500+
501+
src_arr = None
502+
try:
503+
src_arr = type(base).to_eager(base)
504+
except Exception:
505+
src_arr = None
506+
507+
if isinstance(src_arr, np.ndarray):
508+
try:
509+
mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "64") or "64")
510+
except Exception:
511+
mb = 64
512+
tgt_dtype = getattr(obj, "_gguf_stream_cast_dtype", src_arr.dtype)
513+
_stream_log(f"writer: streaming cast (chunk={mb} MiB) dst={tgt_dtype} shape={getattr(obj, 'shape', '?')}")
514+
write_cast(fout, src_arr, tgt_dtype, mb)
515+
streamed = True
516+
except Exception:
517+
streamed = False # fall back below on any issue
518+
519+
if streamed:
520+
if shard_bar is not None:
521+
shard_bar.update(ti.nbytes)
522+
if bar is not None:
523+
bar.update(ti.nbytes)
524+
self.write_padding(fout, ti.nbytes)
525+
ti.tensor = None
526+
continue
527+
528+
# Fallback: object’s tofile()
529+
obj.tofile(fout)
457530
if shard_bar is not None:
458531
shard_bar.update(ti.nbytes)
459532
if bar is not None:

gguf-py/gguf/lazy.py

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -224,76 +224,70 @@ def tofile(self, *args, **kwargs):
224224
# TODO: __array_function__
225225

226226
# --- begin low-memory streaming for dtype casts ------------------------------
227-
# This block monkey-patches LazyNumpyTensor.astype and .tofile so that pure
228-
# dtype-cast nodes are streamed to disk in chunks, avoiding large RAM spikes.
229-
# Tunable via env: GGUF_CAST_CHUNK_MB (MB per chunk, default 256).
227+
# Tunable via env:
228+
# GGUF_CAST_CHUNK_MB (MiB per chunk; default 64)
229+
# GGUF_STREAM_LOG (set to any non-empty value to print diagnostics)
230+
231+
import sys
232+
from .stream_cast import write_cast # sibling helper
230233

231234
try:
232235
_LAZY_ORIG_ASTYPE = getattr(LazyNumpyTensor, "astype")
233236
_LAZY_ORIG_TOFILE = getattr(LazyNumpyTensor, "tofile")
234237
except NameError:
235-
# If class names change in the future, fail noisily.
236238
raise RuntimeError("Expected LazyNumpyTensor to be defined above this block")
237239

240+
def _slog(msg: str) -> None:
241+
if os.environ.get("GGUF_STREAM_LOG"):
242+
print(f"[gguf-stream] {msg}", file=sys.stdout, flush=True)
243+
238244
def _gguf_streaming_astype(self, dtype, *args, **kwargs):
239-
"""Wrap the original .astype to tag the new lazy node as a streamable cast."""
245+
"""Tag astype results so writer/tofile can stream them later."""
240246
tgt = np.dtype(dtype)
241247
out = _LAZY_ORIG_ASTYPE(self, dtype, *args, **kwargs)
242-
# mark the node so tofile() can detect and stream it
248+
# mark as streamable and record target dtype
243249
setattr(out, "_gguf_stream_cast", True)
244250
setattr(out, "_gguf_stream_cast_dtype", tgt)
251+
# NEW: record the *source* lazy tensor for writer-side streaming
252+
setattr(out, "_gguf_stream_cast_src", self)
253+
_slog(f"mark streamable astype: src={getattr(self._meta,'dtype','?')} -> dst={tgt}")
245254
return out
246255

247-
def _gguf_stream_cast_write(fout, src, tgt_dtype, chunk_elems):
248-
"""Write src.astype(tgt_dtype) to fout in chunks, capping peak RAM."""
249-
flat = src.reshape(-1)
250-
n = flat.size
251-
start = 0
252-
mv = memoryview # local for speed
253-
while start < n:
254-
end = min(start + chunk_elems, n)
255-
# copy=False prevents an extra temporary when NumPy can reuse buffers
256-
chunk = flat[start:end].astype(tgt_dtype, copy=False)
257-
fout.write(mv(chunk).tobytes())
258-
start = end
259-
260-
def _gguf_streaming_tofile(self, fout):
261-
"""
262-
If this lazy node represents a pure dtype cast, stream it in chunks.
263-
Otherwise, fall back to the original behavior (materialize then write).
264-
"""
265-
if getattr(self, "_gguf_stream_cast", False):
266-
# The original astype stored the source object as the first arg
256+
def _gguf_streaming_tofile(self, fout, *args, **kwargs):
257+
"""If this lazy tensor is a pure dtype-cast, stream in chunks; else fallback."""
258+
if not getattr(self, "_gguf_stream_cast", False):
259+
return _LAZY_ORIG_TOFILE(self, fout, *args, **kwargs)
260+
261+
# default chunk size: 64 MiB (can override via GGUF_CAST_CHUNK_MB)
262+
try:
263+
mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "64") or "64")
264+
except Exception:
265+
mb = 64
266+
mb = max(1, mb)
267+
268+
# Prefer the explicitly tagged source lazy tensor if present (step 2)
269+
base = getattr(self, "_gguf_stream_cast_src", None)
270+
271+
# Fallback to first arg (older astype behavior) if not tagged
272+
if base is None:
267273
base = getattr(self, "_args", None)
268274
base = base[0] if base else None
269275

270-
# Try to obtain an eager ndarray for the source
276+
try:
277+
src_arr = LazyNumpyTensor.to_eager(base)
278+
except Exception:
271279
src_arr = None
272-
try:
273-
src_arr = LazyNumpyTensor.to_eager(base)
274-
except Exception:
275-
pass
276-
277-
if isinstance(src_arr, np.ndarray):
278-
# chunk size in MB; default 256 if unset/invalid
279-
try:
280-
mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "256") or "256")
281-
except Exception:
282-
mb = 256
283-
mb = max(1, mb)
284-
285-
tgt_dtype = getattr(self, "_gguf_stream_cast_dtype", src_arr.dtype)
286-
# choose element count so that each chunk ~mb megabytes of the *larger* itemsize
287-
itemsize = max(src_arr.dtype.itemsize, np.dtype(tgt_dtype).itemsize)
288-
chunk_elems = max(1, (mb * 1024 * 1024) // itemsize)
289-
290-
_gguf_stream_cast_write(fout, src_arr, tgt_dtype, chunk_elems)
291-
return
292-
293-
# Fallback: original behavior
294-
_LAZY_ORIG_TOFILE(self, fout)
295-
296-
# Install the monkey patches
280+
281+
if not isinstance(src_arr, np.ndarray):
282+
_slog("fallback to original tofile: cannot materialize source to ndarray")
283+
return _LAZY_ORIG_TOFILE(self, fout, *args, **kwargs)
284+
285+
tgt = getattr(self, "_gguf_stream_cast_dtype", src_arr.dtype)
286+
_slog(f"streaming cast write: chunk={mb} MiB; dst={tgt}; shape={getattr(self._meta,'shape','?')}")
287+
write_cast(fout, src_arr, tgt, mb)
288+
return
289+
290+
# Install patches
297291
LazyNumpyTensor.astype = _gguf_streaming_astype
298292
LazyNumpyTensor.tofile = _gguf_streaming_tofile
299-
# --- end low-memory streaming for dtype casts --------------------------------
293+
# --- end low-memory streaming for dtype casts ------------------------------

gguf-py/gguf/stream_cast.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# gguf-py/gguf/stream_cast.py
2+
from __future__ import annotations
3+
from typing import Any
4+
import os
5+
import sys
6+
import numpy as np
7+
8+
9+
def _slog(msg: str) -> None:
10+
"""Conditional debug logging when GGUF_STREAM_LOG is set."""
11+
if os.environ.get("GGUF_STREAM_LOG"):
12+
print(f"[gguf-stream] {msg}", file=sys.stdout, flush=True)
13+
14+
15+
def _chunk_elems(src_dtype: np.dtype, dst_dtype: np.dtype, chunk_mb: int) -> int:
16+
"""
17+
Compute how many elements to process per chunk so that each chunk is
18+
approximately `chunk_mb` MiB of the *larger* of the source/destination itemsize.
19+
"""
20+
try:
21+
mb = int(chunk_mb)
22+
except Exception:
23+
mb = 64
24+
mb = max(1, mb)
25+
item = max(np.dtype(src_dtype).itemsize, np.dtype(dst_dtype).itemsize)
26+
return max(1, (mb * 1024 * 1024) // item)
27+
28+
29+
def write_cast(fout, src: np.ndarray, dst_dtype: Any, chunk_mb: int) -> None:
30+
"""
31+
Stream `src.astype(dst_dtype)` to `fout` in fixed-size chunks to cap peak RSS.
32+
33+
This matches the import site in lazy.py:
34+
from .stream_cast import write_cast
35+
36+
Parameters
37+
----------
38+
fout : file-like object
39+
Open file handle to write bytes to (must support .write()).
40+
src : np.ndarray
41+
Source ndarray to be converted and streamed.
42+
dst_dtype : Any
43+
Target dtype (anything accepted by np.dtype).
44+
chunk_mb : int
45+
Desired chunk size in MiB (will be clamped to >= 1).
46+
"""
47+
dst = np.dtype(dst_dtype)
48+
flat = src.reshape(-1)
49+
n = flat.size
50+
ce = _chunk_elems(flat.dtype, dst, chunk_mb)
51+
52+
_slog(
53+
f"write_cast: src={flat.dtype} -> dst={dst}; n={n}; "
54+
f"chunk={max(1, int(chunk_mb))} MiB; elems/chunk={ce}"
55+
)
56+
57+
start = 0
58+
# local binding for tiny speed bump
59+
mv = memoryview
60+
while start < n:
61+
end = min(start + ce, n)
62+
# copy=False avoids an extra tmp when possible
63+
chunk = flat[start:end].astype(dst, copy=False)
64+
fout.write(mv(chunk).tobytes())
65+
start = end
66+
67+
68+
# Optional: writer-side API that accepts chunk size in bytes (used by gguf_writer)
69+
def stream_write(fout, src_arr: np.ndarray, dst_dtype: Any, chunk_bytes: int) -> None:
70+
"""
71+
Same as write_cast, but the chunk size is given in bytes.
72+
Kept for compatibility with earlier helper drafts.
73+
"""
74+
if not isinstance(chunk_bytes, int) or chunk_bytes <= 0:
75+
chunk_mb = 64
76+
else:
77+
# round bytes to MiB for the element count helper
78+
chunk_mb = max(1, chunk_bytes // (1024 * 1024))
79+
80+
write_cast(fout, src_arr, dst_dtype, chunk_mb)

0 commit comments

Comments
 (0)