Skip to content

Commit 83fe459

Browse files
committed
persistent workers can be set through CLI for GNI
1 parent 06a1869 commit 83fe459

File tree

1 file changed

+4
-2
lines changed
  • chebai/preprocessing/datasets

1 file changed

+4
-2
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
label_filter: Optional[int] = None,
7777
balance_after_filter: Optional[float] = None,
7878
num_workers: int = 1,
79+
persistent_workers=True,
7980
chebi_version: int = 200,
8081
inner_k_folds: int = -1, # use inner cross-validation if > 1
8182
fold_index: Optional[int] = None,
@@ -99,6 +100,7 @@ def __init__(
99100
), "Filter balancing requires a filter"
100101
self.balance_after_filter = balance_after_filter
101102
self.num_workers = num_workers
103+
self.persistent_workers: bool = bool(persistent_workers)
102104
self.chebi_version = chebi_version
103105
assert type(inner_k_folds) is int
104106
self.inner_k_folds = inner_k_folds
@@ -360,7 +362,7 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader
360362
"train",
361363
shuffle=True,
362364
num_workers=self.num_workers,
363-
persistent_workers=True,
365+
persistent_workers=self.persistent_workers,
364366
**kwargs,
365367
)
366368

@@ -379,7 +381,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
379381
"validation",
380382
shuffle=False,
381383
num_workers=self.num_workers,
382-
persistent_workers=True,
384+
persistent_workers=self.persistent_workers,
383385
**kwargs,
384386
)
385387

0 commit comments

Comments
 (0)