@@ -76,6 +76,7 @@ def __init__(
7676 label_filter : Optional [int ] = None ,
7777 balance_after_filter : Optional [float ] = None ,
7878 num_workers : int = 1 ,
79+ persistent_workers = True ,
7980 chebi_version : int = 200 ,
8081 inner_k_folds : int = - 1 , # use inner cross-validation if > 1
8182 fold_index : Optional [int ] = None ,
@@ -99,6 +100,7 @@ def __init__(
99100 ), "Filter balancing requires a filter"
100101 self .balance_after_filter = balance_after_filter
101102 self .num_workers = num_workers
103+ self .persistent_workers : bool = bool (persistent_workers )
102104 self .chebi_version = chebi_version
103105 assert type (inner_k_folds ) is int
104106 self .inner_k_folds = inner_k_folds
@@ -360,7 +362,7 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader
360362 "train" ,
361363 shuffle = True ,
362364 num_workers = self .num_workers ,
363- persistent_workers = True ,
365+ persistent_workers = self . persistent_workers ,
364366 ** kwargs ,
365367 )
366368
@@ -379,7 +381,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
379381 "validation" ,
380382 shuffle = False ,
381383 num_workers = self .num_workers ,
382- persistent_workers = True ,
384+ persistent_workers = self . persistent_workers ,
383385 ** kwargs ,
384386 )
385387
0 commit comments