Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params,
)

def fit(self, dataset):
state, action, reward, next_state, absorbing, _ = dataset.parse(to='torch')
state, action, reward, next_state, absorbing, last = dataset.parse(to='torch')

v, adv = compute_advantage_montecarlo(self._V, state, next_state,
reward, absorbing,
reward, absorbing, last,
self.mdp_info.gamma)
self._V.fit(state, v, **self._critic_fit_params)

Expand Down
32 changes: 30 additions & 2 deletions mushroom_rl/core/array_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def shape(array):
@staticmethod
def full(shape, value):
raise NotImplementedError


@staticmethod
def nonzero(array):
raise NotImplementedError

@staticmethod
def repeat(array, repeats):
raise NotImplementedError

class NumpyBackend(ArrayBackend):
@staticmethod
Expand All @@ -188,7 +195,12 @@ def to_numpy(array):

@staticmethod
def to_torch(array):
return None if array is None else torch.from_numpy(array).to(TorchUtils.get_device())
if array is None:
return None
else:
if array.dtype == np.float64:
array = array.astype(np.float32)
return torch.from_numpy(array).to(TorchUtils.get_device())

@staticmethod
def convert_to_backend(cls, array):
Expand Down Expand Up @@ -303,6 +315,14 @@ def shape(array):
@staticmethod
def full(shape, value):
return np.full(shape, value)

@staticmethod
def nonzero(array):
return np.flatnonzero(array)

@staticmethod
def repeat(array, repeats):
return np.repeat(array, repeats)


class TorchBackend(ArrayBackend):
Expand Down Expand Up @@ -443,6 +463,14 @@ def shape(array):
@staticmethod
def full(shape, value):
return torch.full(shape, value)

@staticmethod
def nonzero(array):
return torch.nonzero(array)

@staticmethod
def repeat(array, repeats):
return torch.repeat_interleave(array, repeats)

class ListBackend(ArrayBackend):

Expand Down
30 changes: 14 additions & 16 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ._impl import *

from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes

class DatasetInfo(Serializable):
def __init__(self, backend, device, horizon, gamma, state_shape, state_dtype, action_shape, action_dtype,
Expand Down Expand Up @@ -473,22 +474,19 @@ def compute_J(self, gamma=1.):
The cumulative discounted reward of each episode in the dataset.

"""
js = list()

j = 0.
episode_steps = 0
for i in range(len(self)):
j += gamma ** episode_steps * self.reward[i]
episode_steps += 1
if self.last[i] or i == len(self) - 1:
js.append(j)
j = 0.
episode_steps = 0

if len(js) == 0:
js = [0.]

return self._array_backend.from_list(js)
r_ep = split_episodes(self.last, self.reward)

if len(r_ep.shape) == 1:
r_ep = r_ep.unsqueeze(0)
if hasattr(r_ep, 'device'):
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype, device=r_ep.device)
else:
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype)

for k in range(r_ep.shape[1]):
js += gamma ** k * r_ep[..., k]

return js

def compute_metrics(self, gamma=1.):
"""
Expand Down
40 changes: 23 additions & 17 deletions mushroom_rl/rl_utils/value_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes


def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma):
def compute_advantage_montecarlo(V, s, ss, r, absorbing, last, gamma):
"""
Function to estimate the advantage and new value function target
over a dataset. The value function is estimated using rollouts
Expand All @@ -24,18 +24,21 @@ def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma):
"""
with torch.no_grad():
r = r.squeeze()
q = torch.zeros(len(r))
v = V(s).squeeze()

q_next = V(ss[-1]).squeeze().item()
for rev_k in range(len(r)):
k = len(r) - rev_k - 1
q_next = r[k] + gamma * q_next * (1 - absorbing[k].int())
q[k] = q_next
r_ep, absorbing_ep, ss_ep = split_episodes(last, r, absorbing, ss)
q_ep = torch.zeros_like(r_ep, dtype=torch.float32)
q_next_ep = V(ss_ep[..., -1, :]).squeeze()

for rev_k in range(r_ep.shape[-1]):
k = r_ep.shape[-1] - rev_k - 1
q_next_ep = r_ep[..., k] + gamma * q_next_ep * (1 - absorbing_ep[..., k].int())
q_ep[..., k] = q_next_ep

q = unsplit_episodes(last, q_ep)
adv = q - v
return q[:, None], adv[:, None]

return q[:, None], adv[:, None]

