Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions alf/algorithms/td_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,16 @@ def __init__(self,
self._lambda = td_lambda
self._debug_summaries = debug_summaries

def forward(self, experience, value, target_value):
def forward(self, experience, value, target_value,train_info = None):
"""Cacluate the loss.

The first dimension of all the tensors is time dimension and the second
dimesion is the batch dimension.

Args:
train_info (sac_info or sarsa_info): in order to calculate the importance ratio
from info.action_distribution. If no input of train info and lambda is not
0 and 1,it will use multistep method instead of retrace
experience (Experience): experience collected from ``unroll()`` or
a replay buffer. All tensors are time-major.
value (torch.Tensor): the time-major tensor for the value at each time
Expand All @@ -99,15 +102,37 @@ def forward(self, experience, value, target_value):
values=target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma)
else:
elif train_info == None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking whether train_info is None, you should add an argument in __init__ to indicate whether use retrace.
You should also change SarsaAlgorithm and SacAlgorithm to pass in train_info.

advantages = value_ops.generalized_advantage_estimation(
rewards=experience.reward,
values=target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma,
td_lambda=self._lambda)
returns = advantages + target_value[:-1]

else:
scope = alf.summary.scope(self.__class__.__name__)
importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add space after ,

action_distribution=train_info.action_distribution,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format, line is too long

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fixed?

collect_action_distribution=experience.rollout_info.action_distribution,
action=experience.action,
clipping_mode='capping',
importance_ratio_clipping= 0.0,
log_prob_clipping= 0.0,
scope=scope,
check_numerics=False,
debug_summaries=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug_summaries= debug_summaries

advantages = value_ops.generalized_advantage_estimation_retrace(
importance_ratio = importance_ratio_clipped,
rewards=experience.reward,
values= value,
target_value = target_value,
step_types=experience.step_type,
discounts=experience.discount * self._gamma,
time_major = True,
td_lambda=self._lambda)
returns = advantages + value[:-1]
returns = returns.detach()
value = value[:-1]

if self._debug_summaries and alf.summary.should_record_summaries():
Expand Down
33 changes: 33 additions & 0 deletions alf/utils/value_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,36 @@ def generalized_advantage_estimation(rewards,
advs = advs.transpose(0, 1)

return advs.detach()
####### add for the retrace method
def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please comment following the way of other functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need unittest for this function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. line too long
  2. add space after ,
  3. comments for the function need to be added

############## compare the importance_ratio with 1
#importance_ratio = torch.min(importance_ratio, torch.tensor(1.))
##### why we need this time_major, just sample distuibution?
if not time_major:
discounts = discounts.transpose(0, 1)
rewards = rewards.transpose(0, 1)
values = values.transpose(0, 1)
step_types = step_types.transpose(0, 1)
importance_ratio = importance_ratio.transpose(0,1)
target_value = target_value.transpose(0,1)

assert values.shape[0] >= 2, ("The sequence length needs to be "
"at least 2. Got {s}".format(
s=values.shape[0]))

#### calcuate the loss not very clear for this function
advs = torch.zeros_like(values)
is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32)
delta = (rewards[1:] + discounts[1:] * target_value[1:] - values[:-1])


weighted_discounts = discounts[1:] * td_lambda * importance_ratio
with torch.no_grad():
for t in reversed(range(rewards.shape[0] - 1)):
advs[t] = (1 - is_lasts[t]) * \
(delta[t] + weighted_discounts[t] * advs[t + 1])
advs = advs[:-1]
if not time_major:
advs = advs.transpose(0, 1)

return advs.detach()