Skip to content

Commit e221c8a

Browse files
committed
not redundent compute
1 parent a2ce985 commit e221c8a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

trinity/algorithm/policy_loss_fn/ppo_policy_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def __call__( # type: ignore
9999
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
100100

101101
# Apply sequence mask to the losses
102-
pg_losses = pg_losses * sequence_mask
102+
if self.enable_sequence_masking:
103+
pg_losses = pg_losses * sequence_mask
103104

104105
pg_loss = aggregate_loss(pg_losses, action_mask, loss_agg_mode=self.loss_agg_mode)
105106
metrics = {

0 commit comments

Comments
 (0)