@@ -3342,33 +3342,57 @@ def test_d4rl_iteration(self, task, split_trajs):
3342
3342
_MINARI_DATASETS = []
3343
3343
3344
3344
3345
- def _minari_selected_datasets ():
3346
- if not _has_minari or not _has_gymnasium :
3347
- return
3345
+ def _minari_init ():
3346
+ """Initialize Minari datasets list. Returns True if already initialized."""
3348
3347
global _MINARI_DATASETS
3349
- import minari
3348
+ if _MINARI_DATASETS and not all (
3349
+ isinstance (x , str ) and x .isdigit () for x in _MINARI_DATASETS
3350
+ ):
3351
+ return True # Already initialized with real dataset names
3350
3352
3351
- torch .manual_seed (0 )
3353
+ if not _has_minari or not _has_gymnasium :
3354
+ return False
3352
3355
3353
- total_keys = sorted (
3354
- minari .list_remote_datasets (latest_version = True , compatible_minari_version = True )
3355
- )
3356
- indices = torch .randperm (len (total_keys ))[:20 ]
3357
- keys = [total_keys [idx ] for idx in indices ]
3356
+ try :
3357
+ import minari
3358
+
3359
+ torch .manual_seed (0 )
3358
3360
3359
- assert len (keys ) > 5 , keys
3360
- _MINARI_DATASETS += keys
3361
+ total_keys = sorted (
3362
+ minari .list_remote_datasets (
3363
+ latest_version = True , compatible_minari_version = True
3364
+ )
3365
+ )
3366
+ indices = torch .randperm (len (total_keys ))[:20 ]
3367
+ keys = [total_keys [idx ] for idx in indices ]
3361
3368
3369
+ assert len (keys ) > 5 , keys
3370
+ _MINARI_DATASETS [:] = keys # Replace the placeholder values
3371
+ return True
3372
+ except Exception :
3373
+ return False
3362
3374
3363
- _minari_selected_datasets ()
3375
+
3376
+ # Initialize with placeholder values for parametrization
3377
+ # These will be replaced with actual dataset names when the first Minari test runs
3378
+ _MINARI_DATASETS = [str (i ) for i in range (20 )]
3364
3379
3365
3380
3366
3381
@pytest .mark .skipif (not _has_minari or not _has_gymnasium , reason = "Minari not found" )
3367
3382
@pytest .mark .slow
3368
3383
class TestMinari :
3369
3384
@pytest .mark .parametrize ("split" , [False , True ])
3370
- @pytest .mark .parametrize ("selected_dataset" , _MINARI_DATASETS )
3371
- def test_load (self , selected_dataset , split ):
3385
+ @pytest .mark .parametrize ("dataset_idx" , range (20 ))
3386
+ def test_load (self , dataset_idx , split ):
3387
+ # Initialize Minari datasets if not already done
3388
+ if not _minari_init ():
3389
+ pytest .skip ("Failed to initialize Minari datasets" )
3390
+
3391
+ # Get the actual dataset name from the initialized list
3392
+ if dataset_idx >= len (_MINARI_DATASETS ):
3393
+ pytest .skip (f"Dataset index { dataset_idx } out of range" )
3394
+
3395
+ selected_dataset = _MINARI_DATASETS [dataset_idx ]
3372
3396
torchrl_logger .info (f"dataset { selected_dataset } " )
3373
3397
data = MinariExperienceReplay (
3374
3398
selected_dataset , batch_size = 32 , split_trajs = split
0 commit comments