Skip to content

Commit 2da5f18

Browse files
authored
feat: streaming safetensors writer (#1164)
* stream writes Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add test Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent f258802 commit 2da5f18

File tree

2 files changed

+508
-41
lines changed

2 files changed

+508
-41
lines changed

nemo_automodel/components/checkpoint/_backports/filesystem.py

Lines changed: 150 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import os
2323
import pickle
2424
import queue
25+
import struct
2526
import threading
2627
import uuid
2728
import warnings
@@ -68,6 +69,7 @@
6869
from nemo_automodel.components.checkpoint._backports.hf_utils import (
6970
CUSTOM_METADATA_KEY,
7071
DCP_VERSION_KEY,
72+
DTYPE_MAP,
7173
HF_DCP_VERSION,
7274
)
7375

@@ -113,6 +115,21 @@ def _generate_uuid() -> str:
113115
return str(uuid.uuid4())
114116

115117

118+
_DTYPE_TO_SAFETENSORS_DTYPE: dict[torch.dtype, str] = {v: k for k, v in DTYPE_MAP.items()}
119+
120+
121+
def _to_safetensors_dtype_str(dtype: torch.dtype) -> str:
122+
"""Return the safetensors dtype string for a torch.dtype.
123+
124+
Raises:
125+
ValueError: If dtype is not supported by our safetensors serializer.
126+
"""
127+
try:
128+
return _DTYPE_TO_SAFETENSORS_DTYPE[dtype]
129+
except KeyError as e:
130+
raise ValueError(f"Unsupported dtype for safetensors serialization: {dtype}") from e
131+
132+
116133
class _TensorLoader(ABC):
117134
@abstractmethod
118135
def add(self, size: int, obj: object) -> None:
@@ -281,12 +298,13 @@ def close(self):
281298

282299

283300
def _item_size(item: WriteItem) -> int:
284-
size = 1
285301
assert item.tensor_data is not None
302+
# NOTE: WriteItems can represent *chunks* of a global tensor (e.g., SHARD writes).
303+
# The on-disk payload corresponds to the chunk sizes, not the global tensor size.
304+
size = 1
286305
# can't use math.prod as PT needs to support older python
287-
for s in item.tensor_data.size:
306+
for s in item.tensor_data.chunk.sizes:
288307
size *= s
289-
290308
dtype = item.tensor_data.properties.dtype
291309
return size * torch._utils._element_size(dtype)
292310

@@ -412,49 +430,140 @@ def _write_files_from_queue(
412430
write_results = []
413431

414432
with create_stream(file_name, "wb") as stream:
415-
for write_item in bytes_w:
416-
data = planner.resolve_data(write_item)
417-
write_results.append(
418-
_write_item(
419-
transforms,
420-
stream,
421-
data,
422-
write_item,
423-
storage_key,
424-
serialization_format,
433+
if serialization_format == SerializationFormat.SAFETENSORS:
434+
# SAFETENSORS expects the stream to start with the header.
435+
# DCP's BYTE_IO items would corrupt the file layout, so we explicitly disallow them.
436+
if bytes_w:
437+
raise RuntimeError(
438+
"Cannot serialize BYTE_IO items in safetensors format. "
439+
"This is a bug: safetensors files can only contain tensors."
425440
)
426-
)
427441

428-
tensor_dict = {}
429-
metadata_dict = {}
430-
for tensor, write_item in loader.values():
431-
assert tensor.is_cpu
432-
write_results.append(
433-
_write_item(
434-
transforms,
435-
stream,
436-
tensor,
437-
write_item,
438-
storage_key,
439-
serialization_format,
442+
# Determine the tensor write order. The overlapping loader sorts by size.
443+
ordered_tensor_w = tensor_w
444+
if isinstance(loader, _OverlappingCpuLoader):
445+
ordered_tensor_w = sorted(tensor_w, key=_item_size)
446+
447+
# Build the custom DCP sharding metadata (per-key saved offsets).
448+
metadata_dict = {
449+
wi.index.fqn: {"saved_offsets": wi.tensor_data.chunk.offsets} for wi in ordered_tensor_w
450+
}
451+
452+
# Build the safetensors header up-front so we can stream raw tensor bytes without
453+
# materializing the full file in memory (safetensors.torch.save returns a full bytes blob).
454+
header: dict[str, Any] = {}
455+
tensor_data_offsets: dict[str, tuple[int, int]] = {}
456+
data_offset = 0
457+
for wi in ordered_tensor_w:
458+
assert wi.tensor_data is not None
459+
nbytes = _item_size(wi)
460+
header[wi.index.fqn] = {
461+
"dtype": _to_safetensors_dtype_str(wi.tensor_data.properties.dtype),
462+
# SAFETENSORS entries store the chunk payload; global shape is
463+
# reconstructed via CUSTOM_METADATA_KEY (saved_offsets) at load time.
464+
"shape": [int(s) for s in wi.tensor_data.chunk.sizes],
465+
"data_offsets": [data_offset, data_offset + nbytes],
466+
}
467+
tensor_data_offsets[wi.index.fqn] = (data_offset, nbytes)
468+
data_offset += nbytes
469+
470+
header["__metadata__"] = {
471+
CUSTOM_METADATA_KEY: json.dumps(metadata_dict),
472+
DCP_VERSION_KEY: str(HF_DCP_VERSION),
473+
"format": "pt",
474+
}
475+
476+
header_json = json.dumps(header, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
477+
# Pad to 8-byte alignment (matches safetensors writer behavior).
478+
pad_len = (8 - (len(header_json) % 8)) % 8
479+
if pad_len:
480+
header_json += b" " * pad_len
481+
482+
stream.write(struct.pack("<Q", len(header_json)))
483+
stream.write(header_json)
484+
header_size = 8 + len(header_json)
485+
486+
# Stream tensors in the same order as the header entries.
487+
expected_fqns = [wi.index.fqn for wi in ordered_tensor_w]
488+
expected_idx = 0
489+
for tensor, write_item in loader.values():
490+
assert tensor.is_cpu
491+
if expected_idx >= len(expected_fqns) or write_item.index.fqn != expected_fqns[expected_idx]:
492+
raise RuntimeError(
493+
"Internal error: safetensors write order mismatch. "
494+
f"Expected {expected_fqns[expected_idx] if expected_idx < len(expected_fqns) else '<end>'}, "
495+
f"got {write_item.index.fqn}."
496+
)
497+
expected_idx += 1
498+
499+
# Ensure a compact, contiguous CPU buffer before writing.
500+
# Some tensor views can have larger backing storage than numel*element_size.
501+
if not tensor.is_contiguous():
502+
tensor = tensor.contiguous()
503+
expected_nbytes = tensor.numel() * tensor.element_size()
504+
if tensor.untyped_storage().size() != expected_nbytes:
505+
tensor = tensor.clone()
506+
507+
# Write raw bytes without creating an intermediate bytes blob.
508+
# NOTE: `Tensor.view(dtype)` does not allow 0-dim tensors when element sizes differ
509+
# (e.g. BF16 -> uint8). Safetensors stores raw bytes, so reshape scalars to 1D.
510+
if tensor.dim() == 0:
511+
tensor = tensor.reshape(1)
512+
byte_view = tensor.view(torch.uint8)
513+
np_view = byte_view.numpy()
514+
stream.write(memoryview(np_view))
515+
516+
# Record storage offsets/lengths (absolute offset within the safetensors file).
517+
data_off, planned_nbytes = tensor_data_offsets[write_item.index.fqn]
518+
if expected_nbytes != planned_nbytes:
519+
raise RuntimeError(
520+
"Internal error: safetensors size mismatch for "
521+
f"{write_item.index.fqn}: planned {planned_nbytes} bytes, got {expected_nbytes}."
522+
)
523+
524+
write_results.append(
525+
WriteResult(
526+
index=write_item.index,
527+
size_in_bytes=planned_nbytes,
528+
storage_data=_StorageInfo(
529+
storage_key,
530+
offset=header_size + data_off,
531+
length=planned_nbytes,
532+
),
533+
)
440534
)
441-
)
442-
tensor_dict[write_item.index.fqn] = tensor
443-
metadata_dict[write_item.index.fqn] = {"saved_offsets": write_item.tensor_data.chunk.offsets}
444535

445-
if serialization_format == SerializationFormat.SAFETENSORS:
446-
from safetensors.torch import save # type: ignore[import-not-found]
447-
448-
stream.write(
449-
save(
450-
tensor_dict,
451-
metadata={
452-
CUSTOM_METADATA_KEY: json.dumps(metadata_dict),
453-
DCP_VERSION_KEY: str(HF_DCP_VERSION),
454-
"format": "pt",
455-
},
536+
if expected_idx != len(expected_fqns):
537+
raise RuntimeError(
538+
"Internal error: did not write all tensors to safetensors file. "
539+
f"Wrote {expected_idx}/{len(expected_fqns)} tensors."
540+
)
541+
else:
542+
for write_item in bytes_w:
543+
data = planner.resolve_data(write_item)
544+
write_results.append(
545+
_write_item(
546+
transforms,
547+
stream,
548+
data,
549+
write_item,
550+
storage_key,
551+
serialization_format,
552+
)
553+
)
554+
555+
for tensor, write_item in loader.values():
556+
assert tensor.is_cpu
557+
write_results.append(
558+
_write_item(
559+
transforms,
560+
stream,
561+
tensor,
562+
write_item,
563+
storage_key,
564+
serialization_format,
565+
)
456566
)
457-
)
458567

459568
if use_fsync:
460569
try:

0 commit comments

Comments
 (0)