Skip to content

Commit 69123af

Browse files
authored
Fix hanging metrics tests (#5134)
1 parent eb9cb3c commit 69123af

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

tests/metrics/regression/test_ssim.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def _sk_metric(preds, target, data_range, multichannel):
5353
class TestSSIM(MetricTester):
5454
atol = 6e-5
5555

56-
# TODO: for some reason this test hangs with ddp=True
57-
# @pytest.mark.parametrize("ddp", [True, False])
58-
@pytest.mark.parametrize("ddp", [False])
56+
@pytest.mark.parametrize("ddp", [True, False])
5957
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
6058
def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step):
6159
self.run_class_metric_test(

tests/metrics/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111

1212
from pytorch_lightning.metrics import Metric
1313

14+
try:
15+
set_start_method("spawn")
16+
except RuntimeError:
17+
pass
18+
1419
NUM_PROCESSES = 2
1520
NUM_BATCHES = 10
1621
BATCH_SIZE = 32
@@ -165,10 +170,7 @@ def setup_class(self):
165170
"""Setup the metric class. This will spawn the pool of workers that are
166171
used for metric testing and setup_ddp
167172
"""
168-
try:
169-
set_start_method("spawn")
170-
except RuntimeError:
171-
pass
173+
172174
self.poolSize = NUM_PROCESSES
173175
self.pool = Pool(processes=self.poolSize)
174176
self.pool.starmap(setup_ddp, [(rank, self.poolSize) for rank in range(self.poolSize)])

0 commit comments

Comments
 (0)