Skip to content

Commit acc0cf0

Browse files
authored
Refinements to the num-workers warning (#18737)
1 parent a26424e commit acc0cf0

File tree

4 files changed

+9
-24
lines changed

4 files changed

+9
-24
lines changed

src/lightning/fabric/utilities/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def suggested_max_num_workers(local_world_size: int) -> int:
447447
if local_world_size < 1:
448448
raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.")
449449
cpu_count = _num_cpus_available()
450-
return max(1, cpu_count // local_world_size)
450+
return max(1, cpu_count // local_world_size - 1) # -1 to leave some resources for main process
451451

452452

453453
def _num_cpus_available() -> int:

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None:
436436
rank_zero_warn(
437437
f"Consider setting `persistent_workers=True` in '{name}' to speed up the dataloader worker initialization."
438438
)
439-
elif dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound:
439+
elif dataloader.num_workers < 2:
440440
# if changed, update the `filterwarnings` snippet in 'speed.html#num-workers'
441441
rank_zero_warn(
442442
f"The '{name}' does not have many workers which may be a bottleneck. Consider increasing the value of the"

tests/tests_fabric/utilities/test_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,13 +588,13 @@ def test_set_sampler_epoch():
588588
[
589589
(0, 1, 1),
590590
(1, 1, 1),
591-
(2, 1, 2),
591+
(2, 1, 2 - 1),
592592
(1, 2, 1),
593593
(2, 2, 1),
594594
(3, 2, 1),
595-
(4, 2, 2),
595+
(4, 2, 2 - 1),
596596
(4, 3, 1),
597-
(4, 1, 4),
597+
(4, 1, 4 - 1),
598598
],
599599
)
600600
@pytest.mark.parametrize(

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,31 +139,16 @@ def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path
139139
trainer.fit(model, dataloader)
140140

141141

142-
@pytest.mark.parametrize(
143-
("num_devices", "num_workers", "cpu_count", "expected_warning"),
144-
[
145-
(1, 0, 1, False),
146-
(8, 0, 1, False),
147-
(8, 0, None, False),
148-
(1, 1, None, False),
149-
(1, 2, 2, False),
150-
(1, 1, 8, True),
151-
(1, 2, 8, True),
152-
(1, 3, 8, False),
153-
(4, 1, 8, True),
154-
(4, 2, 8, False),
155-
(8, 2, 8, False),
156-
],
157-
)
142+
@pytest.mark.parametrize(("num_workers", "expected_warning"), [(0, True), (1, True), (2, False), (3, False)])
158143
@mock.patch("lightning.fabric.utilities.data.os.cpu_count")
159144
@mock.patch("lightning.pytorch.trainer.connectors.data_connector.mp.get_start_method", return_value="not_spawn")
160-
def test_worker_check(_, cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch):
145+
def test_worker_check(_, cpu_count_mock, num_workers, expected_warning, monkeypatch):
161146
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
162147
trainer = Mock(spec=Trainer)
163148
dataloader = Mock(spec=DataLoader, persistent_workers=False)
164-
trainer.num_devices = num_devices
149+
trainer.num_devices = 2
165150
dataloader.num_workers = num_workers
166-
cpu_count_mock.return_value = cpu_count
151+
cpu_count_mock.return_value = 8
167152

168153
if expected_warning:
169154
ctx = pytest.warns(UserWarning, match="Consider increasing the value of the `num_workers` argument`")

0 commit comments

Comments
 (0)