From 3adce62b2c083ea2e899336c98ae3126b3d6d3f0 Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Sun, 8 Jan 2023 22:12:53 +0530 Subject: [PATCH 1/3] Add next_obs and dones to RolloutBuffer --- stable_baselines3/common/buffers.py | 11 +++++++++++ stable_baselines3/common/on_policy_algorithm.py | 4 +++- stable_baselines3/common/type_aliases.py | 2 ++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 2dafd415b1..3ec54cd412 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -352,14 +352,17 @@ def __init__( self.gamma = gamma self.observations, self.actions, self.rewards, self.advantages = None, None, None, None self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.next_observations, self.dones = None, None self.generator_ready = False self.reset() def reset(self) -> None: self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32) + self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -408,8 +411,10 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra def add( self, obs: np.ndarray, + next_obs: np.ndarray, action: np.ndarray, reward: np.ndarray, + done: np.ndarray, episode_start: np.ndarray, value: th.Tensor, log_prob: th.Tensor, @@ -437,8 +442,10 @@ def add( action = action.reshape((self.n_envs, self.action_dim)) self.observations[self.pos] = np.array(obs).copy() + self.next_observations[self.pos] = np.array(next_obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() + self.dones[self.pos] = np.array(done).copy() self.episode_starts[self.pos] = np.array(episode_start).copy() self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() @@ -454,7 +461,9 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample _tensor_names = [ "observations", + "next_observations", "actions", + "dones", "values", "log_probs", "advantages", @@ -477,7 +486,9 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: data = ( self.observations[batch_inds], + self.next_observations[batch_inds], self.actions[batch_inds], + self.dones[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index bc0dda49fa..ca56abe6d7 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -201,7 +201,9 @@ def collect_rollouts( terminal_value = self.policy.predict_values(terminal_obs)[0] rewards[idx] += self.gamma * terminal_value - rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) + rollout_buffer.add( + self._last_obs, new_obs, actions, rewards, dones, self._last_episode_starts, values, log_probs + ) self._last_obs = new_obs self._last_episode_starts = dones diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 7227667a1e..5f8a011fa6 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -29,7 +29,9 @@ class RolloutBufferSamples(NamedTuple): observations: th.Tensor + next_observations: th.Tensor actions: th.Tensor + dones: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor advantages: th.Tensor From 7886606722680ef2d6b8e4e7399ec19681dbf7d4 Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Mon, 9 Jan 2023 18:24:34 +0530 Subject: [PATCH 2/3] Add next_obs and dones to DictRolloutBuffer --- stable_baselines3/common/buffers.py | 21 +++++++++++++++++++-- stable_baselines3/common/type_aliases.py | 2 ++ tests/test_buffers.py | 2 +- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 3ec54cd412..18cf8c37dc 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -696,15 +696,18 @@ def __init__( self.gamma = gamma self.observations, self.actions, self.rewards, self.advantages = None, None, None, None self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.next_observations, self.dones = None, None self.generator_ready = False self.reset() def reset(self) -> None: assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" - self.observations = {} + self.observations, self.next_observations = {}, {} for key, obs_input_shape in self.obs_shape.items(): self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32) + self.next_observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -717,8 +720,10 @@ def reset(self) -> None: def add( self, obs: Dict[str, np.ndarray], + next_obs: Dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, + done: np.ndarray, episode_start: np.ndarray, value: th.Tensor, log_prob: th.Tensor, @@ -745,8 +750,15 @@ def add( obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) self.observations[key][self.pos] = obs_ + for key in self.next_observations.keys(): + next_obs_ = np.array(next_obs[key]).copy() + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + next_obs_ = next_obs_.reshape((self.n_envs,) + self.obs_shape[key]) + self.next_observations[key][self.pos] = next_obs_ + self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() + self.dones[self.pos] = np.array(done).copy() self.episode_starts[self.pos] = np.array(episode_start).copy() self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() @@ -763,7 +775,10 @@ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSa for key, obs in self.observations.items(): self.observations[key] = self.swap_and_flatten(obs) - _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] + for key, obs in self.next_observations.items(): + self.next_observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "dones", "values", "log_probs", "advantages", "returns"] for tensor in _tensor_names: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) @@ -782,7 +797,9 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non return DictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + next_observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.next_observations.items()}, actions=self.to_torch(self.actions[batch_inds]), + dones=self.to_torch(self.dones[batch_inds]), old_values=self.to_torch(self.values[batch_inds].flatten()), old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 5f8a011fa6..6a9f2b0518 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -40,7 +40,9 @@ class RolloutBufferSamples(NamedTuple): class DictRolloutBufferSamples(NamedTuple): observations: TensorDict + next_observations: TensorDict actions: th.Tensor + dones: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor advantages: th.Tensor diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 0e028e670d..474f8b1ed5 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -120,7 +120,7 @@ def test_device_buffer(replay_buffer_cls, device): next_obs, reward, done, info = env.step(action) if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1) - buffer.add(obs, action, reward, episode_start, values, log_prob) + buffer.add(obs, next_obs, action, reward, done, episode_start, values, log_prob) else: buffer.add(obs, next_obs, action, reward, done, info) obs = next_obs From 9adc561132c593a92b4a57b2e2ca094824fbcbe3 Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Mon, 9 Jan 2023 18:39:20 +0530 Subject: [PATCH 3/3] Fix lint error --- stable_baselines3/common/on_policy_algorithm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index ca56abe6d7..45ac236646 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -201,9 +201,7 @@ def collect_rollouts( terminal_value = self.policy.predict_values(terminal_obs)[0] rewards[idx] += self.gamma * terminal_value - rollout_buffer.add( - self._last_obs, new_obs, actions, rewards, dones, self._last_episode_starts, values, log_probs - ) + rollout_buffer.add(self._last_obs, new_obs, actions, rewards, dones, self._last_episode_starts, values, log_probs) self._last_obs = new_obs self._last_episode_starts = dones