Skip to content

[Bugfix] Wrong minari download first element #3106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3722,6 +3722,14 @@ def test_local_minari_dataset_loading(self, tmpdir):
if MINARI_DATASETS_PATH:
os.environ["MINARI_DATASETS_PATH"] = MINARI_DATASETS_PATH

def test_correct_categorical_missions(self):
exp_replay = MinariExperienceReplay(
dataset_id="minigrid/BabyAI-Pickup/optimal-v0",
batch_size=1,
root=None,
)
assert isinstance(exp_replay[0][("observation", "mission")], (bytes, str))


@pytest.mark.slow
class TestRoboset:
Expand Down
85 changes: 79 additions & 6 deletions torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import Callable

import torch
from tensordict import is_non_tensor, PersistentTensorDict, TensorDict
from tensordict import (PersistentTensorDict, TensorDict, set_list_to_stack,
TensorDictBase, NonTensorData, NonTensorStack, is_non_tensor, is_tensor_collection)

from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
from torchrl.data.datasets.utils import _get_root_dir
Expand Down Expand Up @@ -281,6 +283,7 @@ def _download_and_preproc(self):
f"loading dataset from local Minari cache at {h5_path}"
)
h5_data = PersistentTensorDict.from_h5(h5_path)
h5_data = h5_data.to_tensordict()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? It's a bit expensive so if we can avoid it it's better (under the hood it copies the entire dataset in memory)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nothing I would love more than getting rid of that line. The method to change from NonTensorData to NonTensorStack is basically:

with set_list_to_stack(True):
    tensordict[key] = data_list

Unfortunately, if we don't get the h5_data into memory, we face this error upon rewriting each NonTensorData key.

OSError: Can't synchronously write data (no write intent on file)

If anyone knows how to avoid this, I would love to get this thing fixed in a better way.


else:
# temporarily change the minari cache path
Expand All @@ -304,9 +307,11 @@ def _download_and_preproc(self):
h5_data = PersistentTensorDict.from_h5(
parent_dir / "main_data.hdf5"
)
h5_data = h5_data.to_tensordict()

# populate the tensordict
episode_dict = {}
dataset_has_nontensor = False
for i, (episode_key, episode) in enumerate(h5_data.items()):
episode_num = int(episode_key[len("episode_") :])
episode_len = episode["actions"].shape[0]
Expand All @@ -315,9 +320,18 @@ def _download_and_preproc(self):
total_steps += episode_len
if i == 0:
td_data.set("episode", 0)
seen = set()
for key, val in episode.items():
match = _NAME_MATCH[key]
if match in seen:
continue
seen.add(match)
if key in ("observations", "state", "infos"):
val = episode[key]
if any(isinstance(val.get(k), (NonTensorData, NonTensorStack)) for k in val.keys()):
non_tensor_probe = val.clone()
_extract_nontensor_fields(non_tensor_probe, recursive=True)
dataset_has_nontensor = True
if (
not val.shape
): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1:
Expand All @@ -328,7 +342,7 @@ def _download_and_preproc(self):
val = _patch_info(val)
td_data.set(("next", match), torch.zeros_like(val[0]))
td_data.set(match, torch.zeros_like(val[0]))
if key not in ("terminations", "truncations", "rewards"):
elif key not in ("terminations", "truncations", "rewards"):
td_data.set(match, torch.zeros_like(val[0]))
else:
td_data.set(
Expand All @@ -348,6 +362,9 @@ def _download_and_preproc(self):
f"creating tensordict data in {self.data_path_root}: "
)
td_data = td_data.memmap_like(self.data_path_root)
td_data = td_data.unlock_()
if dataset_has_nontensor:
_preallocate_nontensor_fields(td_data, episode, total_steps, name_map=_NAME_MATCH)
torchrl_logger.info(f"tensordict structure: {td_data}")

torchrl_logger.info(
Expand All @@ -358,7 +375,7 @@ def _download_and_preproc(self):
# iterate over episodes and populate the tensordict
for episode_num in sorted(episode_dict):
episode_key, steps = episode_dict[episode_num]
episode = h5_data.get(episode_key)
episode = _patch_nontensor_data_to_stack(h5_data.get(episode_key))
idx = slice(index, (index + steps))
data_view = td_data[idx]
data_view.fill_("episode", episode_num)
Expand All @@ -379,8 +396,18 @@ def _download_and_preproc(self):
raise RuntimeError(
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}."
)
data_view["next", match].copy_(val[1:])
data_view[match].copy_(val[:-1])
val_next = val[1:].clone()
val_copy = val[:-1].clone()

non_tensors_next = _extract_nontensor_fields(val_next)
non_tensors_now = _extract_nontensor_fields(val_copy)

data_view["next", match].copy_(val_next)
data_view[match].copy_(val_copy)

data_view["next", match].update_(non_tensors_next)
data_view[match].update_(non_tensors_now)

elif key not in ("terminations", "truncations", "rewards"):
if steps is None:
steps = val.shape[0]
Expand Down Expand Up @@ -413,7 +440,6 @@ def _download_and_preproc(self):
f"index={index} - episode num {episode_num}"
)
index += steps
h5_data.close()
# Add a "done" entry
if self.split_trajs:
with td_data.unlock_():
Expand Down Expand Up @@ -539,3 +565,50 @@ def _patch_info(info_td):
if not source.is_empty():
val_td_sel.update(source, update_batch_size=True)
return val_td_sel


def _patch_nontensor_data_to_stack(tensordict: TensorDictBase):
"""Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack."""
for key, val in tensordict.items():
if isinstance(val, TensorDictBase):
_patch_nontensor_data_to_stack(val) # in-place recursive
elif isinstance(val, NonTensorData):
data_list = list(val.data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure this makes sense? Can val be a str? Perhaps we should check what type val.data has?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, val.data is a np.array of length equal to the number of steps in the episode. In the case of 'mission' keys, it is an array that contains what mission the agent has at each step.

with set_list_to_stack(True):
tensordict[key] = data_list
return tensordict


def _extract_nontensor_fields(tensordict: TensorDictBase, recursive: bool = False) -> TensorDict:
"""Deletes the NonTensor fields from tensordict and returns the deleted tensordict"""
extracted = {}
for key in list(tensordict.keys()):
val = tensordict.get(key)
if is_non_tensor(val):
extracted[key] = val
del tensordict[key]
elif recursive and is_tensor_collection(val):
nested = _extract_nontensor_fields(val, recursive=True)
if len(nested) > 0:
extracted[key] = nested
return TensorDict(extracted, batch_size=tensordict.batch_size)


def _preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict):
"""Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping."""
with set_list_to_stack(True):
def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()):
for key, val in src_td.items():
mapped_key = name_map.get(key, key)
full_dst_key = prefix + (mapped_key,)

if is_non_tensor(val):
dummy_stack = NonTensorStack(*[total_steps for _ in range(total_steps)])
dst_td.set(full_dst_key, dummy_stack)
dst_td.set(("next",) + full_dst_key, dummy_stack)

elif is_tensor_collection(val):
_recurse(val, dst_td, full_dst_key)

_recurse(example, td_data)

Loading