Skip to content

Commit 3c028f3

Browse files
authored
Fix load_from_tensor (#1231)
1 parent 5549b34 commit 3c028f3

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 1.7.0a8 (WIP)
7+
Release 1.7.0a9 (WIP)
88
--------------------------
99

1010
Breaking Changes:
@@ -39,6 +39,7 @@ Bug Fixes:
3939
- Fixed ``Self`` return type using ``TypeVar``
4040
- Fixed the env checker, the key was not passed when checking images from Dict observation space
4141
- Fixed ``normalize_images`` which was not passed to parent class in some cases
42+
- Fixed ``load_from_vector`` that was broken with newer PyTorch version when passing PyTorch tensor
4243

4344
Deprecations:
4445
^^^^^^^^^^^^^

stable_baselines3/common/policies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def load_from_vector(self, vector: np.ndarray) -> None:
199199
200200
:param vector:
201201
"""
202-
th.nn.utils.vector_to_parameters(th.FloatTensor(vector, device=self.device), self.parameters())
202+
th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters())
203203

204204
def parameters_to_vector(self) -> np.ndarray:
205205
"""

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.7.0a8
1+
1.7.0a9

0 commit comments

Comments
 (0)