@@ -151,7 +151,7 @@ def load_state_dict(
151151 checkpoint_file : Union [str , os .PathLike ],
152152 dduf_entries : Optional [Dict [str , DDUFEntry ]] = None ,
153153 disable_mmap : bool = False ,
154- map_location : Optional [ Union [str , torch .device ]] = None ,
154+ map_location : Union [str , torch .device ] = "cpu" ,
155155):
156156 """
157157 Reads a checkpoint file, returning properly formatted errors if they arise.
@@ -174,8 +174,6 @@ def load_state_dict(
174174 elif file_extension == GGUF_FILE_EXTENSION :
175175 return load_gguf_checkpoint (checkpoint_file )
176176 else :
177- if map_location is None :
178- map_location = "cpu"
179177 extra_args = {}
180178 weights_only_kwarg = {"weights_only" : True } if is_torch_version (">=" , "1.13" ) else {}
181179 # mmap can only be used with files serialized with zipfile-based format.
@@ -187,7 +185,7 @@ def load_state_dict(
187185 and not disable_mmap
188186 ):
189187 extra_args = {"mmap" : True }
190- return torch .load (checkpoint_file , map_location = "cpu" , ** weights_only_kwarg , ** extra_args )
188+ return torch .load (checkpoint_file , map_location = map_location , ** weights_only_kwarg , ** extra_args )
191189 except Exception as e :
192190 try :
193191 with open (checkpoint_file ) as f :
0 commit comments