def compute_advantage(V, s, ss, r, absorbing, gamma):
"""
Expand Down Expand Up @@ -97,13 +100,16 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam):
with torch.no_grad():
v = V(s)
v_next = V(ss)
gen_adv = torch.empty_like(v)
for rev_k in range(len(v)):
k = len(v) - rev_k - 1
if last[k] or rev_k == 0:
gen_adv[k] = r[k] - v[k]
if not absorbing[k]:
gen_adv[k] += gamma * v_next[k]

v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r, absorbing)
gen_adv_ep = torch.zeros_like(v_ep)
for rev_k in range(v_ep.shape[-1]):
k = v_ep.shape[-1] - rev_k - 1
if rev_k == 0:
gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k]
else:
gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1]
gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] + gamma * lam * gen_adv_ep[..., k + 1]

gen_adv = unsplit_episodes(last, gen_adv_ep).unsqueeze(-1)

return gen_adv + v, gen_adv
61 changes: 61 additions & 0 deletions mushroom_rl/utils/episodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from mushroom_rl.core.array_backend import ArrayBackend

def split_episodes(last, *arrays):
"""
Split a array from shape (n_steps) to (n_episodes, max_episode_steps).
"""
backend = ArrayBackend.get_array_backend_from(last)

if last.sum().item() <= 1:
return arrays if len(arrays) > 1 else arrays[0]

row_idx, colum_idx, n_episodes, max_episode_steps = _get_episode_idx(last, backend)
episodes_arrays = []

for array in arrays:
array_ep = backend.zeros(n_episodes, max_episode_steps, *array.shape[1:], dtype=array.dtype, device=array.device if hasattr(array, 'device') else None)

array_ep[row_idx, colum_idx] = array
episodes_arrays.append(array_ep)

return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0]

def unsplit_episodes(last, *episodes_arrays):
"""
Unsplit a array from shape (n_episodes, max_episode_steps) to (n_steps).
"""

if last.sum().item() <= 1:
return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0]

row_idx, colum_idx, _, _ = _get_episode_idx(last)
arrays = []

for episode_array in episodes_arrays:
array = episode_array[row_idx, colum_idx]
arrays.append(array)

return arrays if len(arrays) > 1 else arrays[0]

def _get_episode_idx(last, backend=None):
if backend is None:
backend = ArrayBackend.get_array_backend_from(last)

n_episodes = last.sum()
last_idx = backend.nonzero(last).squeeze()
first_steps = backend.from_list([last_idx[0] + 1])
if hasattr(last, 'device'):
first_steps = first_steps.to(last.device)
episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]])
max_episode_steps = episode_steps.max()

start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1])
range_n_episodes = backend.arange(0, n_episodes, dtype=int)
range_len = backend.arange(0, last.shape[0], dtype=int)
if hasattr(last, 'device'):
range_n_episodes = range_n_episodes.to(last.device)
range_len = range_len.to(last.device)
row_idx = backend.repeat(range_n_episodes, episode_steps)
colum_idx = range_len - start_idx[row_idx]

return row_idx, colum_idx, n_episodes, max_episode_steps
4 changes: 3 additions & 1 deletion tests/algorithms/test_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_a2c():
agent = learn_a2c()

w = agent.policy.get_weights()
w_test = np.array([0.9382279 , -1.8847059 , -0.13790752, -0.00786441])
w_test = np.array([ 0.9389272 ,-1.8838323 ,-0.13710725,-0.00668973])

assert np.allclose(w, w_test)

Expand All @@ -95,3 +95,5 @@ def test_a2c_save(tmpdir):
print(save_attr, load_attr)

tu.assert_eq(save_attr, load_attr)

test_a2c()
4 changes: 1 addition & 3 deletions tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,4 @@ def test_dataset_loading(tmpdir):

assert len(dataset.info) == len(new_dataset.info)
for key in dataset.info:
assert np.array_equal(dataset.info[key], new_dataset.info[key])


assert np.array_equal(dataset.info[key], new_dataset.info[key])
92 changes: 92 additions & 0 deletions tests/rl_utils/test_value_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from mushroom_rl.policy import DeterministicPolicy
from mushroom_rl.environments.segway import Segway
from mushroom_rl.core import Core, Agent
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import LinearApproximator, TorchApproximator
from mushroom_rl.rl_utils.value_functions import compute_gae, compute_advantage_montecarlo

from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes

def test_compute_advantage_montecarlo():
def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma):
with torch.no_grad():
r = r.squeeze()
q = torch.zeros(len(r))
v = V(s).squeeze()

for rev_k in range(len(r)):
k = len(r) - rev_k - 1
if last[k] or rev_k == 0:
q_next = V(ss[k]).squeeze().item()
q_next = r[k] + gamma * q_next * (1 - absorbing[k].int())
q[k] = q_next

adv = q - v
return q[:, None], adv[:, None]

torch.manual_seed(42)
_value_functions_tester(compute_advantage_montecarlo, advantage_montecarlo, 0.99)

def test_compute_gae():
def gae(V, s, ss, r, absorbing, last, gamma, lam):
with torch.no_grad():
v = V(s)
v_next = V(ss)
gen_adv = torch.empty_like(v)
for rev_k in range(len(v)):
k = len(v) - rev_k - 1
if last[k] or rev_k == 0:
gen_adv[k] = r[k] - v[k]
if not absorbing[k]:
gen_adv[k] += gamma * v_next[k]
else:
gen_adv[k] = r[k] - v[k] + gamma * v_next[k] + gamma * lam * gen_adv[k + 1]
return gen_adv + v, gen_adv

torch.manual_seed(42)
_value_functions_tester(compute_gae, gae, 0.99, 0.95)

def _value_functions_tester(test_fun, correct_fun, *args):
mdp = Segway()
V = Regressor(TorchApproximator, input_shape=mdp.info.observation_space.shape, output_shape=(1,), network=Net, loss=torch.nn.MSELoss(), optimizer={'class': torch.optim.Adam, 'params': {'lr': 0.001}})

state, action, reward, next_state, absorbing, last = _get_episodes(mdp, 10)

correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args)
v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args)

assert torch.allclose(v, correct_v)
assert torch.allclose(adv, correct_adv)

V.fit(state, correct_v)

correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args)
v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args)

assert torch.allclose(v, correct_v)
assert torch.allclose(adv, correct_adv)

def _get_episodes(mdp, n_episodes=100):
mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0)

approximator = Regressor(LinearApproximator,
input_shape=mdp.info.observation_space.shape,
output_shape=mdp.info.action_space.shape,
weights=mu)

policy = DeterministicPolicy(approximator)

agent = Agent(mdp.info, policy)
core = Core(agent, mdp)
dataset = core.evaluate(n_episodes=n_episodes)

return dataset.parse(to='torch')

class Net(torch.nn.Module):
def __init__(self, input_shape, output_shape, **kwargs):
super().__init__()
self._q = torch.nn.Linear(input_shape[0], output_shape[0])

def forward(self, x):
return self._q(x.float())
Loading