Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion alf/algorithms/actor_critic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def _calc_returns_and_advantages(self, experience, value):
values=value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma,
td_lambda=self._lambda)
target_value=value,
td_lambda=self._lambda,
importance_ratio=1.0,
use_retrace=False)
advantages = tensor_utils.tensor_extend_zero(advantages)
if self._use_td_lambda_return:
returns = advantages + value
Expand Down
3 changes: 2 additions & 1 deletion alf/algorithms/ddpg_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def calc_loss(self, experience, train_info: DdpgInfo):
critic_losses[i] = self._critic_losses[i](
experience=experience,
value=train_info.critic.q_values[:, :, i, ...],
target_value=train_info.critic.target_q_values).loss
target_value=train_info.critic.target_q_values,
train_info = train_info).loss

critic_loss = math_ops.add_n(critic_losses)

Expand Down
3 changes: 3 additions & 0 deletions alf/algorithms/ppo_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def preprocess_experience(self, exp: Experience):
values=exp.rollout_info.value,
step_types=exp.step_type,
discounts=exp.discount * self._loss._gamma,
target_value=exp.rollout_info.value,
td_lambda=self._loss._lambda,
importance_ratio=1.0,
use_retrace=False,
time_major=False)
advantages = torch.cat([
advantages,
Expand Down
3 changes: 2 additions & 1 deletion alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,8 @@ def _calc_critic_loss(self, experience, train_info: SacInfo):
critic_losses.append(
l(experience=experience,
value=critic_info.critics[:, :, i, ...],
target_value=critic_info.target_critic).loss)
target_value=critic_info.target_critic,
train_info = train_info).loss)

critic_loss = math_ops.add_n(critic_losses)

Expand Down
2 changes: 1 addition & 1 deletion alf/algorithms/sarsa_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def calc_loss(self, experience, info: SarsaInfo):
target_critic = tensor_utils.tensor_prepend_zero(
info.target_critics)
loss_info = self._critic_losses[i](shifted_experience, critic,
target_critic)
target_critic, info)
critic_losses.append(nest_map(lambda l: l[:-1], loss_info.loss))

critic_loss = math_ops.add_n(critic_losses)
Expand Down
59 changes: 56 additions & 3 deletions alf/algorithms/td_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self,
td_error_loss_fn=element_wise_squared_loss,
td_lambda=0.95,
normalize_target=False,
use_retrace=False,
debug_summaries=False,
name="TDLoss"):
r"""Create a TDLoss object.
Expand All @@ -46,7 +47,8 @@ def __init__(self,
:math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)`
where the generalized advantage estimation is defined as:
:math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))`

use_retrace = False means one step or multi_step loss, use_retrace = True means retrace loss
:math:`\mathcal{R} Q(x, a):=Q(x, a)+\mathbb{E}_{\mu}\left[\sum_{t \geq 0} \gamma^{t}\left(\prod_{s=1}^{t} c_{s}\right)\left(r_{t}+\gamma \mathbb{E}_{\pi} Q\left(x_{t+1}, \cdot\right)-Q\left(x_{t}, a_{t}\right)\right)\right]`
References:

Schulman et al. `High-Dimensional Continuous Control Using Generalized Advantage Estimation
Expand All @@ -55,6 +57,9 @@ def __init__(self,
Sutton et al. `Reinforcement Learning: An Introduction
<http://incompleteideas.net/book/the-book.html>`_, Chapter 12, 2018

Remi Munos et al. `Safe and efficient off-policy reinforcement learning
<https://arxiv.org/pdf/1606.02647.pdf>`_

Args:
gamma (float): A discount factor for future rewards.
td_errors_loss_fn (Callable): A function for computing the TD errors
Expand All @@ -76,8 +81,9 @@ def __init__(self,
self._debug_summaries = debug_summaries
self._normalize_target = normalize_target
self._target_normalizer = None
self._use_retrace = use_retrace

def forward(self, experience, value, target_value):
def forward(self, experience, value, target_value, train_info):
"""Cacluate the loss.

