Skip to content

Commit fed9783

Browse files
committed
try: persistent_workers=True
1 parent 897b2af commit fed9783

File tree

7 files changed

+9
-7
lines changed

7 files changed

+9
-7
lines changed

tests/parity_fabric/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def get_dataloader(self):
7676
dataset,
7777
batch_size=self.batch_size,
7878
num_workers=2,
79+
persistent_workers=True,
7980
)
8081

8182
def get_loss_function(self):

tests/parity_pytorch/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,5 @@ def train_dataloader(self):
5959
CIFAR10(root=_PATH_DATASETS, train=True, download=True, transform=self.transform),
6060
batch_size=32,
6161
num_workers=1,
62+
persistent_workers=True,
6263
)

tests/tests_fabric/utilities/test_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,9 +638,9 @@ def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size
638638

639639
# The dataloader runs a check in `DataLoader.check_worker_number_rationality`
640640
with pytest.warns(UserWarning, match="This DataLoader will create"):
641-
DataLoader(range(2), num_workers=(cpu_count + 1))
641+
DataLoader(range(2), num_workers=(cpu_count + 1), persistent_workers=True)
642642
with no_warning_call():
643-
DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size))
643+
DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size), persistent_workers=True)
644644

645645

646646
def test_state():

tests/tests_pytorch/callbacks/test_prediction_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_prediction_writer_batch_indices(num_workers, tmp_path):
8383
DummyPredictionWriter.write_on_batch_end = Mock()
8484
DummyPredictionWriter.write_on_epoch_end = Mock()
8585

86-
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers)
86+
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=True)
8787
model = BoringModel()
8888
writer = DummyPredictionWriter("batch_and_epoch")
8989
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)

tests/tests_pytorch/helpers/advanced_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,4 @@ def configure_optimizers(self):
218218
return torch.optim.Adam(self.parameters(), lr=0.02)
219219

220220
def train_dataloader(self):
221-
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)
221+
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1, persistent_workers=True)

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path
135135
barebones=True,
136136
)
137137
model = TestSpawnBoringModel(warning_expected=(num_workers > 0))
138-
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers)
138+
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers, persistent_workers=True)
139139
trainer.fit(model, dataloader)
140140

141141

@@ -252,7 +252,7 @@ def test_update_dataloader_with_multiprocessing_context():
252252
"""This test verifies that `use_distributed_sampler` conserves multiprocessing context."""
253253
train = RandomDataset(32, 64)
254254
context = "spawn"
255-
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)
255+
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True, persistent_workers=True)
256256
new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset))
257257
assert new_data_loader.multiprocessing_context == train.multiprocessing_context
258258

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def test_auto_add_worker_init_fn_distributed(tmp_path, monkeypatch):
661661
num_workers = 2
662662
batch_size = 2
663663

664-
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
664+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
665665
seed_everything(0, workers=True)
666666
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp_spawn")
667667
model = MultiProcessModel()

0 commit comments

Comments
 (0)