Skip to content

Commit 0b1e9f5

Browse files
committed
map_location default cpu
1 parent bcbd493 commit 0b1e9f5

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def load_state_dict(
151151
checkpoint_file: Union[str, os.PathLike],
152152
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
153153
disable_mmap: bool = False,
154-
map_location: Optional[Union[str, torch.device]] = None,
154+
map_location: Union[str, torch.device] = "cpu",
155155
):
156156
"""
157157
Reads a checkpoint file, returning properly formatted errors if they arise.
@@ -174,8 +174,6 @@ def load_state_dict(
174174
elif file_extension == GGUF_FILE_EXTENSION:
175175
return load_gguf_checkpoint(checkpoint_file)
176176
else:
177-
if map_location is None:
178-
map_location = "cpu"
179177
extra_args = {}
180178
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
181179
# mmap can only be used with files serialized with zipfile-based format.
@@ -187,7 +185,7 @@ def load_state_dict(
187185
and not disable_mmap
188186
):
189187
extra_args = {"mmap": True}
190-
return torch.load(checkpoint_file, map_location="cpu", **weights_only_kwarg, **extra_args)
188+
return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
191189
except Exception as e:
192190
try:
193191
with open(checkpoint_file) as f:

0 commit comments

Comments
 (0)