diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 52c862d1..b3e34d7b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.8.0a0 (WIP) +Release 2.8.0a1 (WIP) -------------------------- Breaking Changes: @@ -19,7 +19,8 @@ New Features: Bug Fixes: ^^^^^^^^^^ -- Do not call ``forward()`` method directly in ``RecurrentPPO`` +- Fix RecurrentPPO and MaskablePPO forward and predict do not reshape action before clip it (@immortal-boy) +- Do not call ``forward()`` method directly in ``RecurrentPPO`` (@immortal-boy) Deprecations: ^^^^^^^^^^^^^ @@ -657,3 +658,4 @@ Contributors: @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec @mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @danielpalen @corentinlger +@immortal-boy diff --git a/sb3_contrib/common/maskable/distributions.py b/sb3_contrib/common/maskable/distributions.py index b1bf92e2..7ddf6fad 100644 --- a/sb3_contrib/common/maskable/distributions.py +++ b/sb3_contrib/common/maskable/distributions.py @@ -138,19 +138,15 @@ def proba_distribution( return self def log_prob(self, actions: th.Tensor) -> th.Tensor: - assert self.distribution is not None, "Must set distribution parameters" return self.distribution.log_prob(actions) def entropy(self) -> th.Tensor: - assert self.distribution is not None, "Must set distribution parameters" return self.distribution.entropy() def sample(self) -> th.Tensor: - assert self.distribution is not None, "Must set distribution parameters" return self.distribution.sample() def mode(self) -> th.Tensor: - assert self.distribution is not None, "Must set distribution parameters" return th.argmax(self.distribution.probs, dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -164,7 +160,6 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th. return actions, log_prob def apply_masking(self, masks: MaybeMasks) -> None: - assert self.distribution is not None, "Must set distribution parameters" self.distribution.apply_masking(masks) diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 464c4fe6..bf44cfe9 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -139,6 +139,7 @@ def forward( distribution.apply_masking(action_masks) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) + actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc] return actions, values, log_prob def extract_features( # type: ignore[override] @@ -304,7 +305,7 @@ def predict( with th.no_grad(): actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks) # Convert to numpy - actions = actions.cpu().numpy() # type: ignore[assignment] + actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[assignment, misc] if isinstance(self.action_space, spaces.Box): if self.squash_output: diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 813c6142..4a678d26 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -253,6 +253,7 @@ def forward( distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) + actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]) return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf) def get_distribution( @@ -412,7 +413,7 @@ def predict( states = (states[0].cpu().numpy(), states[1].cpu().numpy()) # Convert to numpy - actions = actions.cpu().numpy() + actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[assignment] if isinstance(self.action_space, spaces.Box): if self.squash_output: diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 11922a5c..8813e985 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.8.0a0 +2.8.0a1 diff --git a/tests/test_distributions.py b/tests/test_distributions.py index bb3cf269..f44ae323 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -82,19 +82,19 @@ def test_distribution_must_be_initialized(self): DIMS = 2 dist = MaskableCategoricalDistribution(DIMS) - with pytest.raises(AssertionError): + with pytest.raises(AttributeError): dist.log_prob(th.randint(DIMS - 1, (1, 3))) - with pytest.raises(AssertionError): + with pytest.raises(AttributeError): dist.entropy() - with pytest.raises(AssertionError): + with pytest.raises(AttributeError): dist.sample() - with pytest.raises(AssertionError): + with pytest.raises(AttributeError): dist.mode() - with pytest.raises(AssertionError): + with pytest.raises(AttributeError): dist.apply_masking(None) # But now we can diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 294998ab..3bcb508e 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -244,3 +244,22 @@ def make_env(): # In CartPole-v1, a non-recurrent policy can easily get >= 450. # In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50. evaluate_policy(model, env, reward_threshold=450) + + +class MultiDimensionalActionSpaceEnv(gym.Env): + def __init__(self): + self.observation_space = spaces.Box(low=-1, high=1, shape=(10,), dtype=np.float32) + self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + return self.observation_space.sample(), {} + + def step(self, action): + return self.observation_space.sample(), 1, np.random.rand() > 0.8, False, {} + + +def test_ppo_multi_dimensional_action_space(): + env = MultiDimensionalActionSpaceEnv() + model = RecurrentPPO("MlpLstmPolicy", env, n_steps=64, n_epochs=2).learn(64) + evaluate_policy(model, model.get_env())