@@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
559559 If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
560560 `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
561561 weights. If set to `False`, `safetensors` weights are not loaded.
562+ disable_mmap ('bool', *optional*, defaults to 'False'):
563+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
564+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
562565
563566 <Tip>
564567
@@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
604607 variant = kwargs .pop ("variant" , None )
605608 use_safetensors = kwargs .pop ("use_safetensors" , None )
606609 quantization_config = kwargs .pop ("quantization_config" , None )
610+ disable_mmap = kwargs .pop ("disable_mmap" , False )
607611
608612 allow_pickle = False
609613 if use_safetensors is None :
@@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
883887 # TODO (sayakpaul, SunMarc): remove this after model loading refactor
884888 else :
885889 param_device = torch .device (torch .cuda .current_device ())
886- state_dict = load_state_dict (model_file , variant = variant )
890+ state_dict = load_state_dict (model_file , variant = variant , disable_mmap = disable_mmap )
887891 model ._convert_deprecated_attention_blocks (state_dict )
888892
889893 # move the params from meta device to cpu
@@ -979,7 +983,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
979983 else :
980984 model = cls .from_config (config , ** unused_kwargs )
981985
982- state_dict = load_state_dict (model_file , variant = variant )
986+ state_dict = load_state_dict (model_file , variant = variant , disable_mmap = disable_mmap )
983987 model ._convert_deprecated_attention_blocks (state_dict )
984988
985989 model , missing_keys , unexpected_keys , mismatched_keys , error_msgs = cls ._load_pretrained_model (
0 commit comments