-
Notifications
You must be signed in to change notification settings - Fork 400
[Bugfix] Wrong minari download first element #3106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
003b701
1e975a8
73de83f
78797bf
0316e4d
1c0f87f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,9 @@ | |
from typing import Callable | ||
|
||
import torch | ||
from tensordict import is_non_tensor, PersistentTensorDict, TensorDict | ||
from tensordict import (PersistentTensorDict, TensorDict, set_list_to_stack, | ||
TensorDictBase, NonTensorData, NonTensorStack, is_non_tensor) | ||
|
||
from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger | ||
from torchrl.data.datasets.common import BaseDatasetExperienceReplay | ||
from torchrl.data.datasets.utils import _get_root_dir | ||
|
@@ -281,6 +283,7 @@ def _download_and_preproc(self): | |
f"loading dataset from local Minari cache at {h5_path}" | ||
) | ||
h5_data = PersistentTensorDict.from_h5(h5_path) | ||
h5_data = h5_data.to_tensordict() | ||
|
||
else: | ||
# temporarily change the minari cache path | ||
|
@@ -304,9 +307,11 @@ def _download_and_preproc(self): | |
h5_data = PersistentTensorDict.from_h5( | ||
parent_dir / "main_data.hdf5" | ||
) | ||
h5_data = h5_data.to_tensordict() | ||
|
||
# populate the tensordict | ||
episode_dict = {} | ||
dataset_has_nontensor = False | ||
for i, (episode_key, episode) in enumerate(h5_data.items()): | ||
episode_num = int(episode_key[len("episode_") :]) | ||
episode_len = episode["actions"].shape[0] | ||
|
@@ -315,9 +320,18 @@ def _download_and_preproc(self): | |
total_steps += episode_len | ||
if i == 0: | ||
td_data.set("episode", 0) | ||
seen = set() | ||
for key, val in episode.items(): | ||
match = _NAME_MATCH[key] | ||
if match in seen: | ||
continue | ||
seen.add(match) | ||
if key in ("observations", "state", "infos"): | ||
val = episode[key] | ||
if any(isinstance(val.get(k), (NonTensorData, NonTensorStack)) for k in val.keys()): | ||
non_tensor_probe = val.clone() | ||
extract_nontensor_fields(non_tensor_probe, recursive=True) | ||
dataset_has_nontensor = True | ||
if ( | ||
not val.shape | ||
): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: | ||
|
@@ -328,7 +342,7 @@ def _download_and_preproc(self): | |
val = _patch_info(val) | ||
td_data.set(("next", match), torch.zeros_like(val[0])) | ||
td_data.set(match, torch.zeros_like(val[0])) | ||
if key not in ("terminations", "truncations", "rewards"): | ||
elif key not in ("terminations", "truncations", "rewards"): | ||
td_data.set(match, torch.zeros_like(val[0])) | ||
else: | ||
td_data.set( | ||
|
@@ -348,6 +362,9 @@ def _download_and_preproc(self): | |
f"creating tensordict data in {self.data_path_root}: " | ||
) | ||
td_data = td_data.memmap_like(self.data_path_root) | ||
td_data = td_data.unlock_() | ||
if dataset_has_nontensor: | ||
preallocate_nontensor_fields(td_data, episode, total_steps, name_map=_NAME_MATCH) | ||
torchrl_logger.info(f"tensordict structure: {td_data}") | ||
|
||
torchrl_logger.info( | ||
|
@@ -358,7 +375,7 @@ def _download_and_preproc(self): | |
# iterate over episodes and populate the tensordict | ||
for episode_num in sorted(episode_dict): | ||
episode_key, steps = episode_dict[episode_num] | ||
episode = h5_data.get(episode_key) | ||
episode = patch_nontensor_data_to_stack(h5_data.get(episode_key)) | ||
idx = slice(index, (index + steps)) | ||
data_view = td_data[idx] | ||
data_view.fill_("episode", episode_num) | ||
|
@@ -379,8 +396,18 @@ def _download_and_preproc(self): | |
raise RuntimeError( | ||
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." | ||
) | ||
data_view["next", match].copy_(val[1:]) | ||
data_view[match].copy_(val[:-1]) | ||
val_next = val[1:].clone() | ||
val_copy = val[:-1].clone() | ||
|
||
non_tensors_next = extract_nontensor_fields(val_next) | ||
non_tensors_now = extract_nontensor_fields(val_copy) | ||
|
||
data_view["next", match].copy_(val_next) | ||
data_view[match].copy_(val_copy) | ||
|
||
data_view["next", match].update_(non_tensors_next) | ||
data_view[match].update_(non_tensors_now) | ||
|
||
elif key not in ("terminations", "truncations", "rewards"): | ||
if steps is None: | ||
steps = val.shape[0] | ||
|
@@ -413,7 +440,6 @@ def _download_and_preproc(self): | |
f"index={index} - episode num {episode_num}" | ||
) | ||
index += steps | ||
h5_data.close() | ||
# Add a "done" entry | ||
if self.split_trajs: | ||
with td_data.unlock_(): | ||
|
@@ -539,3 +565,51 @@ def _patch_info(info_td): | |
if not source.is_empty(): | ||
val_td_sel.update(source, update_batch_size=True) | ||
return val_td_sel | ||
|
||
|
||
def patch_nontensor_data_to_stack(tensordict: TensorDictBase): | ||
marcosgalleterobbva marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack.""" | ||
for key in list(tensordict.keys()): | ||
val = tensordict.get(key) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mmmm I get this error when using .items() Traceback (most recent call last):
File "/Users/O000142/Projects/rl/torchrl/data/datasets/minari_data.py", line 585, in _extract_nontensor_fields
for key, val in tensordict.items():
RuntimeError: dictionary changed size during iteration This is due to the fact that we are deleting keys in the dataset as we are iterating them. Looks better with |
||
if isinstance(val, TensorDictBase): | ||
patch_nontensor_data_to_stack(val) # in-place recursive | ||
elif isinstance(val, NonTensorData): | ||
data_list = list(val.data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we sure this makes sense? Can val be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I know, |
||
with set_list_to_stack(True): | ||
tensordict[key] = data_list | ||
return tensordict | ||
|
||
|
||
def extract_nontensor_fields(td: TensorDictBase, recursive: bool = False) -> TensorDict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These functions look similar but perform different functions. One preallocates keys in a tensordict, the other deletes NonTensorData keys and another one for transforming NonTensorData into NonTensorStack. But they all traverse the tensordict keys. I can try to refactor them a bit more later. |
||
extracted = {} | ||
for key in list(td.keys()): | ||
val = td.get(key) | ||
if isinstance(val, (NonTensorData, NonTensorStack)): | ||
marcosgalleterobbva marked this conversation as resolved.
Show resolved
Hide resolved
|
||
extracted[key] = val | ||
td.del_(key) | ||
marcosgalleterobbva marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif recursive and isinstance(val, TensorDictBase): | ||
marcosgalleterobbva marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nested = extract_nontensor_fields(val, recursive=True) | ||
if len(nested) > 0: | ||
extracted[key] = nested | ||
return TensorDict(extracted, batch_size=td.batch_size) | ||
|
||
|
||
def preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
"""Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping.""" | ||
with set_list_to_stack(True): | ||
def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()): | ||
for key, val in src_td.items(): | ||
mapped_key = name_map.get(key, key) | ||
full_dst_key = prefix + (mapped_key,) | ||
|
||
if isinstance(val, NonTensorData): | ||
marcosgalleterobbva marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dummy_val = b"" if isinstance(val.data[0], bytes) else "" | ||
dummy_stack = TensorDict({mapped_key: dummy_val}).expand(total_steps)[mapped_key] | ||
marcosgalleterobbva marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dst_td.set(full_dst_key, dummy_stack) | ||
dst_td.set(("next",) + full_dst_key, dummy_stack) | ||
|
||
elif isinstance(val, TensorDictBase): | ||
_recurse(val, dst_td, full_dst_key) | ||
|
||
_recurse(example, td_data) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this? It's a bit expensive so if we can avoid it it's better (under the hood it copies the entire dataset in memory)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is nothing I would love more than getting rid of that line. The method to change from NonTensorData to NonTensorStack is basically:
Unfortunately, if we don't get the h5_data into memory, we face this error upon rewriting each NonTensorData key.
OSError: Can't synchronously write data (no write intent on file)
If anyone knows how to avoid this, I would love to get this thing fixed in a better way.