Skip to content

Commit e78ba6f

Browse files
authored
Hotfix to load policies saved with SB3 <= v1.6 (#1234)
* Hotfix to load policies saved with SB3 <= v1.6 * Add warning and test * Update doc
1 parent 3c028f3 commit e78ba6f

File tree

4 files changed

+38
-4
lines changed

4 files changed

+38
-4
lines changed

docs/misc/changelog.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,17 @@ Changelog
44
==========
55

66

7-
Release 1.7.0a9 (WIP)
7+
Release 1.7.0a10 (WIP)
88
--------------------------
99

10+
.. note::
11+
12+
A2C and PPO saved with SB3 < 1.7.0 will show a warning about
13+
missing keys in the state dict when loaded with SB3 >= 1.7.0.
14+
To suppress the warning, simply save the model again.
15+
You can find more info in `issue #1233 <https://github.com/DLR-RM/stable-baselines3/issues/1233>`_
16+
17+
1018
Breaking Changes:
1119
^^^^^^^^^^^^^^^^^
1220
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,

stable_baselines3/common/base_class.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io
44
import pathlib
55
import time
6+
import warnings
67
from abc import ABC, abstractmethod
78
from collections import deque
89
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
@@ -705,8 +706,25 @@ def load(
705706
model.__dict__.update(kwargs)
706707
model._setup_model()
707708

708-
# put state_dicts back in place
709-
model.set_parameters(params, exact_match=True, device=device)
709+
try:
710+
# put state_dicts back in place
711+
model.set_parameters(params, exact_match=True, device=device)
712+
except RuntimeError as e:
713+
# Patch to load Policy saved using SB3 < 1.7.0
714+
# the error is probably due to old policy being loaded
715+
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
716+
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
717+
model.set_parameters(params, exact_match=False, device=device)
718+
warnings.warn(
719+
"You are probably loading a model saved with SB3 < 1.7.0, "
720+
"we deactivated exact_match so you can save the model "
721+
"again to avoid issues in the future "
722+
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
723+
f"Original error: {e} \n"
724+
"Note: the model should still work fine, this only a warning."
725+
)
726+
else:
727+
raise e
710728

711729
# put other pytorch variables back in place
712730
if pytorch_variables is not None:

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.7.0a9
1+
1.7.0a10

tests/test_save_load.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,14 @@ def test_save_load_env_cnn(tmp_path, model_class):
338338
# clear file from os
339339
os.remove(tmp_path / "test_save.zip")
340340

341+
# Check we can load models saved with SB3 < 1.7.0
342+
if model_class == A2C:
343+
del model.policy.pi_features_extractor
344+
model.save(tmp_path / "test_save")
345+
with pytest.warns(UserWarning):
346+
model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
347+
os.remove(tmp_path / "test_save.zip")
348+
341349

342350
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
343351
def test_save_load_replay_buffer(tmp_path, model_class):

0 commit comments

Comments
 (0)