diff --git a/test/test_libs.py b/test/test_libs.py index 054d6ca0240..a20d391d57a 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3722,6 +3722,14 @@ def test_local_minari_dataset_loading(self, tmpdir): if MINARI_DATASETS_PATH: os.environ["MINARI_DATASETS_PATH"] = MINARI_DATASETS_PATH + def test_correct_categorical_missions(self): + exp_replay = MinariExperienceReplay( + dataset_id="minigrid/BabyAI-Pickup/optimal-v0", + batch_size=1, + root=None, + ) + assert isinstance(exp_replay[0][("observation", "mission")], (bytes, str)) + @pytest.mark.slow class TestRoboset: diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index fa715d03f46..7e11221a1b9 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -16,7 +16,9 @@ from typing import Callable import torch -from tensordict import is_non_tensor, PersistentTensorDict, TensorDict +from tensordict import (PersistentTensorDict, TensorDict, set_list_to_stack, + TensorDictBase, NonTensorData, NonTensorStack, is_non_tensor, is_tensor_collection) + from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay from torchrl.data.datasets.utils import _get_root_dir @@ -281,6 +283,7 @@ def _download_and_preproc(self): f"loading dataset from local Minari cache at {h5_path}" ) h5_data = PersistentTensorDict.from_h5(h5_path) + h5_data = h5_data.to_tensordict() else: # temporarily change the minari cache path @@ -304,9 +307,11 @@ def _download_and_preproc(self): h5_data = PersistentTensorDict.from_h5( parent_dir / "main_data.hdf5" ) + h5_data = h5_data.to_tensordict() # populate the tensordict episode_dict = {} + dataset_has_nontensor = False for i, (episode_key, episode) in enumerate(h5_data.items()): episode_num = int(episode_key[len("episode_") :]) episode_len = episode["actions"].shape[0] @@ -315,9 +320,18 @@ def _download_and_preproc(self): total_steps += episode_len if i == 0: td_data.set("episode", 0) + seen = set() for key, val in episode.items(): match = _NAME_MATCH[key] + if match in seen: + continue + seen.add(match) if key in ("observations", "state", "infos"): + val = episode[key] + if any(isinstance(val.get(k), (NonTensorData, NonTensorStack)) for k in val.keys()): + non_tensor_probe = val.clone() + _extract_nontensor_fields(non_tensor_probe, recursive=True) + dataset_has_nontensor = True if ( not val.shape ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: @@ -328,7 +342,7 @@ def _download_and_preproc(self): val = _patch_info(val) td_data.set(("next", match), torch.zeros_like(val[0])) td_data.set(match, torch.zeros_like(val[0])) - if key not in ("terminations", "truncations", "rewards"): + elif key not in ("terminations", "truncations", "rewards"): td_data.set(match, torch.zeros_like(val[0])) else: td_data.set( @@ -348,6 +362,9 @@ def _download_and_preproc(self): f"creating tensordict data in {self.data_path_root}: " ) td_data = td_data.memmap_like(self.data_path_root) + td_data = td_data.unlock_() + if dataset_has_nontensor: + _preallocate_nontensor_fields(td_data, episode, total_steps, name_map=_NAME_MATCH) torchrl_logger.info(f"tensordict structure: {td_data}") torchrl_logger.info( @@ -358,7 +375,7 @@ def _download_and_preproc(self): # iterate over episodes and populate the tensordict for episode_num in sorted(episode_dict): episode_key, steps = episode_dict[episode_num] - episode = h5_data.get(episode_key) + episode = _patch_nontensor_data_to_stack(h5_data.get(episode_key)) idx = slice(index, (index + steps)) data_view = td_data[idx] data_view.fill_("episode", episode_num) @@ -379,8 +396,18 @@ def _download_and_preproc(self): raise RuntimeError( f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." ) - data_view["next", match].copy_(val[1:]) - data_view[match].copy_(val[:-1]) + val_next = val[1:].clone() + val_copy = val[:-1].clone() + + non_tensors_next = _extract_nontensor_fields(val_next) + non_tensors_now = _extract_nontensor_fields(val_copy) + + data_view["next", match].copy_(val_next) + data_view[match].copy_(val_copy) + + data_view["next", match].update_(non_tensors_next) + data_view[match].update_(non_tensors_now) + elif key not in ("terminations", "truncations", "rewards"): if steps is None: steps = val.shape[0] @@ -413,7 +440,6 @@ def _download_and_preproc(self): f"index={index} - episode num {episode_num}" ) index += steps - h5_data.close() # Add a "done" entry if self.split_trajs: with td_data.unlock_(): @@ -539,3 +565,50 @@ def _patch_info(info_td): if not source.is_empty(): val_td_sel.update(source, update_batch_size=True) return val_td_sel + + +def _patch_nontensor_data_to_stack(tensordict: TensorDictBase): + """Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack.""" + for key, val in tensordict.items(): + if isinstance(val, TensorDictBase): + _patch_nontensor_data_to_stack(val) # in-place recursive + elif isinstance(val, NonTensorData): + data_list = list(val.data) + with set_list_to_stack(True): + tensordict[key] = data_list + return tensordict + + +def _extract_nontensor_fields(tensordict: TensorDictBase, recursive: bool = False) -> TensorDict: + """Deletes the NonTensor fields from tensordict and returns the deleted tensordict""" + extracted = {} + for key in list(tensordict.keys()): + val = tensordict.get(key) + if is_non_tensor(val): + extracted[key] = val + del tensordict[key] + elif recursive and is_tensor_collection(val): + nested = _extract_nontensor_fields(val, recursive=True) + if len(nested) > 0: + extracted[key] = nested + return TensorDict(extracted, batch_size=tensordict.batch_size) + + +def _preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict): + """Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping.""" + with set_list_to_stack(True): + def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()): + for key, val in src_td.items(): + mapped_key = name_map.get(key, key) + full_dst_key = prefix + (mapped_key,) + + if is_non_tensor(val): + dummy_stack = NonTensorStack(*[total_steps for _ in range(total_steps)]) + dst_td.set(full_dst_key, dummy_stack) + dst_td.set(("next",) + full_dst_key, dummy_stack) + + elif is_tensor_collection(val): + _recurse(val, dst_td, full_dst_key) + + _recurse(example, td_data) +