We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
done = done | truncated
1 parent 0063741 commit cdcd9f3Copy full SHA for cdcd9f3
torchrl/collectors/collectors.py
@@ -1121,7 +1121,10 @@ def _maybe_set_truncated(self, final_rollout):
1121
truncated = final_rollout["next", truncated_key]
1122
truncated[last_step] = True
1123
final_rollout["next", truncated_key] = truncated
1124
- final_rollout["next", _replace_last(truncated_key, "done")] = truncated
+ done = final_rollout["next", _replace_last(truncated_key, "done")]
1125
+ final_rollout["next", _replace_last(truncated_key, "done")] = (
1126
+ done | truncated
1127
+ )
1128
return final_rollout
1129
1130
@torch.no_grad()
0 commit comments