Skip to content

Commit 12b6c97

Browse files
committed
stage
Signed-off-by: adil-a <adil.asif2000@hotmail.com>
1 parent d86b960 commit 12b6c97

File tree

1 file changed

+90
-23
lines changed

1 file changed

+90
-23
lines changed

nemo_automodel/components/checkpoint/_backports/consolidate_hf_safetensors.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ class _InputFileData:
9393
metadata: Any = None
9494

9595

96-
GLOBAL_OUTPUT_FILES_DATA: Optional[dict[str, _OutputFileData]] = None
97-
98-
9996
def _parse_input_metadata(
10097
input_files_data: dict[str, _InputFileData],
10198
output_files_data: dict[str, _OutputFileData],
@@ -535,6 +532,91 @@ def _write_overall_metadata_file(
535532
json.dump(metadata_to_write, metadata_file, indent=2)
536533

537534

535+
def _write_overall_metadata_file_from_shards(
536+
input_dir: str,
537+
output_dir: str,
538+
fqn_to_index_mapping: dict[str, int],
539+
) -> None:
540+
"""
541+
Write the overall metadata file by reading metadata from input shard files.
542+
543+
This creates a model.safetensors.index.json file that HuggingFace models use
544+
to locate tensors across multiple files. Unlike _write_overall_metadata_file,
545+
this function reads the necessary shape/dtype information directly from the
546+
input shard files, avoiding the need for distributed gather operations.
547+
548+
Args:
549+
input_dir: Directory containing the input shard safetensors files
550+
output_dir: Directory where the metadata file will be written
551+
fqn_to_index_mapping: Mapping from tensor names to output file indices
552+
"""
553+
from safetensors.torch import _getdtype # type: ignore[import]
554+
555+
# Find all safetensors files in the input directory
556+
safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}"))
557+
558+
# Read metadata from all input files
559+
input_files_data: dict[str, _InputFileData] = {}
560+
for input_file in safetensors_files:
561+
with open(input_file, "rb") as f:
562+
metadata, metadata_size = _get_safetensors_file_metadata(f)
563+
input_files_data[input_file] = _InputFileData(
564+
metadata_size=metadata_size,
565+
metadata=metadata,
566+
)
567+
568+
# Compute full tensor shapes from sharded metadata (same logic as _parse_input_metadata)
569+
fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {}
570+
for file_data in input_files_data.values():
571+
safetensors_metadata = file_data.metadata
572+
dcp_sharding_info = _get_dcp_custom_metadata(safetensors_metadata)
573+
if not dcp_sharding_info:
574+
raise ValueError(
575+
"No DCP custom metadata found in safetensors file. The file must be saved with DCP to be consolidated."
576+
)
577+
578+
for key, val in safetensors_metadata.items():
579+
if key == DEFAULT_EXTRA_METADATA_KEY:
580+
continue
581+
582+
sizes = val[SHAPE_KEY]
583+
offsets = dcp_sharding_info[key][SAVED_OFFSETS_KEY]
584+
585+
if key not in fqn_to_size_mapping:
586+
cur_size = [size + offset for size, offset in zip(sizes, offsets)]
587+
fqn_to_size_mapping[key] = (cur_size, val[DTYPE_KEY])
588+
else:
589+
cur_size = fqn_to_size_mapping[key][0]
590+
for i in range(len(sizes)):
591+
cur_size[i] = max(cur_size[i], sizes[i] + offsets[i])
592+
593+
# Compute total_size and weight_map
594+
max_index = max(fqn_to_index_mapping.values())
595+
total_size = 0
596+
weight_map = {}
597+
598+
for fqn, (tensor_shape, dtype_str) in fqn_to_size_mapping.items():
599+
dtype = _getdtype(dtype_str)
600+
try:
601+
dtype_size = torch.finfo(dtype).bits // 8
602+
except TypeError:
603+
dtype_size = torch.tensor([], dtype=dtype).element_size()
604+
605+
total_size += math.prod(tensor_shape) * dtype_size
606+
607+
idx = fqn_to_index_mapping[fqn]
608+
weight_map[fqn] = _gen_file_name(idx, max_index)
609+
610+
# Write the metadata file
611+
metadata_to_write: dict[str, Any] = {}
612+
metadata_to_write["metadata"] = {"total_size": total_size}
613+
metadata_to_write["weight_map"] = weight_map
614+
615+
metadata_path = os.path.join(output_dir, _metadata_fn)
616+
with open(metadata_path, "w") as metadata_file:
617+
json.dump(metadata_to_write, metadata_file, indent=2)
618+
619+
538620
def _consolidate_safetensors_files(
539621
input_dir: str,
540622
output_dir: str,
@@ -747,33 +829,18 @@ def consolidate_safetensors_files_on_every_rank(
747829
filtered_filename_mapping[fqn] = filename
748830

749831
# Call the existing consolidation function with the filtered mapping
750-
output_files_data = _consolidate_safetensors_files(
832+
_consolidate_safetensors_files(
751833
input_dir=input_dir,
752834
output_dir=output_dir,
753835
fqn_to_file_mapping=filtered_filename_mapping,
754836
num_threads=num_threads,
755837
use_staging=use_staging,
756838
staging_dir=staging_dir,
757839
)
758-
else:
759-
output_files_data = {}
760-
761-
global GLOBAL_OUTPUT_FILES_DATA
762-
if GLOBAL_OUTPUT_FILES_DATA is None:
763-
# cache after the first time we checkpoint
764-
GLOBAL_OUTPUT_FILES_DATA = {}
765-
global_output_files_data_list = [None] * world_size
766-
dist.all_gather_object(global_output_files_data_list, output_files_data)
767-
for item in global_output_files_data_list:
768-
if item:
769-
GLOBAL_OUTPUT_FILES_DATA.update(item)
770-
771-
# Write overall model.index.safetensors.json file with weight map
772-
if GLOBAL_OUTPUT_FILES_DATA:
773-
if rank == 0:
774-
_write_overall_metadata_file(output_dir, GLOBAL_OUTPUT_FILES_DATA)
775-
if dist.is_available() and dist.is_initialized():
776-
dist.barrier()
840+
841+
# Write overall model.index.safetensors.json file with weight map (rank 0 only)
842+
if rank == 0:
843+
_write_overall_metadata_file_from_shards(input_dir, output_dir, fqn_to_index_mapping)
777844

778845
logger.info(
779846
"Rank %d: Done consolidating. Processed %d unique indices in %.2f secs.",

0 commit comments

Comments
 (0)