Skip to content

Commit fda3d4d

Browse files
qgallouedecaraffin
andauthored
Fix returned type in predict (#964)
* `arr[0]` to `arr.squeeze(0)` * `squeeze(axis=0)` to `squeeze(0)` * Type testing * Add type test for unvectorized observation * `squeeze(0)` to `squeeze(axis=0)` * Treatment of the laziness symptoms * Update changelog * Udate changelog Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent a18b91e commit fda3d4d

File tree

4 files changed

+7
-4
lines changed

4 files changed

+7
-4
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ SB3-Contrib
1717

1818
Bug Fixes:
1919
^^^^^^^^^^
20+
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
2021

2122
Deprecations:
2223
^^^^^^^^^^^^^
@@ -1011,4 +1012,4 @@ And all the contributors:
10111012
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
10121013
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
10131014
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
1014-
@Melanol
1015+
@Melanol @qgallouedec

stable_baselines3/common/distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,10 +578,10 @@ def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
578578
return th.mm(latent_sde, self.exploration_mat)
579579
# Use batch matrix multiplication for efficient computation
580580
# (batch_size, n_features) -> (batch_size, 1, n_features)
581-
latent_sde = latent_sde.unsqueeze(1)
581+
latent_sde = latent_sde.unsqueeze(dim=1)
582582
# (batch_size, 1, n_actions)
583583
noise = th.bmm(latent_sde, self.exploration_matrices)
584-
return noise.squeeze(1)
584+
return noise.squeeze(dim=1)
585585

586586
def actions_from_params(
587587
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False

stable_baselines3/common/policies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def predict(
350350

351351
# Remove batch dimension if needed
352352
if not vectorized_env:
353-
actions = actions[0]
353+
actions = actions.squeeze(axis=0)
354354

355355
return actions, state
356356

tests/test_predict.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,13 @@ def test_predict(model_class, env_id, device):
7373

7474
obs = env.reset()
7575
action, _ = model.predict(obs)
76+
assert isinstance(action, np.ndarray)
7677
assert action.shape == env.action_space.shape
7778
assert env.action_space.contains(action)
7879

7980
vec_env_obs = vec_env.reset()
8081
action, _ = model.predict(vec_env_obs)
82+
assert isinstance(action, np.ndarray)
8183
assert action.shape[0] == vec_env_obs.shape[0]
8284

8385
# Special case for DQN to check the epsilon greedy exploration

0 commit comments

Comments
 (0)