The first dimension of all the tensors is time dimension and the second
Expand All @@ -91,6 +97,11 @@ def forward(self, experience, value, target_value):
target_value (torch.Tensor): the time-major tensor for the value at
each time step. This is used to calculate return. ``target_value``
can be same as ``value``.
train_info : train_info includes action distrbution, actor, critic and
other information. Different algorithm may have different info inside.
For the retrace method, we can use SarsaInfo, SacInfo or DdpgInfo as train_info
for Sac, Sarsa or Ddpg algorithm. Adding train_info to calculate importance_ratio
and importance_ratio_clipped.
Returns:
LossInfo: with the ``extra`` field same as ``loss``.
"""
Expand All @@ -106,15 +117,57 @@ def forward(self, experience, value, target_value):
values=target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma)
else:
elif self._use_retrace == False:
scope = alf.summary.scope(self.__class__.__name__)
importance_ratio, importance_ratio_clipped = value_ops. \
action_importance_ratio(
action_distribution=train_info.action_distribution,
collect_action_distribution=experience.rollout_info.
action_distribution,
action=experience.action,
clipping_mode='capping',
importance_ratio_clipping=0.0,
log_prob_clipping=0.0,
scope=scope,
check_numerics=False,
debug_summaries=self._debug_summaries)
advantages = value_ops.generalized_advantage_estimation(
rewards=experience.reward,
values=target_value,
step_types=experience.step_type,
target_value=target_value,
importance_ratio=importance_ratio,
use_retrace=False,
discounts=experience.discount * self._gamma,
td_lambda=self._lambda)
returns = advantages + target_value[:-1]
else:
scope = alf.summary.scope(self.__class__.__name__)
importance_ratio, importance_ratio_clipped = value_ops. \
action_importance_ratio(
action_distribution=train_info.action_distribution,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format, line is too long

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fixed?

collect_action_distribution=experience.rollout_info.
action_distribution,
action=experience.action,
clipping_mode='capping',
importance_ratio_clipping=0.0,
log_prob_clipping=0.0,
scope=scope,
check_numerics=False,
debug_summaries=self._debug_summaries)
advantages = value_ops.generalized_advantage_estimation(
importance_ratio=importance_ratio_clipped,
rewards=experience.reward,
values=value,
target_value=target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma,
use_retrace=True,
time_major=True,
td_lambda=self._lambda)

returns = advantages + value[:-1]
returns = returns.detach()
value = value[:-1]
if self._normalize_target:
if self._target_normalizer is None:
Expand Down
4 changes: 2 additions & 2 deletions alf/examples/carla.gin
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import alf
import alf.algorithms.merlin_algorithm
import alf.environments.suite_carla

CameraSensor.image_size_x=200
CameraSensor.image_size_y=100
CameraSensor.image_size_x=128
CameraSensor.image_size_y=64
CameraSensor.fov=135

create_environment.env_name='Town01'
Expand Down
69 changes: 66 additions & 3 deletions alf/utils/value_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ def generalized_advantage_estimation(rewards,
values,
step_types,
discounts,
target_value,
importance_ratio,
use_retrace=False,
td_lambda=1.0,
time_major=True):
"""Computes generalized advantage estimation (GAE) for the first T-1 steps.
Expand Down Expand Up @@ -231,6 +234,8 @@ def generalized_advantage_estimation(rewards,
rewards = rewards.transpose(0, 1)
values = values.transpose(0, 1)
step_types = step_types.transpose(0, 1)
importance_ratio = importance_ratio.transpose(0, 1)
target_value = target_value.transpose(0, 1)

assert values.shape[0] >= 2, ("The sequence length needs to be "
"at least 2. Got {s}".format(
Expand All @@ -240,18 +245,76 @@ def generalized_advantage_estimation(rewards,
is_lasts = common.expand_dims_as(is_lasts, values)
discounts = common.expand_dims_as(discounts, values)

weighted_discounts = discounts[1:] * td_lambda
advs = torch.zeros_like(values)
if use_retrace == False:
weighted_discounts = discounts[1:] * td_lambda
delta = rewards[1:] + discounts[1:] * values[1:] - values[:-1]
with torch.no_grad():
for t in reversed(range(rewards.shape[0] - 1)):
advs[t] = (1 - is_lasts[t]) * \
(delta[t] + weighted_discounts[t] * advs[t + 1])
advs = advs[:-1]
else:
delta = (rewards[1:] + discounts[1:] * target_value[1:] - values[:-1])
weighted_discounts = discounts[1:] * td_lambda * importance_ratio
with torch.no_grad():
for t in reversed(range(rewards.shape[0] - 1)):
advs[t] = (1 - is_lasts[t]) * \
(delta[t] + weighted_discounts[t] * advs[t + 1])
advs = advs[:-1]

if not time_major:
advs = advs.transpose(0, 1)

return advs.detach()


'''
# add for the retrace method
def generalized_advantage_estimation_retrace(importance_ratio, discounts,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can be merged with generalized_advantage_estimation function

rewards, td_lambda, time_major,
values, target_value, step_types):
"""
compute the generalized advantage estimation for retrace method. Main change is adding
importance ratio

Args:
importance_ratio: shape is [T], scalar between [0,1]. Representing importance ratio
rewards (Tensor): shape is [T, B] (or [T]) representing rewards.
values (Tensor): shape is [T,B] (or [T]) representing values.
step_types (Tensor): shape is [T,B] (or [T]) representing step types.
discounts (Tensor): shape is [T, B] (or [T]) representing discounts.
td_lambda (float): A scalar between [0, 1]. It's used for variance
reduction in temporal difference.
time_major (bool): Whether input tensors are time major.
False means input tensors have shape [B, T].
Returns:
A tensor with shape [T-1, B] representing advantages. Shape is [B, T-1]
when time_major is false.
"""
if not time_major:
discounts = discounts.transpose(0, 1)
rewards = rewards.transpose(0, 1)
values = values.transpose(0, 1)
step_types = step_types.transpose(0, 1)
importance_ratio = importance_ratio.transpose(0, 1)
target_value = target_value.transpose(0, 1)

assert values.shape[0] >= 2, ("The sequence length needs to be "
"at least 2. Got {s}".format(
s=values.shape[0]))
advs = torch.zeros_like(values)
delta = rewards[1:] + discounts[1:] * values[1:] - values[:-1]
is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32)
delta = (rewards[1:] + discounts[1:] * target_value[1:] - values[:-1])

weighted_discounts = discounts[1:] * td_lambda * importance_ratio
with torch.no_grad():
for t in reversed(range(rewards.shape[0] - 1)):
advs[t] = (1 - is_lasts[t]) * \
(delta[t] + weighted_discounts[t] * advs[t + 1])
advs = advs[:-1]

if not time_major:
advs = advs.transpose(0, 1)

return advs.detach()
'''
53 changes: 50 additions & 3 deletions alf/utils/value_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@ class GeneralizedAdvantageTest(unittest.TestCase):
"""Tests for alf.utils.value_ops.generalized_advantage_estimation
"""

def _check(self, rewards, values, step_types, discounts, td_lambda,
expected):
def _check(self, rewards, values, step_types, discounts, target_value,
importance_ratio, use_retrace, td_lambda, expected):
np.testing.assert_array_almost_equal(
value_ops.generalized_advantage_estimation(
rewards=rewards,
values=values,
step_types=step_types,
discounts=discounts,
target_value=target_value,
importance_ratio=importance_ratio,
use_retrace=use_retrace,
td_lambda=td_lambda,
time_major=False), expected)

Expand All @@ -113,6 +116,9 @@ def _check(self, rewards, values, step_types, discounts, td_lambda,
values=torch.stack([values, 2 * values], dim=2),
step_types=step_types,
discounts=discounts,
importance_ratio=importance_ratio,
target_value=target_value,
use_retrace=use_retrace,
td_lambda=td_lambda,
time_major=False),
torch.stack([expected, 2 * expected], dim=2),
Expand All @@ -124,7 +130,9 @@ def test_generalized_advantage_estimation(self):
rewards = torch.tensor([[3.] * 5], dtype=torch.float32)
discounts = torch.tensor([[0.9] * 5], dtype=torch.float32)
td_lambda = 0.6 / 0.9

target_value = torch.tensor([[3.] * 4], dtype=torch.float32)
importance_ratio = torch.tensor([[0.8] * 3], dtype=torch.float32)
use_retrace = False
d = 2 * 0.9 + 1
expected = torch.tensor([[((d * 0.6 + d) * 0.6 + d) * 0.6 + d,
(d * 0.6 + d) * 0.6 + d, d * 0.6 + d, d]],
Expand All @@ -134,7 +142,10 @@ def test_generalized_advantage_estimation(self):
values=values,
step_types=step_types,
discounts=discounts,
importance_ratio=importance_ratio,
target_value=target_value,
td_lambda=td_lambda,
use_retrace=use_retrace,
expected=expected)

# two episodes, and exceed by time limit (discount=1)
Expand All @@ -150,7 +161,10 @@ def test_generalized_advantage_estimation(self):
values=values,
step_types=step_types,
discounts=discounts,
importance_ratio=importance_ratio,
target_value=target_value,
td_lambda=td_lambda,
use_retrace=use_retrace,
expected=expected)

# tow episodes, and end normal (discount=0)
Expand All @@ -169,8 +183,41 @@ def test_generalized_advantage_estimation(self):
step_types=step_types,
discounts=discounts,
td_lambda=td_lambda,
importance_ratio=importance_ratio,
target_value=target_value,
use_retrace=use_retrace,
expected=expected)


'''
class GeneralizedAdvantage_retrace_Test(unittest.TestCase):
"""Tests for alf.utils.value_ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments not correct

"""GeneralizedAdvantageTest.test_generalized_advantage_estimation()

def test_generalized_advantage_estimation_retrace(self):
values = torch.tensor([[2.] * 4], dtype=torch.float32)
step_types = torch.tensor([[StepType.MID] * 4], dtype=torch.int64)
rewards = torch.tensor([[3.] * 4], dtype=torch.float32)
discounts = torch.tensor([[0.9] * 4], dtype=torch.float32)
td_lambda = 0.6 / 0.9
target_value = torch.tensor([[3.] * 4], dtype=torch.float32)
importance_ratio = torch.tensor([[0.8] * 3], dtype=torch.float32)
d = 3 * 0.9 + 3 - 2
expected = torch.tensor(
[[(d * 0.6 * 0.8) * 0.6 * 0.8 + 0.6 * 0.8 * d + d,
d * 0.6 * 0.8 + d, d]],
dtype=torch.float32)
np.testing.assert_array_almost_equal(
value_ops.generalized_advantage_estimation_retrace(
rewards=rewards,
values=values,
target_value=target_value,
step_types=step_types,
discounts=discounts,
td_lambda=td_lambda,
importance_ratio=importance_ratio,
time_major=False), expected)
'''

if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
install_requires=[
'atari_py == 0.1.7',
'cpplint',
'clang-format == 9.0',
#'clang-format == 9.0',
'fasteners',
'gin-config@git+https://github.com/HorizonRobotics/gin-config.git',
'gym == 0.12.5',
Expand Down