diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 12eb634c..3ac0a803 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -76,6 +76,7 @@ def __init__( label_filter: Optional[int] = None, balance_after_filter: Optional[float] = None, num_workers: int = 1, + persistent_workers: bool = True, chebi_version: int = 200, inner_k_folds: int = -1, # use inner cross-validation if > 1 fold_index: Optional[int] = None, @@ -99,6 +100,7 @@ def __init__( ), "Filter balancing requires a filter" self.balance_after_filter = balance_after_filter self.num_workers = num_workers + self.persistent_workers: bool = bool(persistent_workers) self.chebi_version = chebi_version assert type(inner_k_folds) is int self.inner_k_folds = inner_k_folds @@ -360,7 +362,7 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader "train", shuffle=True, num_workers=self.num_workers, - persistent_workers=True, + persistent_workers=self.persistent_workers, **kwargs, ) @@ -379,7 +381,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]] "validation", shuffle=False, num_workers=self.num_workers, - persistent_workers=True, + persistent_workers=self.persistent_workers, **kwargs, )