File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change 16
16
from typing import Callable
17
17
18
18
import torch
19
- from tensordict import PersistentTensorDict , TensorDict
19
+ from tensordict import PersistentTensorDict , TensorDict , is_non_tensor
20
20
from torchrl ._utils import KeyDependentDefaultDict , logger as torchrl_logger
21
21
from torchrl .data .datasets .common import BaseDatasetExperienceReplay
22
22
from torchrl .data .datasets .utils import _get_root_dir
@@ -323,6 +323,8 @@ def _download_and_preproc(self):
323
323
): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1:
324
324
if val .is_empty ():
325
325
continue
326
+ if is_non_tensor (val ):
327
+ continue
326
328
val = _patch_info (val )
327
329
td_data .set (("next" , match ), torch .zeros_like (val [0 ]))
328
330
td_data .set (match , torch .zeros_like (val [0 ]))
@@ -370,6 +372,8 @@ def _download_and_preproc(self):
370
372
if not val .shape or steps != val .shape [0 ] - 1 :
371
373
if val .is_empty ():
372
374
continue
375
+ if is_non_tensor (val ):
376
+ continue
373
377
val = _patch_info (val )
374
378
if steps != val .shape [0 ] - 1 :
375
379
raise RuntimeError (
You can’t perform that action at this time.
0 commit comments