diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index b1c69c45f230..99e17612b7d8 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize('method', ['spawn', 'fork', 'forkserver']) def test_start_method_safety(method): import torch.multiprocessing as mp - mp.set_start_method(method) + mp.set_start_method(method, force=True) @pytest.mark.parametrize('zero_stage', [0, 3])