@@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
559559                If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the 
560560                `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` 
561561                weights. If set to `False`, `safetensors` weights are not loaded. 
562+             disable_mmap ('bool', *optional*, defaults to 'False'): 
563+                 Whether to disable mmap when loading a Safetensors model. This option can perform better when the model 
564+                 is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. 
562565
563566        <Tip> 
564567
@@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
604607        variant  =  kwargs .pop ("variant" , None )
605608        use_safetensors  =  kwargs .pop ("use_safetensors" , None )
606609        quantization_config  =  kwargs .pop ("quantization_config" , None )
610+         disable_mmap  =  kwargs .pop ("disable_mmap" , False )
607611
608612        allow_pickle  =  False 
609613        if  use_safetensors  is  None :
@@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
883887                    # TODO (sayakpaul,  SunMarc): remove this after model loading refactor 
884888                    else :
885889                        param_device  =  torch .device (torch .cuda .current_device ())
886-                     state_dict  =  load_state_dict (model_file , variant = variant )
890+                     state_dict  =  load_state_dict (model_file , variant = variant ,  disable_mmap = disable_mmap )
887891                    model ._convert_deprecated_attention_blocks (state_dict )
888892
889893                    # move the params from meta device to cpu 
@@ -979,7 +983,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
979983            else :
980984                model  =  cls .from_config (config , ** unused_kwargs )
981985
982-                 state_dict  =  load_state_dict (model_file , variant = variant )
986+                 state_dict  =  load_state_dict (model_file , variant = variant ,  disable_mmap = disable_mmap )
983987                model ._convert_deprecated_attention_blocks (state_dict )
984988
985989                model , missing_keys , unexpected_keys , mismatched_keys , error_msgs  =  cls ._load_pretrained_model (
0 commit comments