Skip to content

Commit b05568c

Browse files
committed
Fixes.
1 parent f678739 commit b05568c

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

mlspm/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,9 @@ def load_checkpoint(
258258
import torch
259259

260260
if rank is None:
261-
state = torch.load(file_name)
261+
state = torch.load(file_name, weights_only=False)
262262
else:
263-
state = torch.load(file_name, map_location={"cuda:0": f"cuda:{rank}"})
263+
state = torch.load(file_name, map_location={"cuda:0": f"cuda:{rank}"}, weights_only=False)
264264
model.load_state_dict(state["model_params"])
265265

266266
if optimizer:

tests/test_data_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_tar_writer():
3838

3939
with pytest.raises(RuntimeError):
4040
# Cannot overwrite an existing file
41-
with TarWriter(base_path, base_name, max_count=10) as tar_writer:
41+
with TarWriter(base_path, base_name, max_count=10, async_write=False) as tar_writer:
4242
pass
4343

4444
rmtree(base_path)

0 commit comments

Comments
 (0)