-
Notifications
You must be signed in to change notification settings - Fork 58
draft_retrace #695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
draft_retrace #695
Changes from 1 commit
edde2da
394a39a
074b5da
07b5929
1219ef1
5dc49be
027c817
23d5945
5cafd48
2466126
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,13 +70,16 @@ def __init__(self, | |
| self._lambda = td_lambda | ||
| self._debug_summaries = debug_summaries | ||
|
|
||
| def forward(self, experience, value, target_value): | ||
| def forward(self, experience, value, target_value,train_info = None): | ||
| """Cacluate the loss. | ||
|
|
||
| The first dimension of all the tensors is time dimension and the second | ||
| dimesion is the batch dimension. | ||
|
|
||
| Args: | ||
| train_info (sac_info or sarsa_info): in order to calculate the importance ratio | ||
| from info.action_distribution. If no input of train info and lambda is not | ||
| 0 and 1,it will use multistep method instead of retrace | ||
| experience (Experience): experience collected from ``unroll()`` or | ||
| a replay buffer. All tensors are time-major. | ||
| value (torch.Tensor): the time-major tensor for the value at each time | ||
|
|
@@ -99,15 +102,37 @@ def forward(self, experience, value, target_value): | |
| values=target_value, | ||
| step_types=experience.step_type, | ||
| discounts=experience.discount * self._gamma) | ||
| else: | ||
| elif train_info == None: | ||
| advantages = value_ops.generalized_advantage_estimation( | ||
| rewards=experience.reward, | ||
| values=target_value, | ||
| step_types=experience.step_type, | ||
| discounts=experience.discount * self._gamma, | ||
| td_lambda=self._lambda) | ||
| returns = advantages + target_value[:-1] | ||
|
|
||
| else: | ||
| scope = alf.summary.scope(self.__class__.__name__) | ||
| importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio( | ||
|
||
| action_distribution=train_info.action_distribution, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. format, line is too long
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not fixed? |
||
| collect_action_distribution=experience.rollout_info.action_distribution, | ||
| action=experience.action, | ||
| clipping_mode='capping', | ||
| importance_ratio_clipping= 0.0, | ||
| log_prob_clipping= 0.0, | ||
| scope=scope, | ||
| check_numerics=False, | ||
| debug_summaries=True) | ||
|
||
| advantages = value_ops.generalized_advantage_estimation_retrace( | ||
| importance_ratio = importance_ratio_clipped, | ||
| rewards=experience.reward, | ||
| values= value, | ||
| target_value = target_value, | ||
| step_types=experience.step_type, | ||
| discounts=experience.discount * self._gamma, | ||
| time_major = True, | ||
| td_lambda=self._lambda) | ||
| returns = advantages + value[:-1] | ||
| returns = returns.detach() | ||
| value = value[:-1] | ||
|
|
||
| if self._debug_summaries and alf.summary.should_record_summaries(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -255,3 +255,36 @@ def generalized_advantage_estimation(rewards, | |
| advs = advs.transpose(0, 1) | ||
|
|
||
| return advs.detach() | ||
| ####### add for the retrace method | ||
| def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types): | ||
|
||
| ############## compare the importance_ratio with 1 | ||
| #importance_ratio = torch.min(importance_ratio, torch.tensor(1.)) | ||
| ##### why we need this time_major, just sample distuibution? | ||
| if not time_major: | ||
| discounts = discounts.transpose(0, 1) | ||
| rewards = rewards.transpose(0, 1) | ||
| values = values.transpose(0, 1) | ||
| step_types = step_types.transpose(0, 1) | ||
| importance_ratio = importance_ratio.transpose(0,1) | ||
| target_value = target_value.transpose(0,1) | ||
|
|
||
| assert values.shape[0] >= 2, ("The sequence length needs to be " | ||
| "at least 2. Got {s}".format( | ||
| s=values.shape[0])) | ||
|
|
||
| #### calcuate the loss not very clear for this function | ||
| advs = torch.zeros_like(values) | ||
| is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32) | ||
| delta = (rewards[1:] + discounts[1:] * target_value[1:] - values[:-1]) | ||
|
|
||
|
|
||
| weighted_discounts = discounts[1:] * td_lambda * importance_ratio | ||
| with torch.no_grad(): | ||
| for t in reversed(range(rewards.shape[0] - 1)): | ||
| advs[t] = (1 - is_lasts[t]) * \ | ||
| (delta[t] + weighted_discounts[t] * advs[t + 1]) | ||
| advs = advs[:-1] | ||
| if not time_major: | ||
| advs = advs.transpose(0, 1) | ||
|
|
||
| return advs.detach() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking whether train_info is None, you should add an argument in
__init__to indicate whether use retrace.You should also change SarsaAlgorithm and SacAlgorithm to pass in train_info.