From 9175f267bd15f7a169de6f3bb1946c1385a40272 Mon Sep 17 00:00:00 2001 From: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com> Date: Thu, 13 Mar 2025 18:08:30 +0200 Subject: [PATCH] Add unit tests for RecurrentDictRolloutBuffer Introduced tests to validate the functionality of the RecurrentDictRolloutBuffer, including buffer capacity, hidden state shapes, and batch sampling. Tests cover various buffer configurations and ensure consistency of LSTM state updates. Signed-off-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com> --- tests/buffers/__init__.py | 0 tests/buffers/test_recurrent_buff.py | 89 ++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 tests/buffers/__init__.py create mode 100644 tests/buffers/test_recurrent_buff.py diff --git a/tests/buffers/__init__.py b/tests/buffers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/buffers/test_recurrent_buff.py b/tests/buffers/test_recurrent_buff.py new file mode 100644 index 00000000..4e7a30fd --- /dev/null +++ b/tests/buffers/test_recurrent_buff.py @@ -0,0 +1,89 @@ +import pytest +import numpy as np +import torch +import torch as th +from gymnasium import spaces + +from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer +from sb3_contrib.common.recurrent.type_aliases import RNNStates + + +import pytest +import numpy as np +import torch as th +from sb3_contrib.common.recurrent.buffers import pad, pad_and_flatten, create_sequencers + + +@pytest.mark.parametrize("n_envs", [1, 2, 5]) +# @pytest.mark.parametrize("n_envs", [1,]) +@pytest.mark.parametrize("buffer_size", [100, 200]) +@pytest.mark.parametrize( + "observation_space", + [ + spaces.Dict({"obs": spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)}), + spaces.Dict({"obs": spaces.MultiDiscrete([1, 1, 1])}), + ], +) +@pytest.mark.parametrize("action_space", [spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)]) +@pytest.mark.parametrize("hidden_state_size", [64, 128]) +@pytest.mark.parametrize("batch_size", [64, 100]) +@pytest.mark.parametrize("lstm_num_layers", [1, 3]) +def test_recurrent_dict_rollout_buffer_get( + n_envs, + buffer_size, + observation_space, + action_space, + hidden_state_size, + batch_size, + lstm_num_layers, +): + single_hs_shape = (lstm_num_layers, n_envs, hidden_state_size) + + hidden_state_shape = (buffer_size, *single_hs_shape) + rollout_buffer = RecurrentDictRolloutBuffer( + buffer_size=buffer_size, + observation_space=observation_space, + action_space=action_space, + hidden_state_shape=hidden_state_shape, + device="cpu", + gae_lambda=0.95, + gamma=0.99, + n_envs=n_envs, + ) + buffer_size = rollout_buffer.buffer_size + n_envs = rollout_buffer.n_envs + assert rollout_buffer.hidden_states_pi.shape == (buffer_size, *single_hs_shape) + assert rollout_buffer.observations["obs"].shape == (buffer_size, n_envs, 3) + episode_start = np.array([False], dtype=bool) + observation = {"obs": np.array([[0.1, 0.2, 0.3]] * n_envs, dtype=np.float32)} + action = np.array([[0.1, -0.1]] * n_envs, dtype=np.float32) + reward = np.array([0.5] * n_envs, dtype=np.float32) + value = th.tensor([0.1] * n_envs) + log_prob = th.tensor([-0.1] * n_envs) + lstm_states = RNNStates( + pi=(th.ones(single_hs_shape), th.ones(single_hs_shape)), + vf=(th.ones(single_hs_shape), th.ones(single_hs_shape)), + ) + lstm_states.pi[0][:, 0, :] *= 2 + lstm_states.pi[1][:, 0, :] *= 2 + lstm_states.vf[0][:, 0, :] *= 2 + lstm_states.vf[1][:, 0, :] *= 2 + for _ in range(buffer_size - 1): + rollout_buffer.add(observation, action, reward, episode_start, value, log_prob, lstm_states=lstm_states) + episode_start = np.array([False], dtype=bool) + assert not rollout_buffer.full + rollout_buffer.add(observation, action, reward, episode_start, value, log_prob, lstm_states=lstm_states) + assert rollout_buffer.full + assert rollout_buffer.hidden_states_pi.shape == (buffer_size, *single_hs_shape) + assert rollout_buffer.cell_states_pi.shape == (buffer_size, *single_hs_shape) + assert rollout_buffer.hidden_states_vf.shape == (buffer_size, *single_hs_shape) + assert rollout_buffer.cell_states_vf.shape == (buffer_size, *single_hs_shape) + + for i in range(10): + sample = next(rollout_buffer.get(batch_size=batch_size)) + assert sample.observations["obs"].shape[-1] == 3, f"obs epoch {i}" + assert sample.actions.shape[-1] == 2, f"actions epoch {i}" + assert len(set(torch.unique(sample.lstm_states.pi[0]))) == 1 + # todo: define behavior + +