Skip to content

Commit 8bbe907

Browse files
committed
update
1 parent d9a620a commit 8bbe907

File tree

3 files changed

+3
-6
lines changed

3 files changed

+3
-6
lines changed

tests/tests_pytorch/callbacks/test_prediction_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_prediction_writer_batch_indices(num_workers, tmp_path):
8383
DummyPredictionWriter.write_on_batch_end = Mock()
8484
DummyPredictionWriter.write_on_epoch_end = Mock()
8585

86-
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=True)
86+
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=num_workers > 0)
8787
model = BoringModel()
8888
writer = DummyPredictionWriter("batch_and_epoch")
8989
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path
135135
barebones=True,
136136
)
137137
model = TestSpawnBoringModel(warning_expected=(num_workers > 0))
138-
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers, persistent_workers=True)
138+
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers, persistent_workers=num_workers > 0)
139139
trainer.fit(model, dataloader)
140140

141141

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,7 @@ def on_train_epoch_end(self):
658658
def test_auto_add_worker_init_fn_distributed(tmp_path, monkeypatch):
659659
"""Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training."""
660660
dataset = NumpyRandomDataset()
661-
num_workers = 2
662-
batch_size = 2
663-
664-
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
661+
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, persistent_workers=True)
665662
seed_everything(0, workers=True)
666663
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp_spawn")
667664
model = MultiProcessModel()

0 commit comments

Comments
 (0)