-
Notifications
You must be signed in to change notification settings - Fork 58
draft_retrace #695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
draft_retrace #695
Changes from 7 commits
edde2da
394a39a
074b5da
07b5929
1219ef1
5dc49be
027c817
23d5945
5cafd48
2466126
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ def __init__(self, | |
| td_error_loss_fn=element_wise_squared_loss, | ||
| td_lambda=0.95, | ||
| normalize_target=False, | ||
| use_retrace=False, | ||
| debug_summaries=False, | ||
| name="TDLoss"): | ||
| r"""Create a TDLoss object. | ||
|
|
@@ -46,7 +47,8 @@ def __init__(self, | |
| :math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)` | ||
| where the generalized advantage estimation is defined as: | ||
| :math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))` | ||
|
|
||
| use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss | ||
| :math:`\mathcal{R} Q(x, a):=Q(x, a)+\mathbb{E}_{\mu}\left[\sum_{t \geq 0} \gamma^{t}\left(\prod_{s=1}^{t} c_{s}\right)\left(r_{t}+\gamma \mathbb{E}_{\pi} Q\left(x_{t+1}, \cdot\right)-Q\left(x_{t}, a_{t}\right)\right)\right]` | ||
| References: | ||
|
|
||
| Schulman et al. `High-Dimensional Continuous Control Using Generalized Advantage Estimation | ||
|
|
@@ -55,6 +57,9 @@ def __init__(self, | |
| Sutton et al. `Reinforcement Learning: An Introduction | ||
| <http://incompleteideas.net/book/the-book.html>`_, Chapter 12, 2018 | ||
|
|
||
| Remi Munos et al. `Safe and efficient off-policy reinforcement learning | ||
| <https://arxiv.org/pdf/1606.02647.pdf>`_ | ||
|
|
||
| Args: | ||
| gamma (float): A discount factor for future rewards. | ||
| td_errors_loss_fn (Callable): A function for computing the TD errors | ||
|
|
@@ -76,8 +81,9 @@ def __init__(self, | |
| self._debug_summaries = debug_summaries | ||
| self._normalize_target = normalize_target | ||
| self._target_normalizer = None | ||
| self._use_retrace = use_retrace | ||
|
|
||
| def forward(self, experience, value, target_value): | ||
| def forward(self, experience, value, target_value, train_info): | ||
| """Cacluate the loss. | ||
|
|
||
| The first dimension of all the tensors is time dimension and the second | ||
|
|
@@ -91,6 +97,8 @@ def forward(self, experience, value, target_value): | |
| target_value (torch.Tensor): the time-major tensor for the value at | ||
| each time step. This is used to calculate return. ``target_value`` | ||
| can be same as ``value``. | ||
| train_info (sarsa info, sac info): information used to calcuate importance_ratio | ||
|
||
| or importance_ratio_clipped | ||
| Returns: | ||
| LossInfo: with the ``extra`` field same as ``loss``. | ||
| """ | ||
|
|
@@ -106,15 +114,38 @@ def forward(self, experience, value, target_value): | |
| values=target_value, | ||
| step_types=experience.step_type, | ||
| discounts=experience.discount * self._gamma) | ||
| else: | ||
| elif self._use_retrace == False: | ||
| advantages = value_ops.generalized_advantage_estimation( | ||
| rewards=experience.reward, | ||
| values=target_value, | ||
| step_types=experience.step_type, | ||
| discounts=experience.discount * self._gamma, | ||
| td_lambda=self._lambda) | ||
| returns = advantages + target_value[:-1] | ||
|
|
||
| else: | ||
| scope = alf.summary.scope(self.__class__.__name__) | ||
| importance_ratio, importance_ratio_clipped = value_ops.action_importance_ratio( | ||
| action_distribution=train_info.action_distribution, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. format, line is too long
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not fixed? |
||
| collect_action_distribution=experience.rollout_info. | ||
| action_distribution, | ||
| action=experience.action, | ||
| clipping_mode='capping', | ||
| importance_ratio_clipping=0.0, | ||
| log_prob_clipping=0.0, | ||
| scope=scope, | ||
| check_numerics=False, | ||
| debug_summaries=True) | ||
|
||
| advantages = value_ops.generalized_advantage_estimation_retrace( | ||
| importance_ratio=importance_ratio_clipped, | ||
| rewards=experience.reward, | ||
| values=value, | ||
| target_value=target_value, | ||
| step_types=experience.step_type, | ||
| discounts=experience.discount * self._gamma, | ||
| time_major=True, | ||
| td_lambda=self._lambda) | ||
| returns = advantages + value[:-1] | ||
| returns = returns.detach() | ||
| value = value[:-1] | ||
| if self._normalize_target: | ||
| if self._target_normalizer is None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -255,3 +255,52 @@ def generalized_advantage_estimation(rewards, | |
| advs = advs.transpose(0, 1) | ||
|
|
||
| return advs.detach() | ||
|
|
||
|
|
||
| ####### add for the retrace method | ||
| def generalized_advantage_estimation_retrace(importance_ratio, discounts, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function can be merged with |
||
| rewards, td_lambda, time_major, | ||
| values, target_value, step_types): | ||
| """ | ||
| compute the generalized advantage estimation for retrace method. Main change is adding | ||
| importance ratio | ||
|
|
||
| Args: | ||
| importance_ratio: shape is [T], scalar between [0,1]. representing importance ratio | ||
| rewards (Tensor): shape is [T, B] (or [T]) representing rewards. | ||
| values (Tensor): shape is [T,B] (or [T]) representing values. | ||
| step_types (Tensor): shape is [T,B] (or [T]) representing step types. | ||
| discounts (Tensor): shape is [T, B] (or [T]) representing discounts. | ||
| td_lambda (float): A scalar between [0, 1]. It's used for variance | ||
| reduction in temporal difference. | ||
| time_major (bool): Whether input tensors are time major. | ||
| False means input tensors have shape [B, T]. | ||
| Returns: | ||
| A tensor with shape [T-1, B] representing advantages. Shape is [B, T-1] | ||
| when time_major is false. | ||
| """ | ||
| if not time_major: | ||
| discounts = discounts.transpose(0, 1) | ||
| rewards = rewards.transpose(0, 1) | ||
| values = values.transpose(0, 1) | ||
| step_types = step_types.transpose(0, 1) | ||
| importance_ratio = importance_ratio.transpose(0, 1) | ||
| target_value = target_value.transpose(0, 1) | ||
|
|
||
| assert values.shape[0] >= 2, ("The sequence length needs to be " | ||
| "at least 2. Got {s}".format( | ||
| s=values.shape[0])) | ||
| advs = torch.zeros_like(values) | ||
| is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32) | ||
| delta = (rewards[1:] + discounts[1:] * target_value[1:] - values[:-1]) | ||
|
|
||
| weighted_discounts = discounts[1:] * td_lambda * importance_ratio | ||
| with torch.no_grad(): | ||
| for t in reversed(range(rewards.shape[0] - 1)): | ||
| advs[t] = (1 - is_lasts[t]) * \ | ||
| (delta[t] + weighted_discounts[t] * advs[t + 1]) | ||
| advs = advs[:-1] | ||
| if not time_major: | ||
| advs = advs.transpose(0, 1) | ||
|
|
||
| return advs.detach() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -172,5 +172,34 @@ def test_generalized_advantage_estimation(self): | |
| expected=expected) | ||
|
|
||
|
|
||
| class GeneralizedAdvantage_retrace_Test(unittest.TestCase): | ||
| """Tests for alf.utils.value_ops | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comments not correct |
||
| """ | ||
|
|
||
| def test_generalized_advantage_estimation_retrace(self): | ||
| values = torch.tensor([[2.] * 4], dtype=torch.float32) | ||
| step_types = torch.tensor([[StepType.MID] * 4], dtype=torch.int64) | ||
| rewards = torch.tensor([[3.] * 4], dtype=torch.float32) | ||
| discounts = torch.tensor([[0.9] * 4], dtype=torch.float32) | ||
| td_lambda = 0.6 / 0.9 | ||
| target_value = torch.tensor([[3.] * 4], dtype=torch.float32) | ||
| importance_ratio = torch.tensor([[0.8] * 3], dtype=torch.float32) | ||
| d = 3 * 0.9 + 3 - 2 | ||
| expected = torch.tensor( | ||
| [[(d * 0.6 * 0.8) * 0.6 * 0.8 + 0.6 * 0.8 * d + d, | ||
| d * 0.6 * 0.8 + d, d]], | ||
| dtype=torch.float32) | ||
| np.testing.assert_array_almost_equal( | ||
| value_ops.generalized_advantage_estimation_retrace( | ||
| rewards=rewards, | ||
| values=values, | ||
| target_value=target_value, | ||
| step_types=step_types, | ||
| discounts=discounts, | ||
| td_lambda=td_lambda, | ||
| importance_ratio=importance_ratio, | ||
| time_major=False), expected) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can change
use_retraceuseboolvalueThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update comment