Skip to content

[Bugfix] Wrong minari download first element #3106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 80 additions & 6 deletions torchrl/data/datasets/minari_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import Callable

import torch
from tensordict import is_non_tensor, PersistentTensorDict, TensorDict
from tensordict import (PersistentTensorDict, TensorDict, set_list_to_stack,
TensorDictBase, NonTensorData, NonTensorStack, is_non_tensor)

from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
from torchrl.data.datasets.utils import _get_root_dir
Expand Down Expand Up @@ -281,6 +283,7 @@ def _download_and_preproc(self):
f"loading dataset from local Minari cache at {h5_path}"
)
h5_data = PersistentTensorDict.from_h5(h5_path)
h5_data = h5_data.to_tensordict()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? It's a bit expensive so if we can avoid it it's better (under the hood it copies the entire dataset in memory)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nothing I would love more than getting rid of that line. The method to change from NonTensorData to NonTensorStack is basically:

with set_list_to_stack(True):
    tensordict[key] = data_list

Unfortunately, if we don't get the h5_data into memory, we face this error upon rewriting each NonTensorData key.

OSError: Can't synchronously write data (no write intent on file)

If anyone knows how to avoid this, I would love to get this thing fixed in a better way.


else:
# temporarily change the minari cache path
Expand All @@ -304,9 +307,11 @@ def _download_and_preproc(self):
h5_data = PersistentTensorDict.from_h5(
parent_dir / "main_data.hdf5"
)
h5_data = h5_data.to_tensordict()

# populate the tensordict
episode_dict = {}
dataset_has_nontensor = False
for i, (episode_key, episode) in enumerate(h5_data.items()):
episode_num = int(episode_key[len("episode_") :])
episode_len = episode["actions"].shape[0]
Expand All @@ -315,9 +320,18 @@ def _download_and_preproc(self):
total_steps += episode_len
if i == 0:
td_data.set("episode", 0)
seen = set()
for key, val in episode.items():
match = _NAME_MATCH[key]
if match in seen:
continue
seen.add(match)
if key in ("observations", "state", "infos"):
val = episode[key]
if any(isinstance(val.get(k), (NonTensorData, NonTensorStack)) for k in val.keys()):
non_tensor_probe = val.clone()
extract_nontensor_fields(non_tensor_probe, recursive=True)
dataset_has_nontensor = True
if (
not val.shape
): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1:
Expand All @@ -328,7 +342,7 @@ def _download_and_preproc(self):
val = _patch_info(val)
td_data.set(("next", match), torch.zeros_like(val[0]))
td_data.set(match, torch.zeros_like(val[0]))
if key not in ("terminations", "truncations", "rewards"):
elif key not in ("terminations", "truncations", "rewards"):
td_data.set(match, torch.zeros_like(val[0]))
else:
td_data.set(
Expand All @@ -348,6 +362,9 @@ def _download_and_preproc(self):
f"creating tensordict data in {self.data_path_root}: "
)
td_data = td_data.memmap_like(self.data_path_root)
td_data = td_data.unlock_()
if dataset_has_nontensor:
preallocate_nontensor_fields(td_data, episode, total_steps, name_map=_NAME_MATCH)
torchrl_logger.info(f"tensordict structure: {td_data}")

torchrl_logger.info(
Expand All @@ -358,7 +375,7 @@ def _download_and_preproc(self):
# iterate over episodes and populate the tensordict
for episode_num in sorted(episode_dict):
episode_key, steps = episode_dict[episode_num]
episode = h5_data.get(episode_key)
episode = patch_nontensor_data_to_stack(h5_data.get(episode_key))
idx = slice(index, (index + steps))
data_view = td_data[idx]
data_view.fill_("episode", episode_num)
Expand All @@ -379,8 +396,18 @@ def _download_and_preproc(self):
raise RuntimeError(
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}."
)
data_view["next", match].copy_(val[1:])
data_view[match].copy_(val[:-1])
val_next = val[1:].clone()
val_copy = val[:-1].clone()

non_tensors_next = extract_nontensor_fields(val_next)
non_tensors_now = extract_nontensor_fields(val_copy)

data_view["next", match].copy_(val_next)
data_view[match].copy_(val_copy)

data_view["next", match].update_(non_tensors_next)
data_view[match].update_(non_tensors_now)

elif key not in ("terminations", "truncations", "rewards"):
if steps is None:
steps = val.shape[0]
Expand Down Expand Up @@ -413,7 +440,6 @@ def _download_and_preproc(self):
f"index={index} - episode num {episode_num}"
)
index += steps
h5_data.close()
# Add a "done" entry
if self.split_trajs:
with td_data.unlock_():
Expand Down Expand Up @@ -539,3 +565,51 @@ def _patch_info(info_td):
if not source.is_empty():
val_td_sel.update(source, update_batch_size=True)
return val_td_sel


def patch_nontensor_data_to_stack(tensordict: TensorDictBase):
"""Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack."""
for key in list(tensordict.keys()):
val = tensordict.get(key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use items()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmm I get this error when using .items()

Traceback (most recent call last):
  File "/Users/O000142/Projects/rl/torchrl/data/datasets/minari_data.py", line 585, in _extract_nontensor_fields
    for key, val in tensordict.items():
RuntimeError: dictionary changed size during iteration

This is due to the fact that we are deleting keys in the dataset as we are iterating them. Looks better with items(), but sadly doesn't work

if isinstance(val, TensorDictBase):
patch_nontensor_data_to_stack(val) # in-place recursive
elif isinstance(val, NonTensorData):
data_list = list(val.data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure this makes sense? Can val be a str? Perhaps we should check what type val.data has?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, val.data is a np.array of length equal to the number of steps in the episode. In the case of 'mission' keys, it is an array that contains what mission the agent has at each step.

with set_list_to_stack(True):
tensordict[key] = data_list
return tensordict


def extract_nontensor_fields(td: TensorDictBase, recursive: bool = False) -> TensorDict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions look similar but perform different functions. One preallocates keys in a tensordict, the other deletes NonTensorData keys and another one for transforming NonTensorData into NonTensorStack. But they all traverse the tensordict keys. I can try to refactor them a bit more later.

extracted = {}
for key in list(td.keys()):
val = td.get(key)
if isinstance(val, (NonTensorData, NonTensorStack)):
extracted[key] = val
td.del_(key)
elif recursive and isinstance(val, TensorDictBase):
nested = extract_nontensor_fields(val, recursive=True)
if len(nested) > 0:
extracted[key] = nested
return TensorDict(extracted, batch_size=td.batch_size)


def preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

"""Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping."""
with set_list_to_stack(True):
def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()):
for key, val in src_td.items():
mapped_key = name_map.get(key, key)
full_dst_key = prefix + (mapped_key,)

if isinstance(val, NonTensorData):
dummy_val = b"" if isinstance(val.data[0], bytes) else ""
dummy_stack = TensorDict({mapped_key: dummy_val}).expand(total_steps)[mapped_key]
dst_td.set(full_dst_key, dummy_stack)
dst_td.set(("next",) + full_dst_key, dummy_stack)

elif isinstance(val, TensorDictBase):
_recurse(val, dst_td, full_dst_key)

_recurse(example, td_data)

Loading