From 27213cdc0520bb473aaab84af55c50f68a378166 Mon Sep 17 00:00:00 2001 From: Paolo Magliano Date: Fri, 24 Jan 2025 15:41:11 +0100 Subject: [PATCH 1/8] Optimize compute J, gae and montecarlo adv --- .../actor_critic/deep_actor_critic/a2c.py | 4 +- mushroom_rl/core/dataset.py | 30 +++--- mushroom_rl/rl_utils/value_functions.py | 40 ++++---- mushroom_rl/utils/episodes.py | 67 ++++++++++++++ tests/core/test_dataset.py | 31 +++++++ tests/rl_utils/test_value_functions.py | 92 +++++++++++++++++++ 6 files changed, 229 insertions(+), 35 deletions(-) create mode 100644 mushroom_rl/utils/episodes.py create mode 100644 tests/rl_utils/test_value_functions.py diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py index 91d2dc6c..2739ba67 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py @@ -57,10 +57,10 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params, ) def fit(self, dataset): - state, action, reward, next_state, absorbing, _ = dataset.parse(to='torch') + state, action, reward, next_state, absorbing, last = dataset.parse(to='torch') v, adv = compute_advantage_montecarlo(self._V, state, next_state, - reward, absorbing, + reward, absorbing, last, self.mdp_info.gamma) self._V.fit(state, v, **self._critic_fit_params) diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index 367f7383..88102acb 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -11,6 +11,7 @@ from ._impl import * +from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes class DatasetInfo(Serializable): def __init__(self, backend, device, horizon, gamma, state_shape, state_dtype, action_shape, action_dtype, @@ -473,22 +474,19 @@ def compute_J(self, gamma=1.): The cumulative discounted reward of each episode in the dataset. """ - js = list() - - j = 0. - episode_steps = 0 - for i in range(len(self)): - j += gamma ** episode_steps * self.reward[i] - episode_steps += 1 - if self.last[i] or i == len(self) - 1: - js.append(j) - j = 0. - episode_steps = 0 - - if len(js) == 0: - js = [0.] - - return self._array_backend.from_list(js) + r_ep = split_episodes(self.last, self.reward) + + if len(r_ep.shape) == 1: + r_ep = r_ep.unsqueeze(0) + if hasattr(r_ep, 'device'): + js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype, device=r_ep.device) + else: + js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype) + + for k in range(r_ep.shape[1]): + js += gamma ** k * r_ep[..., k] + + return js def compute_metrics(self, gamma=1.): """ diff --git a/mushroom_rl/rl_utils/value_functions.py b/mushroom_rl/rl_utils/value_functions.py index 3b6b338d..b82611e1 100644 --- a/mushroom_rl/rl_utils/value_functions.py +++ b/mushroom_rl/rl_utils/value_functions.py @@ -1,7 +1,7 @@ import torch +from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes - -def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma): +def compute_advantage_montecarlo(V, s, ss, r, absorbing, last, gamma): """ Function to estimate the advantage and new value function target over a dataset. The value function is estimated using rollouts @@ -24,18 +24,21 @@ def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma): """ with torch.no_grad(): r = r.squeeze() - q = torch.zeros(len(r)) v = V(s).squeeze() - q_next = V(ss[-1]).squeeze().item() - for rev_k in range(len(r)): - k = len(r) - rev_k - 1 - q_next = r[k] + gamma * q_next * (1 - absorbing[k].int()) - q[k] = q_next + r_ep, absorbing_ep, ss_ep = split_episodes(last, r, absorbing, ss) + q_ep = torch.zeros_like(r_ep, dtype=torch.float32) + q_next_ep = V(ss_ep[..., -1, :]).squeeze() + + for rev_k in range(r_ep.shape[-1]): + k = r_ep.shape[-1] - rev_k - 1 + q_next_ep = r_ep[..., k] + gamma * q_next_ep * (1 - absorbing_ep[..., k].int()) + q_ep[..., k] = q_next_ep + q = unsplit_episodes(last, q_ep) adv = q - v - return q[:, None], adv[:, None] + return q[:, None], adv[:, None] def compute_advantage(V, s, ss, r, absorbing, gamma): """ @@ -97,13 +100,16 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam): with torch.no_grad(): v = V(s) v_next = V(ss) - gen_adv = torch.empty_like(v) - for rev_k in range(len(v)): - k = len(v) - rev_k - 1 - if last[k] or rev_k == 0: - gen_adv[k] = r[k] - v[k] - if not absorbing[k]: - gen_adv[k] += gamma * v_next[k] + + v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r, absorbing) + gen_adv_ep = torch.zeros_like(v_ep) + for rev_k in range(v_ep.shape[-1]): + k = v_ep.shape[-1] - rev_k - 1 + if rev_k == 0: + gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] else: - gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1] + gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] + gamma * lam * gen_adv_ep[..., k + 1] + + gen_adv = unsplit_episodes(last, gen_adv_ep).unsqueeze(-1) + return gen_adv + v, gen_adv \ No newline at end of file diff --git a/mushroom_rl/utils/episodes.py b/mushroom_rl/utils/episodes.py new file mode 100644 index 00000000..5552ce67 --- /dev/null +++ b/mushroom_rl/utils/episodes.py @@ -0,0 +1,67 @@ +import torch +import numpy + +def split_episodes(last, *arrays): + """ + Split a array from shape (n_steps) to (n_episodes, max_episode_steps). + """ + + if last.sum().item() <= 1: + return arrays if len(arrays) > 1 else arrays[0] + + row_idx, colum_idx, n_episodes, max_episode_steps = _torch_get_episode_idx(last) if type(last) == torch.Tensor else _numpy_get_episode_idx(last) + episodes_arrays = [] + + for array in arrays: + if type(last) == torch.Tensor: + array_ep = torch.zeros((n_episodes, max_episode_steps, *array.shape[1:]), dtype=array.dtype, device=array.device) + else: + array_ep = numpy.zeros((n_episodes, max_episode_steps, *array.shape[1:]), dtype=array.dtype) + array_ep[row_idx, colum_idx] = array + episodes_arrays.append(array_ep) + + return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0] + +def unsplit_episodes(last, *episodes_arrays): + """ + Unsplit a array from shape (n_episodes, max_episode_steps) to (n_steps). + """ + + if last.sum().item() <= 1: + return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0] + + row_idx, colum_idx, _, _ = _torch_get_episode_idx(last) if type(last) == torch.Tensor else _numpy_get_episode_idx(last) + + arrays = [] + + for episode_array in episodes_arrays: + array = episode_array[row_idx, colum_idx] + arrays.append(array) + + return arrays if len(arrays) > 1 else arrays[0] + +def _torch_get_episode_idx(last): + + n_episodes = last.sum().item() + last_idx = torch.nonzero(last).squeeze() + episode_steps = torch.cat([torch.tensor([last_idx[0] + 1], device=last.device), last_idx[1:] - last_idx[:-1]]) + max_episode_steps = episode_steps.max().item() + + start_idx = torch.cat([torch.tensor([0], device=last.device), last_idx[:-1] + 1]) + row_idx = torch.arange(n_episodes, device=episode_steps.device).repeat_interleave(episode_steps) + colum_idx = torch.arange(last.shape[0], device=last.device) - start_idx[row_idx] + + return row_idx, colum_idx, n_episodes, max_episode_steps + +def _numpy_get_episode_idx(last): + + n_episodes = numpy.sum(last) + last_idx = numpy.flatnonzero(last) + episode_steps = numpy.concatenate(([last_idx[0] + 1], last_idx[1:] - last_idx[:-1])) + max_episode_steps = numpy.max(episode_steps) + + start_idx = numpy.concatenate(([0], last_idx[:-1] + 1)) + row_idx = numpy.repeat(numpy.arange(n_episodes), episode_steps) + column_idx = numpy.arange(last.shape[0]) - start_idx[row_idx] + + return row_idx, column_idx, n_episodes, max_episode_steps diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py index c99c6bfa..191e934d 100644 --- a/tests/core/test_dataset.py +++ b/tests/core/test_dataset.py @@ -130,4 +130,35 @@ def test_dataset_loading(tmpdir): for key in dataset.info: assert np.array_equal(dataset.info[key], new_dataset.info[key]) +def test_compute_J(): + def compute_J(self, gamma=1.): + js = list() + + j = 0. + episode_steps = 0 + for i in range(len(self)): + j += gamma ** episode_steps * self.reward[i] + episode_steps += 1 + if self.last[i] or i == len(self) - 1: + js.append(j) + j = 0. + episode_steps = 0 + + if len(js) == 0: + js = [0.] + + return self._array_backend.from_list(js) + mdp = GridWorld(3, 3, (2, 2)) + dataset = generate_dataset(mdp, 100) + + correct_R = compute_J(dataset) + R = dataset.compute_J() + + assert np.allclose(R, correct_R) + + correct_J = compute_J(dataset, 0.9) + J = dataset.compute_J(0.9) + + assert np.allclose(J, correct_J) +test_compute_J() \ No newline at end of file diff --git a/tests/rl_utils/test_value_functions.py b/tests/rl_utils/test_value_functions.py new file mode 100644 index 00000000..ef5581e1 --- /dev/null +++ b/tests/rl_utils/test_value_functions.py @@ -0,0 +1,92 @@ +import torch +from mushroom_rl.policy import DeterministicPolicy +from mushroom_rl.environments.segway import Segway +from mushroom_rl.core import Core, Agent +from mushroom_rl.approximators import Regressor +from mushroom_rl.approximators.parametric import LinearApproximator, TorchApproximator +from mushroom_rl.rl_utils.value_functions import compute_gae, compute_advantage_montecarlo + + +def test_compute_advantage_montecarlo(): + def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma): + with torch.no_grad(): + r = r.squeeze() + q = torch.zeros(len(r)) + v = V(s).squeeze() + + for rev_k in range(len(r)): + k = len(r) - rev_k - 1 + if last[k] or rev_k == 0: + q_next = V(ss[k]).squeeze().item() + q_next = r[k] + gamma * q_next * (1 - absorbing[k].int()) + q[k] = q_next + + adv = q - v + return q[:, None], adv[:, None] + + test_value_functions(compute_advantage_montecarlo, advantage_montecarlo, 0.99) + +def test_compute_gae(): + def gae(V, s, ss, r, absorbing, last, gamma, lam): + with torch.no_grad(): + v = V(s) + v_next = V(ss) + gen_adv = torch.empty_like(v) + for rev_k in range(len(v)): + k = len(v) - rev_k - 1 + if last[k] or rev_k == 0: + gen_adv[k] = r[k] - v[k] + if not absorbing[k]: + gen_adv[k] += gamma * v_next[k] + else: + gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1] + return gen_adv + v, gen_adv + + test_value_functions(compute_gae, gae, 0.99, 0.95) + +def test_value_functions(test_fun, correct_fun, *args): + mdp = Segway() + V = Regressor(TorchApproximator, input_shape=mdp.info.observation_space.shape, output_shape=(1,), network=Net, loss=torch.nn.MSELoss(), optimizer={'class': torch.optim.Adam, 'params': {'lr': 0.001}}) + + state, action, reward, next_state, absorbing, last = get_episodes(mdp) + + correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args) + v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args) + + assert torch.allclose(v, correct_v) + assert torch.allclose(adv, correct_adv) + + V.fit(state, correct_v) + + correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args) + v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args) + + assert torch.allclose(v, correct_v) + assert torch.allclose(adv, correct_adv) + +def get_episodes(mdp, n_episodes=100): + mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0) + + approximator = Regressor(LinearApproximator, + input_shape=mdp.info.observation_space.shape, + output_shape=mdp.info.action_space.shape, + weights=mu) + + policy = DeterministicPolicy(approximator) + + agent = Agent(mdp.info, policy) + core = Core(agent, mdp) + dataset = core.evaluate(n_episodes=n_episodes) + + return dataset.parse(to='torch') + +class Net(torch.nn.Module): + def __init__(self, input_shape, output_shape, **kwargs): + super().__init__() + self._q = torch.nn.Linear(input_shape[0], output_shape[0]) + + def forward(self, x): + return self._q(x.float()) + +# test_compute_gae() +test_compute_advantage_montecarlo() \ No newline at end of file From fe792798f0d6fc2752587e4f73cf073d597a30ff Mon Sep 17 00:00:00 2001 From: Paolo Magliano Date: Mon, 27 Jan 2025 15:42:29 +0100 Subject: [PATCH 2/8] Update test episode and value function --- mushroom_rl/core/array_backend.py | 25 +++++++++- mushroom_rl/rl_utils/value_functions.py | 9 ++-- mushroom_rl/utils/episodes.py | 62 +++++++++++-------------- tests/core/test_dataset.py | 35 +------------- tests/rl_utils/test_value_functions.py | 18 ++++--- tests/utils/test_episodes.py | 60 ++++++++++++++++++++++++ 6 files changed, 130 insertions(+), 79 deletions(-) create mode 100644 tests/utils/test_episodes.py diff --git a/mushroom_rl/core/array_backend.py b/mushroom_rl/core/array_backend.py index bd0a8c48..c6f4f67a 100644 --- a/mushroom_rl/core/array_backend.py +++ b/mushroom_rl/core/array_backend.py @@ -171,7 +171,14 @@ def shape(array): @staticmethod def full(shape, value): raise NotImplementedError - + + @staticmethod + def nonzero(array): + raise NotImplementedError + + @staticmethod + def repeat(array, repeats): + raise NotImplementedError class NumpyBackend(ArrayBackend): @staticmethod @@ -303,6 +310,14 @@ def shape(array): @staticmethod def full(shape, value): return np.full(shape, value) + + @staticmethod + def nonzero(array): + return np.flatnonzero(array) + + @staticmethod + def repeat(array, repeats): + return np.repeat(array, repeats) class TorchBackend(ArrayBackend): @@ -443,6 +458,14 @@ def shape(array): @staticmethod def full(shape, value): return torch.full(shape, value) + + @staticmethod + def nonzero(array): + return torch.nonzero(array) + + @staticmethod + def repeat(array, repeats): + return torch.repeat_interleave(array, repeats) class ListBackend(ArrayBackend): diff --git a/mushroom_rl/rl_utils/value_functions.py b/mushroom_rl/rl_utils/value_functions.py index b82611e1..49bbf667 100644 --- a/mushroom_rl/rl_utils/value_functions.py +++ b/mushroom_rl/rl_utils/value_functions.py @@ -101,14 +101,17 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam): v = V(s) v_next = V(ss) - v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r, absorbing) + v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r.float(), absorbing) gen_adv_ep = torch.zeros_like(v_ep) + diff = r_ep - v_ep + v_next_discounted = (1 - absorbing_ep.int()) * gamma * v_next_ep for rev_k in range(v_ep.shape[-1]): k = v_ep.shape[-1] - rev_k - 1 if rev_k == 0: - gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] + gen_adv_ep[..., k] = diff[..., k] + v_next_discounted[..., k] else: - gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] + gamma * lam * gen_adv_ep[..., k + 1] + last_adv = gamma * lam * gen_adv_ep[..., k + 1] + gen_adv_ep[..., k] = diff[..., k] + v_next_discounted[..., k] + last_adv gen_adv = unsplit_episodes(last, gen_adv_ep).unsqueeze(-1) diff --git a/mushroom_rl/utils/episodes.py b/mushroom_rl/utils/episodes.py index 5552ce67..7bfd5021 100644 --- a/mushroom_rl/utils/episodes.py +++ b/mushroom_rl/utils/episodes.py @@ -1,22 +1,20 @@ -import torch -import numpy +from mushroom_rl.core.array_backend import ArrayBackend def split_episodes(last, *arrays): """ Split a array from shape (n_steps) to (n_episodes, max_episode_steps). """ + backend = ArrayBackend.get_array_backend_from(last) if last.sum().item() <= 1: return arrays if len(arrays) > 1 else arrays[0] - row_idx, colum_idx, n_episodes, max_episode_steps = _torch_get_episode_idx(last) if type(last) == torch.Tensor else _numpy_get_episode_idx(last) + row_idx, colum_idx, n_episodes, max_episode_steps = _get_episode_idx(last, backend) episodes_arrays = [] for array in arrays: - if type(last) == torch.Tensor: - array_ep = torch.zeros((n_episodes, max_episode_steps, *array.shape[1:]), dtype=array.dtype, device=array.device) - else: - array_ep = numpy.zeros((n_episodes, max_episode_steps, *array.shape[1:]), dtype=array.dtype) + array_ep = backend.zeros(n_episodes, max_episode_steps, *array.shape[1:], dtype=array.dtype, device=array.device if hasattr(array, 'device') else None) + array_ep[row_idx, colum_idx] = array episodes_arrays.append(array_ep) @@ -30,8 +28,7 @@ def unsplit_episodes(last, *episodes_arrays): if last.sum().item() <= 1: return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0] - row_idx, colum_idx, _, _ = _torch_get_episode_idx(last) if type(last) == torch.Tensor else _numpy_get_episode_idx(last) - + row_idx, colum_idx, _, _ = _get_episode_idx(last) arrays = [] for episode_array in episodes_arrays: @@ -40,28 +37,25 @@ def unsplit_episodes(last, *episodes_arrays): return arrays if len(arrays) > 1 else arrays[0] -def _torch_get_episode_idx(last): - - n_episodes = last.sum().item() - last_idx = torch.nonzero(last).squeeze() - episode_steps = torch.cat([torch.tensor([last_idx[0] + 1], device=last.device), last_idx[1:] - last_idx[:-1]]) - max_episode_steps = episode_steps.max().item() - - start_idx = torch.cat([torch.tensor([0], device=last.device), last_idx[:-1] + 1]) - row_idx = torch.arange(n_episodes, device=episode_steps.device).repeat_interleave(episode_steps) - colum_idx = torch.arange(last.shape[0], device=last.device) - start_idx[row_idx] - - return row_idx, colum_idx, n_episodes, max_episode_steps - -def _numpy_get_episode_idx(last): - - n_episodes = numpy.sum(last) - last_idx = numpy.flatnonzero(last) - episode_steps = numpy.concatenate(([last_idx[0] + 1], last_idx[1:] - last_idx[:-1])) - max_episode_steps = numpy.max(episode_steps) - - start_idx = numpy.concatenate(([0], last_idx[:-1] + 1)) - row_idx = numpy.repeat(numpy.arange(n_episodes), episode_steps) - column_idx = numpy.arange(last.shape[0]) - start_idx[row_idx] - - return row_idx, column_idx, n_episodes, max_episode_steps +def _get_episode_idx(last, backend=None): + if backend is None: + backend = ArrayBackend.get_array_backend_from(last) + + n_episodes = last.sum() + last_idx = backend.nonzero(last).squeeze() + first_steps = backend.from_list([last_idx[0] + 1]) + if hasattr(last, 'device'): + first_steps = first_steps.to(last.device) + episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]]) + max_episode_steps = episode_steps.max() + + start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1]) + range_n_episodes = backend.arange(0, n_episodes, dtype=int) + range_len = backend.arange(0, last.shape[0], dtype=int) + if hasattr(last, 'device'): + range_n_episodes = range_n_episodes.to(last.device) + range_len = range_len.to(last.device) + row_idx = backend.repeat(range_n_episodes, episode_steps) + colum_idx = range_len - start_idx[row_idx] + + return row_idx, colum_idx, n_episodes, max_episode_steps \ No newline at end of file diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py index 191e934d..847dedfe 100644 --- a/tests/core/test_dataset.py +++ b/tests/core/test_dataset.py @@ -128,37 +128,4 @@ def test_dataset_loading(tmpdir): assert len(dataset.info) == len(new_dataset.info) for key in dataset.info: - assert np.array_equal(dataset.info[key], new_dataset.info[key]) - -def test_compute_J(): - def compute_J(self, gamma=1.): - js = list() - - j = 0. - episode_steps = 0 - for i in range(len(self)): - j += gamma ** episode_steps * self.reward[i] - episode_steps += 1 - if self.last[i] or i == len(self) - 1: - js.append(j) - j = 0. - episode_steps = 0 - - if len(js) == 0: - js = [0.] - - return self._array_backend.from_list(js) - mdp = GridWorld(3, 3, (2, 2)) - dataset = generate_dataset(mdp, 100) - - correct_R = compute_J(dataset) - R = dataset.compute_J() - - assert np.allclose(R, correct_R) - - correct_J = compute_J(dataset, 0.9) - J = dataset.compute_J(0.9) - - assert np.allclose(J, correct_J) - -test_compute_J() \ No newline at end of file + assert np.array_equal(dataset.info[key], new_dataset.info[key]) \ No newline at end of file diff --git a/tests/rl_utils/test_value_functions.py b/tests/rl_utils/test_value_functions.py index ef5581e1..20e87890 100644 --- a/tests/rl_utils/test_value_functions.py +++ b/tests/rl_utils/test_value_functions.py @@ -6,6 +6,7 @@ from mushroom_rl.approximators.parametric import LinearApproximator, TorchApproximator from mushroom_rl.rl_utils.value_functions import compute_gae, compute_advantage_montecarlo +from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes def test_compute_advantage_montecarlo(): def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma): @@ -24,11 +25,13 @@ def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma): adv = q - v return q[:, None], adv[:, None] + torch.manual_seed(42) test_value_functions(compute_advantage_montecarlo, advantage_montecarlo, 0.99) - + def test_compute_gae(): def gae(V, s, ss, r, absorbing, last, gamma, lam): with torch.no_grad(): + r = r.float() v = V(s) v_next = V(ss) gen_adv = torch.empty_like(v) @@ -39,16 +42,20 @@ def gae(V, s, ss, r, absorbing, last, gamma, lam): if not absorbing[k]: gen_adv[k] += gamma * v_next[k] else: - gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1] + diff = r[k] - v[k] + v_next_discounted = gamma * v_next[k] + last_adv = gamma * lam * gen_adv[k + 1] + gen_adv[k] = diff + v_next_discounted + last_adv return gen_adv + v, gen_adv - + + torch.manual_seed(42) test_value_functions(compute_gae, gae, 0.99, 0.95) def test_value_functions(test_fun, correct_fun, *args): mdp = Segway() V = Regressor(TorchApproximator, input_shape=mdp.info.observation_space.shape, output_shape=(1,), network=Net, loss=torch.nn.MSELoss(), optimizer={'class': torch.optim.Adam, 'params': {'lr': 0.001}}) - state, action, reward, next_state, absorbing, last = get_episodes(mdp) + state, action, reward, next_state, absorbing, last = get_episodes(mdp, 10) correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args) v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args) @@ -87,6 +94,3 @@ def __init__(self, input_shape, output_shape, **kwargs): def forward(self, x): return self._q(x.float()) - -# test_compute_gae() -test_compute_advantage_montecarlo() \ No newline at end of file diff --git a/tests/utils/test_episodes.py b/tests/utils/test_episodes.py new file mode 100644 index 00000000..8f6f895d --- /dev/null +++ b/tests/utils/test_episodes.py @@ -0,0 +1,60 @@ +import torch +import numpy as np + +from mushroom_rl.core import Core, Agent +from mushroom_rl.approximators import Regressor +from mushroom_rl.policy import DeterministicPolicy +from mushroom_rl.approximators.parametric import LinearApproximator +from mushroom_rl.environments import Segway + +from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes + +def test_torch_split(): + torch.manual_seed(42) + mdp = Segway() + state, action, reward, next_state, absorbing, last = get_episodes(mdp) + + ep_arrays = split_episodes(last, state, action, reward, next_state, absorbing, last) + un_state, un_action, un_reward, un_next_state, un_absorbing, un_last = unsplit_episodes(last, *ep_arrays) + + assert torch.allclose(state, un_state) + assert torch.allclose(action, un_action) + assert torch.allclose(reward, un_reward) + assert torch.allclose(next_state, un_next_state) + assert torch.allclose(absorbing, un_absorbing) + assert torch.allclose(last, un_last) + +def test_numpy_split(): + torch.manual_seed(42) + np.random.seed(42) + + mdp = Segway() + state, action, reward, next_state, absorbing, last = get_episodes(mdp) + + state, action, reward, next_state, absorbing, last = state.numpy(), action.numpy(), reward.numpy(), next_state.numpy(), absorbing.numpy(), last.numpy() + + ep_arrays = split_episodes(last, state, action, reward, next_state, absorbing, last) + un_state, un_action, un_reward, un_next_state, un_absorbing, un_last = unsplit_episodes(last, *ep_arrays) + + assert np.allclose(state, un_state) + assert np.allclose(action, un_action) + assert np.allclose(reward, un_reward) + assert np.allclose(next_state, un_next_state) + assert np.allclose(absorbing, un_absorbing) + assert np.allclose(last, un_last) + +def get_episodes(mdp, n_episodes=100): + mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0) + + approximator = Regressor(LinearApproximator, + input_shape=mdp.info.observation_space.shape, + output_shape=mdp.info.action_space.shape, + weights=mu) + + policy = DeterministicPolicy(approximator) + + agent = Agent(mdp.info, policy) + core = Core(agent, mdp) + dataset = core.evaluate(n_episodes=n_episodes) + + return dataset.parse(to='torch') From 8eec8652b64611f22538856f319ccbdb1a736364 Mon Sep 17 00:00:00 2001 From: Paolo Magliano Date: Tue, 28 Jan 2025 17:54:22 +0100 Subject: [PATCH 3/8] Fix dataset parse --- mushroom_rl/core/array_backend.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mushroom_rl/core/array_backend.py b/mushroom_rl/core/array_backend.py index c6f4f67a..cf752730 100644 --- a/mushroom_rl/core/array_backend.py +++ b/mushroom_rl/core/array_backend.py @@ -195,7 +195,12 @@ def to_numpy(array): @staticmethod def to_torch(array): - return None if array is None else torch.from_numpy(array).to(TorchUtils.get_device()) + if array is None: + return None + else: + if array.dtype == np.float64: + array = array.astype(np.float32) + return torch.from_numpy(array).to(TorchUtils.get_device()) @staticmethod def convert_to_backend(cls, array): From be1d32bc949ad332134e3e9b8035168a318066c2 Mon Sep 17 00:00:00 2001 From: Paolo Magliano Date: Tue, 28 Jan 2025 17:55:05 +0100 Subject: [PATCH 4/8] Fix gae sign --- mushroom_rl/rl_utils/value_functions.py | 9 +++------ tests/rl_utils/test_value_functions.py | 16 ++++++---------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/mushroom_rl/rl_utils/value_functions.py b/mushroom_rl/rl_utils/value_functions.py index 49bbf667..b82611e1 100644 --- a/mushroom_rl/rl_utils/value_functions.py +++ b/mushroom_rl/rl_utils/value_functions.py @@ -101,17 +101,14 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam): v = V(s) v_next = V(ss) - v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r.float(), absorbing) + v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r, absorbing) gen_adv_ep = torch.zeros_like(v_ep) - diff = r_ep - v_ep - v_next_discounted = (1 - absorbing_ep.int()) * gamma * v_next_ep for rev_k in range(v_ep.shape[-1]): k = v_ep.shape[-1] - rev_k - 1 if rev_k == 0: - gen_adv_ep[..., k] = diff[..., k] + v_next_discounted[..., k] + gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] else: - last_adv = gamma * lam * gen_adv_ep[..., k + 1] - gen_adv_ep[..., k] = diff[..., k] + v_next_discounted[..., k] + last_adv + gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] + gamma * lam * gen_adv_ep[..., k + 1] gen_adv = unsplit_episodes(last, gen_adv_ep).unsqueeze(-1) diff --git a/tests/rl_utils/test_value_functions.py b/tests/rl_utils/test_value_functions.py index 20e87890..b213c452 100644 --- a/tests/rl_utils/test_value_functions.py +++ b/tests/rl_utils/test_value_functions.py @@ -26,12 +26,11 @@ def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma): return q[:, None], adv[:, None] torch.manual_seed(42) - test_value_functions(compute_advantage_montecarlo, advantage_montecarlo, 0.99) + _value_functions_tester(compute_advantage_montecarlo, advantage_montecarlo, 0.99) def test_compute_gae(): def gae(V, s, ss, r, absorbing, last, gamma, lam): with torch.no_grad(): - r = r.float() v = V(s) v_next = V(ss) gen_adv = torch.empty_like(v) @@ -42,20 +41,17 @@ def gae(V, s, ss, r, absorbing, last, gamma, lam): if not absorbing[k]: gen_adv[k] += gamma * v_next[k] else: - diff = r[k] - v[k] - v_next_discounted = gamma * v_next[k] - last_adv = gamma * lam * gen_adv[k + 1] - gen_adv[k] = diff + v_next_discounted + last_adv + gen_adv[k] = r[k] - v[k] + gamma * v_next[k] + gamma * lam * gen_adv[k + 1] return gen_adv + v, gen_adv torch.manual_seed(42) - test_value_functions(compute_gae, gae, 0.99, 0.95) + _value_functions_tester(compute_gae, gae, 0.99, 0.95) -def test_value_functions(test_fun, correct_fun, *args): +def _value_functions_tester(test_fun, correct_fun, *args): mdp = Segway() V = Regressor(TorchApproximator, input_shape=mdp.info.observation_space.shape, output_shape=(1,), network=Net, loss=torch.nn.MSELoss(), optimizer={'class': torch.optim.Adam, 'params': {'lr': 0.001}}) - state, action, reward, next_state, absorbing, last = get_episodes(mdp, 10) + state, action, reward, next_state, absorbing, last = _get_episodes(mdp, 10) correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args) v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args) @@ -71,7 +67,7 @@ def test_value_functions(test_fun, correct_fun, *args): assert torch.allclose(v, correct_v) assert torch.allclose(adv, correct_adv) -def get_episodes(mdp, n_episodes=100): +def _get_episodes(mdp, n_episodes=100): mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0) approximator = Regressor(LinearApproximator, From 4890af029ac7b3cc81d71ae0995d7851afa1cf6a Mon Sep 17 00:00:00 2001 From: Paolo Magliano Date: Wed, 29 Jan 2025 14:09:09 +0100 Subject: [PATCH 5/8] Fix a2c test --- tests/algorithms/test_a2c.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/algorithms/test_a2c.py b/tests/algorithms/test_a2c.py index 6b3babaa..d3e7b95b 100644 --- a/tests/algorithms/test_a2c.py +++ b/tests/algorithms/test_a2c.py @@ -75,7 +75,7 @@ def test_a2c(): agent = learn_a2c() w = agent.policy.get_weights() - w_test = np.array([0.9382279 , -1.8847059 , -0.13790752, -0.00786441]) + w_test = np.array([ 0.9389272 ,-1.8838323 ,-0.13710725,-0.00668973]) assert np.allclose(w, w_test) @@ -95,3 +95,5 @@ def test_a2c_save(tmpdir): print(save_attr, load_attr) tu.assert_eq(save_attr, load_attr) + +test_a2c() \ No newline at end of file From ef3667f7fcf37a27b90db9a6fbd353487ad61c6b Mon Sep 17 00:00:00 2001 From: Paolo Magliano Date: Wed, 5 Feb 2025 18:20:07 +0100 Subject: [PATCH 6/8] Fix last element of last --- mushroom_rl/utils/episodes.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mushroom_rl/utils/episodes.py b/mushroom_rl/utils/episodes.py index 7bfd5021..96a877a9 100644 --- a/mushroom_rl/utils/episodes.py +++ b/mushroom_rl/utils/episodes.py @@ -41,18 +41,21 @@ def _get_episode_idx(last, backend=None): if backend is None: backend = ArrayBackend.get_array_backend_from(last) + last = backend.copy(last) + last[-1] = True + n_episodes = last.sum() last_idx = backend.nonzero(last).squeeze() first_steps = backend.from_list([last_idx[0] + 1]) - if hasattr(last, 'device'): + if backend == 'torch': first_steps = first_steps.to(last.device) episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]]) max_episode_steps = episode_steps.max() - start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1]) + start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if backend == 'torch' else None), last_idx[:-1] + 1]) range_n_episodes = backend.arange(0, n_episodes, dtype=int) range_len = backend.arange(0, last.shape[0], dtype=int) - if hasattr(last, 'device'): + if backend == 'torch': range_n_episodes = range_n_episodes.to(last.device) range_len = range_len.to(last.device) row_idx = backend.repeat(range_n_episodes, episode_steps) From a4d7fcc9b02968367aaa5e0d8020e475a89bc24b Mon Sep 17 00:00:00 2001 From: Paolo Magliano Date: Wed, 5 Feb 2025 18:28:24 +0100 Subject: [PATCH 7/8] Check backend --- mushroom_rl/utils/episodes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mushroom_rl/utils/episodes.py b/mushroom_rl/utils/episodes.py index 96a877a9..ca70ba16 100644 --- a/mushroom_rl/utils/episodes.py +++ b/mushroom_rl/utils/episodes.py @@ -47,15 +47,15 @@ def _get_episode_idx(last, backend=None): n_episodes = last.sum() last_idx = backend.nonzero(last).squeeze() first_steps = backend.from_list([last_idx[0] + 1]) - if backend == 'torch': + if backend.get_backend_name == 'torch': first_steps = first_steps.to(last.device) episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]]) max_episode_steps = episode_steps.max() - start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if backend == 'torch' else None), last_idx[:-1] + 1]) + start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if backend.get_backend_name() == 'torch' else None), last_idx[:-1] + 1]) range_n_episodes = backend.arange(0, n_episodes, dtype=int) range_len = backend.arange(0, last.shape[0], dtype=int) - if backend == 'torch': + if backend.get_backend_name() == 'torch': range_n_episodes = range_n_episodes.to(last.device) range_len = range_len.to(last.device) row_idx = backend.repeat(range_n_episodes, episode_steps) From f7ea4f3e3c10420cb22ae177d5317ff56c66c282 Mon Sep 17 00:00:00 2001 From: Paolo Magliano Date: Wed, 5 Feb 2025 18:41:25 +0100 Subject: [PATCH 8/8] Add parentesis --- mushroom_rl/utils/episodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mushroom_rl/utils/episodes.py b/mushroom_rl/utils/episodes.py index ca70ba16..f73417cf 100644 --- a/mushroom_rl/utils/episodes.py +++ b/mushroom_rl/utils/episodes.py @@ -47,7 +47,7 @@ def _get_episode_idx(last, backend=None): n_episodes = last.sum() last_idx = backend.nonzero(last).squeeze() first_steps = backend.from_list([last_idx[0] + 1]) - if backend.get_backend_name == 'torch': + if backend.get_backend_name() == 'torch': first_steps = first_steps.to(last.device) episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]]) max_episode_steps = episode_steps.max()