-
Notifications
You must be signed in to change notification settings - Fork 413
First draft for modular Hindsight Experience Replay Transform #2667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -40,6 +40,7 @@ | |||||||||||||||||||||||
TensorDictBase, | ||||||||||||||||||||||||
unravel_key, | ||||||||||||||||||||||||
unravel_key_list, | ||||||||||||||||||||||||
pad_sequence, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
from tensordict.nn import dispatch, TensorDictModuleBase | ||||||||||||||||||||||||
from tensordict.utils import ( | ||||||||||||||||||||||||
|
@@ -9264,3 +9265,165 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: | |||||||||||||||||||||||
high=torch.iinfo(torch.int64).max, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
return super().transform_observation_spec(observation_spec) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class HERSubGoalSampler(Transform): | ||||||||||||||||||||||||
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. | ||||||||||||||||||||||||
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. | ||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. | ||||||||||||||||||||||||
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||
self, | ||||||||||||||||||||||||
num_samples: int = 4, | ||||||||||||||||||||||||
subgoal_idx_key: str = "subgoal_idx", | ||||||||||||||||||||||||
strategy: str = "future" | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||
): | ||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||
in_keys=None, | ||||||||||||||||||||||||
in_keys_inv=None, | ||||||||||||||||||||||||
out_keys_inv=None, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
self.num_samples = num_samples | ||||||||||||||||||||||||
self.subgoal_idx_key = subgoal_idx_key | ||||||||||||||||||||||||
self.strategy = strategy | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def forward(self, trajectories: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||
if len(trajectories.shape) == 1: | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
trajectories = trajectories.unsqueeze(0) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
batch_size, trajectory_len = trajectories.shape | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe
Suggested change
to account for batch size > 2 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At the moment I assume that we have a single trajectory or a batch of trajectories [b, t]. I am not sure what other cases there may be, but we can think about it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least we should capture if the shape has more or less than 2 dims and let people know that 2 is the minimum, and if they want more they should ask for the feature on github. |
||||||||||||||||||||||||
|
||||||||||||||||||||||||
if self.strategy == "last": | ||||||||||||||||||||||||
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size) | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
subgoal_idxs = [] | ||||||||||||||||||||||||
for i in range(batch_size): | ||||||||||||||||||||||||
|
for i in range(batch_size): | |
for i in range(batch_size.numel()): |
for batch_size with more than one dim
dtsaras marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""This module assigns the subgoal to the trajectory according to a given subgoal index. | |
Args: | |
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". | |
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". | |
""" | |
"""This module assigns the subgoal to the trajectory according to a given subgoal index. | |
Args: | |
SHOULD BE achieved_goal_key??? ===> subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". | |
SHOULD BE desired_goal_key?? ===> subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". | |
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a .. seealso::
with other related classes.
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there's a vectorized version of this? The ops seem simple enough to be executed in a vectorized way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I had given it a shot with vmap but indexing is not well supported with vmap. Once we pin down the API, I can give it a shot again.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we keep the loop, I'd rather have trajectories.unbind(0)
than indexing every element along dim 0, it will be faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we finalize the API, I will optimize things further.
dtsaras marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to modify the specs?
Does this work with replay buffer (static data) or only envs? If the latter, we should not be using forward
.
If you look at Compose
, there are a bunch of things that need to be implemented when nesting transforms, like clone
, cache eraser etc.
Perhaps we could inherit from Compose and rewrite forward
, _apply_transform
, _call
, _reset
etc such that the logic hold but the extra features are included automatically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a method that we do not need to attach to an environment but it's a data augmentation method. The gist of the augmentation is: Given a trajectory we sample some intermediate states and assume that they are the goal instead. Thus, we can get some positive rewards for hard cases.
dtsaras marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
dtsaras marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
dtsaras marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
batch_size, trajectory_length = trajectories.shape | |
*batch_size, trajectory_length = trajectories.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i in range(batch_size): | |
for i in range(batch_size.numel()): |
which also works with batch_size=torch.Size([])
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe let's create a dedicated file for these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give the command on where you would like me to put these and I will do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
envs/transforms/her.py
?