-
Notifications
You must be signed in to change notification settings - Fork 398
[Bugfix] Wrong minari download first element #3106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
003b701
1e975a8
73de83f
78797bf
0316e4d
1c0f87f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,9 @@ | |
from typing import Callable | ||
|
||
import torch | ||
from tensordict import is_non_tensor, PersistentTensorDict, TensorDict | ||
from tensordict import (PersistentTensorDict, TensorDict, set_list_to_stack, | ||
TensorDictBase, NonTensorData, NonTensorStack, is_non_tensor, is_tensor_collection) | ||
|
||
from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger | ||
from torchrl.data.datasets.common import BaseDatasetExperienceReplay | ||
from torchrl.data.datasets.utils import _get_root_dir | ||
|
@@ -281,6 +283,7 @@ def _download_and_preproc(self): | |
f"loading dataset from local Minari cache at {h5_path}" | ||
) | ||
h5_data = PersistentTensorDict.from_h5(h5_path) | ||
h5_data = h5_data.to_tensordict() | ||
|
||
else: | ||
# temporarily change the minari cache path | ||
|
@@ -304,9 +307,11 @@ def _download_and_preproc(self): | |
h5_data = PersistentTensorDict.from_h5( | ||
parent_dir / "main_data.hdf5" | ||
) | ||
h5_data = h5_data.to_tensordict() | ||
|
||
# populate the tensordict | ||
episode_dict = {} | ||
dataset_has_nontensor = False | ||
for i, (episode_key, episode) in enumerate(h5_data.items()): | ||
episode_num = int(episode_key[len("episode_") :]) | ||
episode_len = episode["actions"].shape[0] | ||
|
@@ -315,9 +320,18 @@ def _download_and_preproc(self): | |
total_steps += episode_len | ||
if i == 0: | ||
td_data.set("episode", 0) | ||
seen = set() | ||
for key, val in episode.items(): | ||
match = _NAME_MATCH[key] | ||
if match in seen: | ||
continue | ||
seen.add(match) | ||
if key in ("observations", "state", "infos"): | ||
val = episode[key] | ||
if any(isinstance(val.get(k), (NonTensorData, NonTensorStack)) for k in val.keys()): | ||
non_tensor_probe = val.clone() | ||
_extract_nontensor_fields(non_tensor_probe, recursive=True) | ||
dataset_has_nontensor = True | ||
if ( | ||
not val.shape | ||
): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: | ||
|
@@ -328,7 +342,7 @@ def _download_and_preproc(self): | |
val = _patch_info(val) | ||
td_data.set(("next", match), torch.zeros_like(val[0])) | ||
td_data.set(match, torch.zeros_like(val[0])) | ||
if key not in ("terminations", "truncations", "rewards"): | ||
elif key not in ("terminations", "truncations", "rewards"): | ||
td_data.set(match, torch.zeros_like(val[0])) | ||
else: | ||
td_data.set( | ||
|
@@ -348,6 +362,9 @@ def _download_and_preproc(self): | |
f"creating tensordict data in {self.data_path_root}: " | ||
) | ||
td_data = td_data.memmap_like(self.data_path_root) | ||
td_data = td_data.unlock_() | ||
if dataset_has_nontensor: | ||
_preallocate_nontensor_fields(td_data, episode, total_steps, name_map=_NAME_MATCH) | ||
torchrl_logger.info(f"tensordict structure: {td_data}") | ||
|
||
torchrl_logger.info( | ||
|
@@ -358,7 +375,7 @@ def _download_and_preproc(self): | |
# iterate over episodes and populate the tensordict | ||
for episode_num in sorted(episode_dict): | ||
episode_key, steps = episode_dict[episode_num] | ||
episode = h5_data.get(episode_key) | ||
episode = _patch_nontensor_data_to_stack(h5_data.get(episode_key)) | ||
idx = slice(index, (index + steps)) | ||
data_view = td_data[idx] | ||
data_view.fill_("episode", episode_num) | ||
|
@@ -379,8 +396,18 @@ def _download_and_preproc(self): | |
raise RuntimeError( | ||
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." | ||
) | ||
data_view["next", match].copy_(val[1:]) | ||
data_view[match].copy_(val[:-1]) | ||
val_next = val[1:].clone() | ||
val_copy = val[:-1].clone() | ||
|
||
non_tensors_next = _extract_nontensor_fields(val_next) | ||
non_tensors_now = _extract_nontensor_fields(val_copy) | ||
|
||
data_view["next", match].copy_(val_next) | ||
data_view[match].copy_(val_copy) | ||
|
||
data_view["next", match].update_(non_tensors_next) | ||
data_view[match].update_(non_tensors_now) | ||
|
||
elif key not in ("terminations", "truncations", "rewards"): | ||
if steps is None: | ||
steps = val.shape[0] | ||
|
@@ -413,7 +440,6 @@ def _download_and_preproc(self): | |
f"index={index} - episode num {episode_num}" | ||
) | ||
index += steps | ||
h5_data.close() | ||
# Add a "done" entry | ||
if self.split_trajs: | ||
with td_data.unlock_(): | ||
|
@@ -539,3 +565,50 @@ def _patch_info(info_td): | |
if not source.is_empty(): | ||
val_td_sel.update(source, update_batch_size=True) | ||
return val_td_sel | ||
|
||
|
||
def _patch_nontensor_data_to_stack(tensordict: TensorDictBase): | ||
"""Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack.""" | ||
for key, val in tensordict.items(): | ||
if isinstance(val, TensorDictBase): | ||
_patch_nontensor_data_to_stack(val) # in-place recursive | ||
elif isinstance(val, NonTensorData): | ||
data_list = list(val.data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we sure this makes sense? Can val be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I know, |
||
with set_list_to_stack(True): | ||
tensordict[key] = data_list | ||
return tensordict | ||
|
||
|
||
def _extract_nontensor_fields(tensordict: TensorDictBase, recursive: bool = False) -> TensorDict: | ||
"""Deletes the NonTensor fields from tensordict and returns the deleted tensordict""" | ||
extracted = {} | ||
for key in list(tensordict.keys()): | ||
val = tensordict.get(key) | ||
if is_non_tensor(val): | ||
extracted[key] = val | ||
del tensordict[key] | ||
elif recursive and is_tensor_collection(val): | ||
nested = _extract_nontensor_fields(val, recursive=True) | ||
if len(nested) > 0: | ||
extracted[key] = nested | ||
return TensorDict(extracted, batch_size=tensordict.batch_size) | ||
|
||
|
||
def _preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict): | ||
"""Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping.""" | ||
with set_list_to_stack(True): | ||
def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()): | ||
for key, val in src_td.items(): | ||
mapped_key = name_map.get(key, key) | ||
full_dst_key = prefix + (mapped_key,) | ||
|
||
if is_non_tensor(val): | ||
dummy_stack = NonTensorStack(*[total_steps for _ in range(total_steps)]) | ||
dst_td.set(full_dst_key, dummy_stack) | ||
dst_td.set(("next",) + full_dst_key, dummy_stack) | ||
|
||
elif is_tensor_collection(val): | ||
_recurse(val, dst_td, full_dst_key) | ||
|
||
_recurse(example, td_data) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this? It's a bit expensive so if we can avoid it it's better (under the hood it copies the entire dataset in memory)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is nothing I would love more than getting rid of that line. The method to change from NonTensorData to NonTensorStack is basically:
Unfortunately, if we don't get the h5_data into memory, we face this error upon rewriting each NonTensorData key.
OSError: Can't synchronously write data (no write intent on file)
If anyone knows how to avoid this, I would love to get this thing fixed in a better way.