@@ -93,9 +93,6 @@ class _InputFileData:
9393 metadata : Any = None
9494
9595
96- GLOBAL_OUTPUT_FILES_DATA : Optional [dict [str , _OutputFileData ]] = None
97-
98-
9996def _parse_input_metadata (
10097 input_files_data : dict [str , _InputFileData ],
10198 output_files_data : dict [str , _OutputFileData ],
@@ -535,6 +532,91 @@ def _write_overall_metadata_file(
535532 json .dump (metadata_to_write , metadata_file , indent = 2 )
536533
537534
535+ def _write_overall_metadata_file_from_shards (
536+ input_dir : str ,
537+ output_dir : str ,
538+ fqn_to_index_mapping : dict [str , int ],
539+ ) -> None :
540+ """
541+ Write the overall metadata file by reading metadata from input shard files.
542+
543+ This creates a model.safetensors.index.json file that HuggingFace models use
544+ to locate tensors across multiple files. Unlike _write_overall_metadata_file,
545+ this function reads the necessary shape/dtype information directly from the
546+ input shard files, avoiding the need for distributed gather operations.
547+
548+ Args:
549+ input_dir: Directory containing the input shard safetensors files
550+ output_dir: Directory where the metadata file will be written
551+ fqn_to_index_mapping: Mapping from tensor names to output file indices
552+ """
553+ from safetensors .torch import _getdtype # type: ignore[import]
554+
555+ # Find all safetensors files in the input directory
556+ safetensors_files = glob .glob (os .path .join (input_dir , f"*{ SUFFIX } " ))
557+
558+ # Read metadata from all input files
559+ input_files_data : dict [str , _InputFileData ] = {}
560+ for input_file in safetensors_files :
561+ with open (input_file , "rb" ) as f :
562+ metadata , metadata_size = _get_safetensors_file_metadata (f )
563+ input_files_data [input_file ] = _InputFileData (
564+ metadata_size = metadata_size ,
565+ metadata = metadata ,
566+ )
567+
568+ # Compute full tensor shapes from sharded metadata (same logic as _parse_input_metadata)
569+ fqn_to_size_mapping : dict [str , tuple [list [int ], str ]] = {}
570+ for file_data in input_files_data .values ():
571+ safetensors_metadata = file_data .metadata
572+ dcp_sharding_info = _get_dcp_custom_metadata (safetensors_metadata )
573+ if not dcp_sharding_info :
574+ raise ValueError (
575+ "No DCP custom metadata found in safetensors file. The file must be saved with DCP to be consolidated."
576+ )
577+
578+ for key , val in safetensors_metadata .items ():
579+ if key == DEFAULT_EXTRA_METADATA_KEY :
580+ continue
581+
582+ sizes = val [SHAPE_KEY ]
583+ offsets = dcp_sharding_info [key ][SAVED_OFFSETS_KEY ]
584+
585+ if key not in fqn_to_size_mapping :
586+ cur_size = [size + offset for size , offset in zip (sizes , offsets )]
587+ fqn_to_size_mapping [key ] = (cur_size , val [DTYPE_KEY ])
588+ else :
589+ cur_size = fqn_to_size_mapping [key ][0 ]
590+ for i in range (len (sizes )):
591+ cur_size [i ] = max (cur_size [i ], sizes [i ] + offsets [i ])
592+
593+ # Compute total_size and weight_map
594+ max_index = max (fqn_to_index_mapping .values ())
595+ total_size = 0
596+ weight_map = {}
597+
598+ for fqn , (tensor_shape , dtype_str ) in fqn_to_size_mapping .items ():
599+ dtype = _getdtype (dtype_str )
600+ try :
601+ dtype_size = torch .finfo (dtype ).bits // 8
602+ except TypeError :
603+ dtype_size = torch .tensor ([], dtype = dtype ).element_size ()
604+
605+ total_size += math .prod (tensor_shape ) * dtype_size
606+
607+ idx = fqn_to_index_mapping [fqn ]
608+ weight_map [fqn ] = _gen_file_name (idx , max_index )
609+
610+ # Write the metadata file
611+ metadata_to_write : dict [str , Any ] = {}
612+ metadata_to_write ["metadata" ] = {"total_size" : total_size }
613+ metadata_to_write ["weight_map" ] = weight_map
614+
615+ metadata_path = os .path .join (output_dir , _metadata_fn )
616+ with open (metadata_path , "w" ) as metadata_file :
617+ json .dump (metadata_to_write , metadata_file , indent = 2 )
618+
619+
538620def _consolidate_safetensors_files (
539621 input_dir : str ,
540622 output_dir : str ,
@@ -747,33 +829,18 @@ def consolidate_safetensors_files_on_every_rank(
747829 filtered_filename_mapping [fqn ] = filename
748830
749831 # Call the existing consolidation function with the filtered mapping
750- output_files_data = _consolidate_safetensors_files (
832+ _consolidate_safetensors_files (
751833 input_dir = input_dir ,
752834 output_dir = output_dir ,
753835 fqn_to_file_mapping = filtered_filename_mapping ,
754836 num_threads = num_threads ,
755837 use_staging = use_staging ,
756838 staging_dir = staging_dir ,
757839 )
758- else :
759- output_files_data = {}
760-
761- global GLOBAL_OUTPUT_FILES_DATA
762- if GLOBAL_OUTPUT_FILES_DATA is None :
763- # cache after the first time we checkpoint
764- GLOBAL_OUTPUT_FILES_DATA = {}
765- global_output_files_data_list = [None ] * world_size
766- dist .all_gather_object (global_output_files_data_list , output_files_data )
767- for item in global_output_files_data_list :
768- if item :
769- GLOBAL_OUTPUT_FILES_DATA .update (item )
770-
771- # Write overall model.index.safetensors.json file with weight map
772- if GLOBAL_OUTPUT_FILES_DATA :
773- if rank == 0 :
774- _write_overall_metadata_file (output_dir , GLOBAL_OUTPUT_FILES_DATA )
775- if dist .is_available () and dist .is_initialized ():
776- dist .barrier ()
840+
841+ # Write overall model.index.safetensors.json file with weight map (rank 0 only)
842+ if rank == 0 :
843+ _write_overall_metadata_file_from_shards (input_dir , output_dir , fqn_to_index_mapping )
777844
778845 logger .info (
779846 "Rank %d: Done consolidating. Processed %d unique indices in %.2f secs." ,
0 commit comments