@@ -648,6 +648,7 @@ def save_pretrained(
648648 variant : Optional [str ] = None ,
649649 max_shard_size : Union [int , str ] = "10GB" ,
650650 push_to_hub : bool = False ,
651+ use_flashpack : bool = False ,
651652 ** kwargs ,
652653 ):
653654 """
@@ -700,7 +701,12 @@ def save_pretrained(
700701 " the logger on the traceback to understand the reason why the quantized model is not serializable."
701702 )
702703
703- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
704+ weights_name = WEIGHTS_NAME
705+ if use_flashpack :
706+ weights_name = FLASHPACK_WEIGHTS_NAME
707+ elif safe_serialization :
708+ weights_name = SAFETENSORS_WEIGHTS_NAME
709+
704710 weights_name = _add_variant (weights_name , variant )
705711 weights_name_pattern = weights_name .replace (".bin" , "{suffix}.bin" ).replace (
706712 ".safetensors" , "{suffix}.safetensors"
@@ -727,58 +733,68 @@ def save_pretrained(
727733 # Save the model
728734 state_dict = model_to_save .state_dict ()
729735
730- # Save the model
731- state_dict_split = split_torch_state_dict_into_shards (
732- state_dict , max_shard_size = max_shard_size , filename_pattern = weights_name_pattern
733- )
734-
735- # Clean the folder from a previous save
736- if is_main_process :
737- for filename in os .listdir (save_directory ):
738- if filename in state_dict_split .filename_to_tensors .keys ():
739- continue
740- full_filename = os .path .join (save_directory , filename )
741- if not os .path .isfile (full_filename ):
742- continue
743- weights_without_ext = weights_name_pattern .replace (".bin" , "" ).replace (".safetensors" , "" )
744- weights_without_ext = weights_without_ext .replace ("{suffix}" , "" )
745- filename_without_ext = filename .replace (".bin" , "" ).replace (".safetensors" , "" )
746- # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
747- if (
748- filename .startswith (weights_without_ext )
749- and _REGEX_SHARD .fullmatch (filename_without_ext ) is not None
750- ):
751- os .remove (full_filename )
752-
753- for filename , tensors in state_dict_split .filename_to_tensors .items ():
754- shard = {tensor : state_dict [tensor ].contiguous () for tensor in tensors }
755- filepath = os .path .join (save_directory , filename )
756- if safe_serialization :
757- # At some point we will need to deal better with save_function (used for TPU and other distributed
758- # joyfulness), but for now this enough.
759- safetensors .torch .save_file (shard , filepath , metadata = {"format" : "pt" })
760- else :
761- torch .save (shard , filepath )
736+ if use_flashpack :
737+ if is_flashpack_available ():
738+ import flashpack
762739
763- if state_dict_split .is_sharded :
764- index = {
765- "metadata" : state_dict_split .metadata ,
766- "weight_map" : state_dict_split .tensor_to_filename ,
767- }
768- save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
769- save_index_file = os .path .join (save_directory , _add_variant (save_index_file , variant ))
770- # Save the index as well
771- with open (save_index_file , "w" , encoding = "utf-8" ) as f :
772- content = json .dumps (index , indent = 2 , sort_keys = True ) + "\n "
773- f .write (content )
774- logger .info (
775- f"The model is bigger than the maximum size per checkpoint ({ max_shard_size } ) and is going to be "
776- f"split in { len (state_dict_split .filename_to_tensors )} checkpoint shards. You can find where each parameters has been saved in the "
777- f"index located at { save_index_file } ."
778- )
740+ flashpack .serialization .pack_to_file (
741+ state_dict_or_model = state_dict ,
742+ destination_path = save_directory ,
743+ target_dtype = self .dtype (),
744+ )
779745 else :
780- path_to_weights = os .path .join (save_directory , weights_name )
781- logger .info (f"Model weights saved in { path_to_weights } " )
746+ # Save the model
747+ state_dict_split = split_torch_state_dict_into_shards (
748+ state_dict , max_shard_size = max_shard_size , filename_pattern = weights_name_pattern
749+ )
750+
751+ # Clean the folder from a previous save
752+ if is_main_process :
753+ for filename in os .listdir (save_directory ):
754+ if filename in state_dict_split .filename_to_tensors .keys ():
755+ continue
756+ full_filename = os .path .join (save_directory , filename )
757+ if not os .path .isfile (full_filename ):
758+ continue
759+ weights_without_ext = weights_name_pattern .replace (".bin" , "" ).replace (".safetensors" , "" )
760+ weights_without_ext = weights_without_ext .replace ("{suffix}" , "" )
761+ filename_without_ext = filename .replace (".bin" , "" ).replace (".safetensors" , "" )
762+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
763+ if (
764+ filename .startswith (weights_without_ext )
765+ and _REGEX_SHARD .fullmatch (filename_without_ext ) is not None
766+ ):
767+ os .remove (full_filename )
768+
769+ for filename , tensors in state_dict_split .filename_to_tensors .items ():
770+ shard = {tensor : state_dict [tensor ].contiguous () for tensor in tensors }
771+ filepath = os .path .join (save_directory , filename )
772+ if safe_serialization :
773+ # At some point we will need to deal better with save_function (used for TPU and other distributed
774+ # joyfulness), but for now this enough.
775+ safetensors .torch .save_file (shard , filepath , metadata = {"format" : "pt" })
776+ else :
777+ torch .save (shard , filepath )
778+
779+ if state_dict_split .is_sharded :
780+ index = {
781+ "metadata" : state_dict_split .metadata ,
782+ "weight_map" : state_dict_split .tensor_to_filename ,
783+ }
784+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
785+ save_index_file = os .path .join (save_directory , _add_variant (save_index_file , variant ))
786+ # Save the index as well
787+ with open (save_index_file , "w" , encoding = "utf-8" ) as f :
788+ content = json .dumps (index , indent = 2 , sort_keys = True ) + "\n "
789+ f .write (content )
790+ logger .info (
791+ f"The model is bigger than the maximum size per checkpoint ({ max_shard_size } ) and is going to be "
792+ f"split in { len (state_dict_split .filename_to_tensors )} checkpoint shards. You can find where each parameters has been saved in the "
793+ f"index located at { save_index_file } ."
794+ )
795+ else :
796+ path_to_weights = os .path .join (save_directory , weights_name )
797+ logger .info (f"Model weights saved in { path_to_weights } " )
782798
783799 if push_to_hub :
784800 # Create a new empty model card and eventually tag it
0 commit comments