Skip to content

Commit 93fcb02

Browse files
[BugFix] Patching not applied to NonTensorData observations, like Atari's (#3091)
Co-authored-by: vmoens <[email protected]>
1 parent 6be5510 commit 93fcb02

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchrl/data/datasets/minari_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Callable
1717

1818
import torch
19-
from tensordict import PersistentTensorDict, TensorDict
19+
from tensordict import PersistentTensorDict, TensorDict, is_non_tensor
2020
from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger
2121
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
2222
from torchrl.data.datasets.utils import _get_root_dir
@@ -323,6 +323,8 @@ def _download_and_preproc(self):
323323
): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1:
324324
if val.is_empty():
325325
continue
326+
if is_non_tensor(val):
327+
continue
326328
val = _patch_info(val)
327329
td_data.set(("next", match), torch.zeros_like(val[0]))
328330
td_data.set(match, torch.zeros_like(val[0]))
@@ -370,6 +372,8 @@ def _download_and_preproc(self):
370372
if not val.shape or steps != val.shape[0] - 1:
371373
if val.is_empty():
372374
continue
375+
if is_non_tensor(val):
376+
continue
373377
val = _patch_info(val)
374378
if steps != val.shape[0] - 1:
375379
raise RuntimeError(

0 commit comments

Comments
 (0)