-
Notifications
You must be signed in to change notification settings - Fork 13.4k
gguf-py: reduce peak RAM during convert by streaming dtype casts #15648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||
from string import ascii_letters, digits | ||||||
|
||||||
import numpy as np | ||||||
from .stream_cast import write_cast | ||||||
|
||||||
from .constants import ( | ||||||
GGUF_DEFAULT_ALIGNMENT, | ||||||
|
@@ -33,6 +34,9 @@ | |||||
|
||||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
def _stream_log(msg: str) -> None: | ||||||
if os.environ.get("GGUF_STREAM_LOG"): | ||||||
print(f"[gguf-writer] {msg}", flush=True) | ||||||
|
||||||
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" | ||||||
|
||||||
|
@@ -411,12 +415,43 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: | |||||
fout = self.fout[file_id] | ||||||
|
||||||
# pop the first tensor info | ||||||
# TODO: cleaner way to get the first key | ||||||
first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0] | ||||||
ti = self.tensors[file_id].pop(first_tensor_name) | ||||||
assert ti.nbytes == tensor.nbytes | ||||||
|
||||||
# align to data_alignment before writing tensor data | ||||||
self.write_padding(fout, fout.tell()) | ||||||
|
||||||
# --- writer-side streaming for pure dtype casts (survives when tofile() isn't used) --- | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it would be simpler to keep using |
||||||
try: | ||||||
if getattr(tensor, "_gguf_stream_cast", False): | ||||||
# derive the pre-cast lazy source from the astype() node args | ||||||
base = getattr(tensor, "_args", None) | ||||||
base = base[0] if base else None | ||||||
|
||||||
src_arr = None | ||||||
try: | ||||||
src_arr = type(base).to_eager(base) | ||||||
except Exception: | ||||||
src_arr = None | ||||||
|
||||||
if isinstance(src_arr, np.ndarray): | ||||||
try: | ||||||
mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "64") or "64") | ||||||
except Exception: | ||||||
mb = 64 | ||||||
tgt_dtype = getattr(tensor, "_gguf_stream_cast_dtype", src_arr.dtype) | ||||||
_stream_log(f"writer: streaming cast (chunk={mb} MiB) dst={tgt_dtype} shape={getattr(tensor, 'shape', '?')}") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using
Suggested change
|
||||||
write_cast(fout, src_arr, tgt_dtype, mb) | ||||||
self.write_padding(fout, ti.nbytes) | ||||||
self.state = WriterState.WEIGHTS | ||||||
return | ||||||
except Exception: | ||||||
# fall back to normal path on any unexpected issue | ||||||
pass | ||||||
# --------------------------------------------------------------------------------------- | ||||||
|
||||||
# Fallback: rely on the object’s own tofile() (handles lazy or eager) | ||||||
tensor.tofile(fout) | ||||||
self.write_padding(fout, tensor.nbytes) | ||||||
|
||||||
|
@@ -452,8 +487,46 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: | |||||
# relying on the fact that Python dicts preserve insertion order (since 3.7) | ||||||
for ti in tensors.values(): | ||||||
assert ti.tensor is not None # can only iterate once over the tensors | ||||||
assert ti.tensor.nbytes == ti.nbytes | ||||||
ti.tensor.tofile(fout) | ||||||
obj = ti.tensor | ||||||
assert obj.nbytes == ti.nbytes | ||||||
|
||||||
# Try writer-side streaming for pure dtype casts | ||||||
streamed = False | ||||||
try: | ||||||
if getattr(obj, "_gguf_stream_cast", False): | ||||||
# derive the pre-cast lazy source from the astype() node args | ||||||
base = getattr(obj, "_args", None) | ||||||
base = base[0] if base else None | ||||||
|
||||||
src_arr = None | ||||||
try: | ||||||
src_arr = type(base).to_eager(base) | ||||||
except Exception: | ||||||
src_arr = None | ||||||
|
||||||
if isinstance(src_arr, np.ndarray): | ||||||
try: | ||||||
mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "64") or "64") | ||||||
except Exception: | ||||||
mb = 64 | ||||||
tgt_dtype = getattr(obj, "_gguf_stream_cast_dtype", src_arr.dtype) | ||||||
_stream_log(f"writer: streaming cast (chunk={mb} MiB) dst={tgt_dtype} shape={getattr(obj, 'shape', '?')}") | ||||||
write_cast(fout, src_arr, tgt_dtype, mb) | ||||||
streamed = True | ||||||
except Exception: | ||||||
streamed = False # fall back below on any issue | ||||||
|
||||||
if streamed: | ||||||
if shard_bar is not None: | ||||||
shard_bar.update(ti.nbytes) | ||||||
if bar is not None: | ||||||
bar.update(ti.nbytes) | ||||||
self.write_padding(fout, ti.nbytes) | ||||||
ti.tensor = None | ||||||
continue | ||||||
|
||||||
# Fallback: object’s tofile() | ||||||
obj.tofile(fout) | ||||||
if shard_bar is not None: | ||||||
shard_bar.update(ti.nbytes) | ||||||
if bar is not None: | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
from typing import Any, Callable | ||
|
||
import numpy as np | ||
import os | ||
from numpy.typing import DTypeLike | ||
|
||
|
||
|
@@ -221,3 +222,72 @@ def tofile(self, *args, **kwargs): | |
return eager.tofile(*args, **kwargs) | ||
|
||
# TODO: __array_function__ | ||
|
||
# --- begin low-memory streaming for dtype casts ------------------------------ | ||
# Tunable via env: | ||
# GGUF_CAST_CHUNK_MB (MiB per chunk; default 64) | ||
# GGUF_STREAM_LOG (set to any non-empty value to print diagnostics) | ||
|
||
import sys | ||
from .stream_cast import write_cast # sibling helper | ||
|
||
try: | ||
_LAZY_ORIG_ASTYPE = getattr(LazyNumpyTensor, "astype") | ||
_LAZY_ORIG_TOFILE = getattr(LazyNumpyTensor, "tofile") | ||
except NameError: | ||
raise RuntimeError("Expected LazyNumpyTensor to be defined above this block") | ||
|
||
def _slog(msg: str) -> None: | ||
if os.environ.get("GGUF_STREAM_LOG"): | ||
print(f"[gguf-stream] {msg}", file=sys.stdout, flush=True) | ||
Comment on lines
+240
to
+242
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The uses of |
||
|
||
def _gguf_streaming_astype(self, dtype, *args, **kwargs): | ||
"""Tag astype results so writer/tofile can stream them later.""" | ||
tgt = np.dtype(dtype) | ||
out = _LAZY_ORIG_ASTYPE(self, dtype, *args, **kwargs) | ||
# mark as streamable and record target dtype | ||
setattr(out, "_gguf_stream_cast", True) | ||
setattr(out, "_gguf_stream_cast_dtype", tgt) | ||
# NEW: record the *source* lazy tensor for writer-side streaming | ||
setattr(out, "_gguf_stream_cast_src", self) | ||
Comment on lines
+249
to
+252
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, a single attr could be used containing e.g. a Using three separate attrs seems excessive (especially since the existence one is redundant with the other ones existing). This should also simplify (and remove the need for) most of the edge-case handling for missing values (e.g. the missing base array). |
||
_slog(f"mark streamable astype: src={getattr(self._meta,'dtype','?')} -> dst={tgt}") | ||
return out | ||
|
||
def _gguf_streaming_tofile(self, fout, *args, **kwargs): | ||
"""If this lazy tensor is a pure dtype-cast, stream in chunks; else fallback.""" | ||
if not getattr(self, "_gguf_stream_cast", False): | ||
return _LAZY_ORIG_TOFILE(self, fout, *args, **kwargs) | ||
|
||
# default chunk size: 64 MiB (can override via GGUF_CAST_CHUNK_MB) | ||
try: | ||
mb = int(os.environ.get("GGUF_CAST_CHUNK_MB", "64") or "64") | ||
except Exception: | ||
mb = 64 | ||
mb = max(1, mb) | ||
Comment on lines
+261
to
+266
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to make this configurable at runtime? A default value should be fine here, I think? Otherwise this is parsing environment variables at each written tensor. (very minor overhead, though) |
||
|
||
# Prefer the explicitly tagged source lazy tensor if present (step 2) | ||
base = getattr(self, "_gguf_stream_cast_src", None) | ||
|
||
# Fallback to first arg (older astype behavior) if not tagged | ||
if base is None: | ||
base = getattr(self, "_args", None) | ||
base = base[0] if base else None | ||
|
||
try: | ||
src_arr = LazyNumpyTensor.to_eager(base) | ||
except Exception: | ||
src_arr = None | ||
|
||
if not isinstance(src_arr, np.ndarray): | ||
_slog("fallback to original tofile: cannot materialize source to ndarray") | ||
return _LAZY_ORIG_TOFILE(self, fout, *args, **kwargs) | ||
|
||
tgt = getattr(self, "_gguf_stream_cast_dtype", src_arr.dtype) | ||
_slog(f"streaming cast write: chunk={mb} MiB; dst={tgt}; shape={getattr(self._meta,'shape','?')}") | ||
write_cast(fout, src_arr, tgt, mb) | ||
return | ||
|
||
# Install patches | ||
LazyNumpyTensor.astype = _gguf_streaming_astype | ||
LazyNumpyTensor.tofile = _gguf_streaming_tofile | ||
Comment on lines
+290
to
+292
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be cleaner to directly modify the source code of Unless you want this to be disable-able, in which case a subclass (although not sure what name to use) could also be appropriate, and then it could be used in Are there cases where Assuming it's implemented correctly, I think the tag (used to detect whether to stream |
||
# --- end low-memory streaming for dtype casts ------------------------------ |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,80 @@ | ||||||||||
# gguf-py/gguf/stream_cast.py | ||||||||||
from __future__ import annotations | ||||||||||
from typing import Any | ||||||||||
import os | ||||||||||
import sys | ||||||||||
import numpy as np | ||||||||||
|
||||||||||
|
||||||||||
def _slog(msg: str) -> None: | ||||||||||
"""Conditional debug logging when GGUF_STREAM_LOG is set.""" | ||||||||||
if os.environ.get("GGUF_STREAM_LOG"): | ||||||||||
print(f"[gguf-stream] {msg}", file=sys.stdout, flush=True) | ||||||||||
|
||||||||||
|
||||||||||
def _chunk_elems(src_dtype: np.dtype, dst_dtype: np.dtype, chunk_mb: int) -> int: | ||||||||||
""" | ||||||||||
Compute how many elements to process per chunk so that each chunk is | ||||||||||
approximately `chunk_mb` MiB of the *larger* of the source/destination itemsize. | ||||||||||
""" | ||||||||||
try: | ||||||||||
mb = int(chunk_mb) | ||||||||||
except Exception: | ||||||||||
mb = 64 | ||||||||||
mb = max(1, mb) | ||||||||||
item = max(np.dtype(src_dtype).itemsize, np.dtype(dst_dtype).itemsize) | ||||||||||
return max(1, (mb * 1024 * 1024) // item) | ||||||||||
|
||||||||||
|
||||||||||
def write_cast(fout, src: np.ndarray, dst_dtype: Any, chunk_mb: int) -> None: | ||||||||||
""" | ||||||||||
Stream `src.astype(dst_dtype)` to `fout` in fixed-size chunks to cap peak RSS. | ||||||||||
This matches the import site in lazy.py: | ||||||||||
from .stream_cast import write_cast | ||||||||||
Parameters | ||||||||||
---------- | ||||||||||
fout : file-like object | ||||||||||
Open file handle to write bytes to (must support .write()). | ||||||||||
src : np.ndarray | ||||||||||
Source ndarray to be converted and streamed. | ||||||||||
dst_dtype : Any | ||||||||||
Target dtype (anything accepted by np.dtype). | ||||||||||
chunk_mb : int | ||||||||||
Desired chunk size in MiB (will be clamped to >= 1). | ||||||||||
""" | ||||||||||
dst = np.dtype(dst_dtype) | ||||||||||
flat = src.reshape(-1) | ||||||||||
n = flat.size | ||||||||||
ce = _chunk_elems(flat.dtype, dst, chunk_mb) | ||||||||||
|
||||||||||
_slog( | ||||||||||
f"write_cast: src={flat.dtype} -> dst={dst}; n={n}; " | ||||||||||
f"chunk={max(1, int(chunk_mb))} MiB; elems/chunk={ce}" | ||||||||||
) | ||||||||||
|
||||||||||
start = 0 | ||||||||||
# local binding for tiny speed bump | ||||||||||
mv = memoryview | ||||||||||
while start < n: | ||||||||||
end = min(start + ce, n) | ||||||||||
# copy=False avoids an extra tmp when possible | ||||||||||
chunk = flat[start:end].astype(dst, copy=False) | ||||||||||
fout.write(mv(chunk).tobytes()) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, directly use the
Suggested change
I don't know if it would handle non-contiguous strides correctly, though, in this case the previous There's also
Suggested change
No idea of the performance difference of these approaches, but I think |
||||||||||
start = end | ||||||||||
|
||||||||||
|
||||||||||
# Optional: writer-side API that accepts chunk size in bytes (used by gguf_writer) | ||||||||||
def stream_write(fout, src_arr: np.ndarray, dst_dtype: Any, chunk_bytes: int) -> None: | ||||||||||
""" | ||||||||||
Same as write_cast, but the chunk size is given in bytes. | ||||||||||
Kept for compatibility with earlier helper drafts. | ||||||||||
""" | ||||||||||
if not isinstance(chunk_bytes, int) or chunk_bytes <= 0: | ||||||||||
chunk_mb = 64 | ||||||||||
else: | ||||||||||
# round bytes to MiB for the element count helper | ||||||||||
chunk_mb = max(1, chunk_bytes // (1024 * 1024)) | ||||||||||
|
||||||||||
write_cast(fout, src_arr, dst_dtype, chunk_mb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no reason to remove that comment. There's nothing which attempted to fix the stated TODO.