|
22 | 22 | import os |
23 | 23 | import pickle |
24 | 24 | import queue |
| 25 | +import struct |
25 | 26 | import threading |
26 | 27 | import uuid |
27 | 28 | import warnings |
|
68 | 69 | from nemo_automodel.components.checkpoint._backports.hf_utils import ( |
69 | 70 | CUSTOM_METADATA_KEY, |
70 | 71 | DCP_VERSION_KEY, |
| 72 | + DTYPE_MAP, |
71 | 73 | HF_DCP_VERSION, |
72 | 74 | ) |
73 | 75 |
|
@@ -113,6 +115,21 @@ def _generate_uuid() -> str: |
113 | 115 | return str(uuid.uuid4()) |
114 | 116 |
|
115 | 117 |
|
| 118 | +_DTYPE_TO_SAFETENSORS_DTYPE: dict[torch.dtype, str] = {v: k for k, v in DTYPE_MAP.items()} |
| 119 | + |
| 120 | + |
| 121 | +def _to_safetensors_dtype_str(dtype: torch.dtype) -> str: |
| 122 | + """Return the safetensors dtype string for a torch.dtype. |
| 123 | +
|
| 124 | + Raises: |
| 125 | + ValueError: If dtype is not supported by our safetensors serializer. |
| 126 | + """ |
| 127 | + try: |
| 128 | + return _DTYPE_TO_SAFETENSORS_DTYPE[dtype] |
| 129 | + except KeyError as e: |
| 130 | + raise ValueError(f"Unsupported dtype for safetensors serialization: {dtype}") from e |
| 131 | + |
| 132 | + |
116 | 133 | class _TensorLoader(ABC): |
117 | 134 | @abstractmethod |
118 | 135 | def add(self, size: int, obj: object) -> None: |
@@ -281,12 +298,13 @@ def close(self): |
281 | 298 |
|
282 | 299 |
|
283 | 300 | def _item_size(item: WriteItem) -> int: |
284 | | - size = 1 |
285 | 301 | assert item.tensor_data is not None |
| 302 | + # NOTE: WriteItems can represent *chunks* of a global tensor (e.g., SHARD writes). |
| 303 | + # The on-disk payload corresponds to the chunk sizes, not the global tensor size. |
| 304 | + size = 1 |
286 | 305 | # can't use math.prod as PT needs to support older python |
287 | | - for s in item.tensor_data.size: |
| 306 | + for s in item.tensor_data.chunk.sizes: |
288 | 307 | size *= s |
289 | | - |
290 | 308 | dtype = item.tensor_data.properties.dtype |
291 | 309 | return size * torch._utils._element_size(dtype) |
292 | 310 |
|
@@ -412,49 +430,140 @@ def _write_files_from_queue( |
412 | 430 | write_results = [] |
413 | 431 |
|
414 | 432 | with create_stream(file_name, "wb") as stream: |
415 | | - for write_item in bytes_w: |
416 | | - data = planner.resolve_data(write_item) |
417 | | - write_results.append( |
418 | | - _write_item( |
419 | | - transforms, |
420 | | - stream, |
421 | | - data, |
422 | | - write_item, |
423 | | - storage_key, |
424 | | - serialization_format, |
| 433 | + if serialization_format == SerializationFormat.SAFETENSORS: |
| 434 | + # SAFETENSORS expects the stream to start with the header. |
| 435 | + # DCP's BYTE_IO items would corrupt the file layout, so we explicitly disallow them. |
| 436 | + if bytes_w: |
| 437 | + raise RuntimeError( |
| 438 | + "Cannot serialize BYTE_IO items in safetensors format. " |
| 439 | + "This is a bug: safetensors files can only contain tensors." |
425 | 440 | ) |
426 | | - ) |
427 | 441 |
|
428 | | - tensor_dict = {} |
429 | | - metadata_dict = {} |
430 | | - for tensor, write_item in loader.values(): |
431 | | - assert tensor.is_cpu |
432 | | - write_results.append( |
433 | | - _write_item( |
434 | | - transforms, |
435 | | - stream, |
436 | | - tensor, |
437 | | - write_item, |
438 | | - storage_key, |
439 | | - serialization_format, |
| 442 | + # Determine the tensor write order. The overlapping loader sorts by size. |
| 443 | + ordered_tensor_w = tensor_w |
| 444 | + if isinstance(loader, _OverlappingCpuLoader): |
| 445 | + ordered_tensor_w = sorted(tensor_w, key=_item_size) |
| 446 | + |
| 447 | + # Build the custom DCP sharding metadata (per-key saved offsets). |
| 448 | + metadata_dict = { |
| 449 | + wi.index.fqn: {"saved_offsets": wi.tensor_data.chunk.offsets} for wi in ordered_tensor_w |
| 450 | + } |
| 451 | + |
| 452 | + # Build the safetensors header up-front so we can stream raw tensor bytes without |
| 453 | + # materializing the full file in memory (safetensors.torch.save returns a full bytes blob). |
| 454 | + header: dict[str, Any] = {} |
| 455 | + tensor_data_offsets: dict[str, tuple[int, int]] = {} |
| 456 | + data_offset = 0 |
| 457 | + for wi in ordered_tensor_w: |
| 458 | + assert wi.tensor_data is not None |
| 459 | + nbytes = _item_size(wi) |
| 460 | + header[wi.index.fqn] = { |
| 461 | + "dtype": _to_safetensors_dtype_str(wi.tensor_data.properties.dtype), |
| 462 | + # SAFETENSORS entries store the chunk payload; global shape is |
| 463 | + # reconstructed via CUSTOM_METADATA_KEY (saved_offsets) at load time. |
| 464 | + "shape": [int(s) for s in wi.tensor_data.chunk.sizes], |
| 465 | + "data_offsets": [data_offset, data_offset + nbytes], |
| 466 | + } |
| 467 | + tensor_data_offsets[wi.index.fqn] = (data_offset, nbytes) |
| 468 | + data_offset += nbytes |
| 469 | + |
| 470 | + header["__metadata__"] = { |
| 471 | + CUSTOM_METADATA_KEY: json.dumps(metadata_dict), |
| 472 | + DCP_VERSION_KEY: str(HF_DCP_VERSION), |
| 473 | + "format": "pt", |
| 474 | + } |
| 475 | + |
| 476 | + header_json = json.dumps(header, separators=(",", ":"), ensure_ascii=False).encode("utf-8") |
| 477 | + # Pad to 8-byte alignment (matches safetensors writer behavior). |
| 478 | + pad_len = (8 - (len(header_json) % 8)) % 8 |
| 479 | + if pad_len: |
| 480 | + header_json += b" " * pad_len |
| 481 | + |
| 482 | + stream.write(struct.pack("<Q", len(header_json))) |
| 483 | + stream.write(header_json) |
| 484 | + header_size = 8 + len(header_json) |
| 485 | + |
| 486 | + # Stream tensors in the same order as the header entries. |
| 487 | + expected_fqns = [wi.index.fqn for wi in ordered_tensor_w] |
| 488 | + expected_idx = 0 |
| 489 | + for tensor, write_item in loader.values(): |
| 490 | + assert tensor.is_cpu |
| 491 | + if expected_idx >= len(expected_fqns) or write_item.index.fqn != expected_fqns[expected_idx]: |
| 492 | + raise RuntimeError( |
| 493 | + "Internal error: safetensors write order mismatch. " |
| 494 | + f"Expected {expected_fqns[expected_idx] if expected_idx < len(expected_fqns) else '<end>'}, " |
| 495 | + f"got {write_item.index.fqn}." |
| 496 | + ) |
| 497 | + expected_idx += 1 |
| 498 | + |
| 499 | + # Ensure a compact, contiguous CPU buffer before writing. |
| 500 | + # Some tensor views can have larger backing storage than numel*element_size. |
| 501 | + if not tensor.is_contiguous(): |
| 502 | + tensor = tensor.contiguous() |
| 503 | + expected_nbytes = tensor.numel() * tensor.element_size() |
| 504 | + if tensor.untyped_storage().size() != expected_nbytes: |
| 505 | + tensor = tensor.clone() |
| 506 | + |
| 507 | + # Write raw bytes without creating an intermediate bytes blob. |
| 508 | + # NOTE: `Tensor.view(dtype)` does not allow 0-dim tensors when element sizes differ |
| 509 | + # (e.g. BF16 -> uint8). Safetensors stores raw bytes, so reshape scalars to 1D. |
| 510 | + if tensor.dim() == 0: |
| 511 | + tensor = tensor.reshape(1) |
| 512 | + byte_view = tensor.view(torch.uint8) |
| 513 | + np_view = byte_view.numpy() |
| 514 | + stream.write(memoryview(np_view)) |
| 515 | + |
| 516 | + # Record storage offsets/lengths (absolute offset within the safetensors file). |
| 517 | + data_off, planned_nbytes = tensor_data_offsets[write_item.index.fqn] |
| 518 | + if expected_nbytes != planned_nbytes: |
| 519 | + raise RuntimeError( |
| 520 | + "Internal error: safetensors size mismatch for " |
| 521 | + f"{write_item.index.fqn}: planned {planned_nbytes} bytes, got {expected_nbytes}." |
| 522 | + ) |
| 523 | + |
| 524 | + write_results.append( |
| 525 | + WriteResult( |
| 526 | + index=write_item.index, |
| 527 | + size_in_bytes=planned_nbytes, |
| 528 | + storage_data=_StorageInfo( |
| 529 | + storage_key, |
| 530 | + offset=header_size + data_off, |
| 531 | + length=planned_nbytes, |
| 532 | + ), |
| 533 | + ) |
440 | 534 | ) |
441 | | - ) |
442 | | - tensor_dict[write_item.index.fqn] = tensor |
443 | | - metadata_dict[write_item.index.fqn] = {"saved_offsets": write_item.tensor_data.chunk.offsets} |
444 | 535 |
|
445 | | - if serialization_format == SerializationFormat.SAFETENSORS: |
446 | | - from safetensors.torch import save # type: ignore[import-not-found] |
447 | | - |
448 | | - stream.write( |
449 | | - save( |
450 | | - tensor_dict, |
451 | | - metadata={ |
452 | | - CUSTOM_METADATA_KEY: json.dumps(metadata_dict), |
453 | | - DCP_VERSION_KEY: str(HF_DCP_VERSION), |
454 | | - "format": "pt", |
455 | | - }, |
| 536 | + if expected_idx != len(expected_fqns): |
| 537 | + raise RuntimeError( |
| 538 | + "Internal error: did not write all tensors to safetensors file. " |
| 539 | + f"Wrote {expected_idx}/{len(expected_fqns)} tensors." |
| 540 | + ) |
| 541 | + else: |
| 542 | + for write_item in bytes_w: |
| 543 | + data = planner.resolve_data(write_item) |
| 544 | + write_results.append( |
| 545 | + _write_item( |
| 546 | + transforms, |
| 547 | + stream, |
| 548 | + data, |
| 549 | + write_item, |
| 550 | + storage_key, |
| 551 | + serialization_format, |
| 552 | + ) |
| 553 | + ) |
| 554 | + |
| 555 | + for tensor, write_item in loader.values(): |
| 556 | + assert tensor.is_cpu |
| 557 | + write_results.append( |
| 558 | + _write_item( |
| 559 | + transforms, |
| 560 | + stream, |
| 561 | + tensor, |
| 562 | + write_item, |
| 563 | + storage_key, |
| 564 | + serialization_format, |
| 565 | + ) |
456 | 566 | ) |
457 | | - ) |
458 | 567 |
|
459 | 568 | if use_fsync: |
460 | 569 | try: |
|
0 commit comments