Skip to content

Commit 54932d8

Browse files
Revert "[c10d] Add support for testing SIGABRT return (pytorch#153167)"
This reverts commit 03e102d. Reverted pytorch#153167 on behalf of https://github.com/malfet due to It broke lint ([comment](pytorch#153167 (comment)))
1 parent c4ef409 commit 54932d8

File tree

2 files changed

+152
-63
lines changed

2 files changed

+152
-63
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 127 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@
4444
get_timeout,
4545
init_multigpu_helper,
4646
MultiProcessTestCase,
47+
requires_gloo,
4748
requires_multicast_support,
4849
requires_nccl,
4950
requires_nccl_version,
5051
skip_if_lt_x_gpu,
5152
skip_if_rocm_multiprocess,
53+
sm_is_or_higher_than,
5254
TEST_SKIPS,
5355
with_dist_debug_levels,
5456
with_nccl_blocking_wait,
@@ -282,17 +284,16 @@ def opts(self, high_priority_stream=False):
282284

283285
def setUp(self):
284286
super().setUp()
285-
286-
# These tests are expected to throw SIGABRT(6); adding the negative sign
287-
# bc the test return code is actually -6
288-
self.special_return_code_checks = {
289-
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
290-
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
291-
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
292-
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
293-
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
294-
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
295-
}
287+
# Need to skip return code checking for these tests since the child
288+
# processes don't exit cleanly in some cuda versions
289+
self.skip_return_code_checks = [
290+
self.test_nan_assert_float16.__wrapped__,
291+
self.test_nan_assert_float32.__wrapped__,
292+
self.test_nan_assert_float64.__wrapped__,
293+
self.test_nan_assert_bfloat16.__wrapped__,
294+
self.test_nan_assert_float8_e4m3fn.__wrapped__,
295+
self.test_nan_assert_float8_e5m2.__wrapped__,
296+
]
296297

297298
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
298299
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
@@ -533,14 +534,14 @@ def test_nan_assert(self, type):
533534

534535
# confirm enable/disable flag works
535536
backend._set_enable_nan_check(False)
536-
# Note: using all-gather here bc some NCCL/SM version does not support
537-
# FP8 reduction
538-
pg._allgather_base(output, nan_tensor)
537+
pg.allreduce(nan_tensor)
539538

540539
backend._set_enable_nan_check(True)
541-
pg._allgather_base(output, nan_tensor)
540+
with self.assertRaises(RuntimeError):
541+
# Note: using all-gather here bc FP8 types do not support reduce ops
542+
# at the moment
543+
pg._allgather_base(output, nan_tensor)
542544
dist.destroy_process_group()
543-
544545
# reset env
545546
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
546547

@@ -575,13 +576,16 @@ def test_nan_rank_filter(self):
575576
def test_nan_check(self):
576577
# Not expecting an error, NaN check should not make legit code fail
577578
device = torch.device(f"cuda:{self.rank:d}")
579+
if not sm_is_or_higher_than(device, 8, 0):
580+
self.skipTest("bf16 requires sm >= 8.0")
581+
578582
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
579583
store = c10d.FileStore(self.file_name, self.world_size)
580584
c10d.init_process_group(
581585
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
582586
)
583-
x = torch.ones((10,), device=device) * self.rank
584-
t = torch.ones(3, 4, device=device)
587+
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
588+
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
585589
c10d.broadcast(x, src=0)
586590
c10d.all_reduce(t)
587591
c10d.barrier()
@@ -2771,6 +2775,14 @@ def hook(work_info: torch._C._distributed_c10d.WorkInfo):
27712775
class NcclErrorHandlingTest(MultiProcessTestCase):
27722776
def setUp(self):
27732777
super().setUp()
2778+
# Need to skip return code checking for these tests since the child
2779+
# processes don't exit cleanly.
2780+
self.skip_return_code_checks = [
2781+
self.test_nccl_errors_blocking_abort.__wrapped__,
2782+
self.test_nccl_errors_blocking_sigkill.__wrapped__,
2783+
self.test_nccl_errors_blocking_sigterm.__wrapped__,
2784+
self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
2785+
]
27742786
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
27752787
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
27762788
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
@@ -2798,19 +2810,12 @@ def blocking_wait_error_msg(self):
27982810
def _run_all_reduce(self, pg):
27992811
pg.allreduce(torch.rand(10).cuda(self.rank))
28002812

2801-
def _reduce_timeout(self):
2802-
# set heartbeat timeout to a small value so that we don't wait too long
2803-
# for things to shutdown
2804-
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "4"
2805-
os.environ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"] = "1000"
2806-
28072813
@requires_nccl()
28082814
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
28092815
@skip_if_lt_x_gpu(3)
28102816
@skip_if_rocm_multiprocess
28112817
@skip_but_pass_in_sandcastle("Test does not pass when run locally")
28122818
def test_nccl_errors_nonblocking(self):
2813-
self._reduce_timeout()
28142819
# Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
28152820
# since test_c10d_common runs with async error handling by default, but this
28162821
# tests behavior when it is not enabled.
@@ -2841,24 +2846,30 @@ def test_nccl_errors_nonblocking(self):
28412846
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
28422847
] = prev_nccl_async_error_handling
28432848

2844-
@requires_nccl()
2845-
@skip_if_lt_x_gpu(3)
2846-
@skip_if_rocm_multiprocess
2847-
def test_nccl_errors_blocking(self):
2848-
self._reduce_timeout()
2849-
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
2849+
def _test_nccl_errors_blocking(self, func):
28502850
store = c10d.FileStore(self.file_name, self.world_size)
28512851
process_group = c10d.ProcessGroupNCCL(
28522852
store,
28532853
self.rank,
28542854
self.world_size,
2855+
timeout=timedelta(seconds=10),
28552856
)
2856-
x = torch.rand(1024 * 1024).cuda(self.rank)
2857-
process_group.allreduce(x)
2857+
process_group.allreduce(torch.rand(10).cuda(self.rank))
28582858
if self.rank == 0:
2859-
work = process_group.allreduce(x)
2859+
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
28602860
with self.assertRaisesRegex(dist.DistBackendError, ""):
2861+
# It seems the error message would be different depending on
2862+
# whether the test is run on CI machine and devGPU. Skipping
2863+
# the error message check to make both sides happy.
28612864
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
2865+
# Run some GPU operations to make sure cuda has not gotten stuck.
2866+
# It was observed cuda could get stuck if NCCL communicators were
2867+
# not properly aborted before throwing RuntimeError.
2868+
torch.rand(10).cuda(self.rank)
2869+
elif self.rank == 1:
2870+
# Clean up structures (ex: files for FileStore before going down)
2871+
del process_group
2872+
func()
28622873

28632874
def _test_barrier_error(self):
28642875
store = c10d.FileStore(self.file_name, self.world_size)
@@ -2878,19 +2889,60 @@ def _test_barrier_error(self):
28782889
timeout=timedelta(seconds=self.op_timeout_sec)
28792890
)
28802891

2892+
@with_nccl_blocking_wait
2893+
@requires_nccl()
2894+
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2895+
@skip_if_lt_x_gpu(3)
2896+
@skip_if_rocm_multiprocess
2897+
def test_nccl_errors_blocking_clean_exit(self):
2898+
self._test_nccl_errors_blocking(lambda: sys.exit(0))
2899+
2900+
@with_nccl_blocking_wait
2901+
@requires_nccl()
2902+
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2903+
@skip_if_lt_x_gpu(3)
2904+
@skip_if_rocm_multiprocess
2905+
def test_nccl_errors_blocking_nonzero_exit(self):
2906+
self._test_nccl_errors_blocking(lambda: sys.exit(1))
2907+
2908+
@with_nccl_blocking_wait
2909+
@requires_nccl()
2910+
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2911+
@skip_if_lt_x_gpu(3)
2912+
@skip_if_rocm_multiprocess
2913+
@skip_but_pass_in_sandcastle(
2914+
"Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
2915+
)
2916+
def test_nccl_errors_blocking_abort(self):
2917+
self._test_nccl_errors_blocking(lambda: os.abort())
2918+
2919+
@with_nccl_blocking_wait
2920+
@requires_nccl()
2921+
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2922+
@skip_if_lt_x_gpu(3)
2923+
@skip_if_rocm_multiprocess
2924+
def test_nccl_errors_blocking_sigkill(self):
2925+
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
2926+
2927+
@with_nccl_blocking_wait
2928+
@requires_nccl()
2929+
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2930+
@skip_if_lt_x_gpu(3)
2931+
@skip_if_rocm_multiprocess
2932+
def test_nccl_errors_blocking_sigterm(self):
2933+
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
2934+
28812935
@with_nccl_blocking_wait
28822936
@requires_nccl()
28832937
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
28842938
@skip_if_lt_x_gpu(3)
28852939
def test_nccl_blocking_wait_with_barrier(self):
2886-
self._reduce_timeout()
28872940
self._test_barrier_error()
28882941

28892942
@requires_nccl()
28902943
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
28912944
@skip_if_lt_x_gpu(3)
28922945
def test_nccl_non_blocking_wait_with_barrier(self):
2893-
self._reduce_timeout()
28942946
# test the barrier behavior in the non blocking wait setting
28952947
prev_nccl_async_error_handling = os.environ.get(
28962948
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@@ -2961,7 +3013,6 @@ def assert_fut_success(fut):
29613013
@skip_if_rocm_multiprocess
29623014
@skip_if_lt_x_gpu(3)
29633015
def test_restart_pg_after_error(self):
2964-
self._reduce_timeout()
29653016
# test the barrier behavior in the non blocking wait setting
29663017
prev_nccl_async_error_handling = os.environ.get(
29673018
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@@ -3051,6 +3102,45 @@ def test_invalid_nccl_blocking_wait_env(self):
30513102
self._run_invalid_nccl_blocking_wait_env("2147483647")
30523103
self._run_invalid_nccl_blocking_wait_env("4294967295")
30533104

3105+
@with_nccl_blocking_wait
3106+
@requires_nccl()
3107+
@requires_gloo()
3108+
@skip_if_lt_x_gpu(3)
3109+
def test_nccl_timeout(self):
3110+
store = c10d.FileStore(self.file_name, self.world_size)
3111+
3112+
# Initialize process_group.
3113+
process_group = c10d.ProcessGroupNCCL(
3114+
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
3115+
)
3116+
# Control gloo pg used as go-ahead signal/barrier
3117+
# to coordinate btwn ranks.
3118+
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
3119+
failed_collective_timeout = timedelta(milliseconds=100)
3120+
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
3121+
timeout=timedelta(seconds=5)
3122+
)
3123+
3124+
if self.rank == 0:
3125+
# This should timeout in about 1 second.
3126+
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
3127+
with self.assertRaisesRegex(
3128+
dist.DistBackendError, self.blocking_wait_error_msg
3129+
):
3130+
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
3131+
timeout=failed_collective_timeout
3132+
)
3133+
# Now do a barrier to tell other rank to go ahead.
3134+
pg_gloo.barrier().wait()
3135+
else:
3136+
# Wait on rank 0 to fail.
3137+
try:
3138+
pg_gloo.barrier().wait()
3139+
except Exception as e:
3140+
raise ValueError(
3141+
f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
3142+
) from e
3143+
30543144

30553145
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
30563146
def setUp(self):

torch/testing/_internal/common_distributed.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -640,15 +640,7 @@ def __init__(
640640

641641
def setUp(self) -> None:
642642
super().setUp()
643-
644-
# Used for tests that are expected to return a non-0 exit code, such as
645-
# SIGABRT thrown by watchdog.
646-
self.special_return_code_checks: dict = {}
647-
648-
# Used for tests that may return any exit code, which makes it hard to
649-
# check. This is rare, use with caution.
650-
self.skip_return_code_checks: list = []
651-
643+
self.skip_return_code_checks = [] # type: ignore[var-annotated]
652644
self.processes = [] # type: ignore[var-annotated]
653645
self.rank = self.MAIN_PROCESS_RANK
654646
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
@@ -873,13 +865,28 @@ def _join_processes(self, fn) -> None:
873865
time.sleep(0.1)
874866

875867
elapsed_time = time.time() - start_time
876-
self._check_return_codes(fn, elapsed_time)
868+
869+
if fn in self.skip_return_code_checks:
870+
self._check_no_test_errors(elapsed_time)
871+
else:
872+
self._check_return_codes(elapsed_time)
877873
finally:
878874
# Close all pipes
879875
for pipe in self.pid_to_pipe.values():
880876
pipe.close()
881877

882-
def _check_return_codes(self, fn, elapsed_time) -> None:
878+
def _check_no_test_errors(self, elapsed_time) -> None:
879+
"""
880+
Checks that we didn't have any errors thrown in the child processes.
881+
"""
882+
for i, p in enumerate(self.processes):
883+
if p.exitcode is None:
884+
raise RuntimeError(
885+
f"Process {i} timed out after {elapsed_time} seconds"
886+
)
887+
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
888+
889+
def _check_return_codes(self, elapsed_time) -> None:
883890
"""
884891
Checks that the return codes of all spawned processes match, and skips
885892
tests if they returned a return code indicating a skipping condition.
@@ -921,11 +928,11 @@ def _check_return_codes(self, fn, elapsed_time) -> None:
921928
raise RuntimeError(
922929
f"Process {i} terminated or timed out after {elapsed_time} seconds"
923930
)
924-
925-
# Skip the test return code check
926-
if fn in self.skip_return_code_checks:
927-
return
928-
931+
self.assertEqual(
932+
p.exitcode,
933+
first_process.exitcode,
934+
msg=f"Expect process {i} exit code to match Process 0 exit code of {first_process.exitcode}, but got {p.exitcode}",
935+
)
929936
for skip in TEST_SKIPS.values():
930937
if first_process.exitcode == skip.exit_code:
931938
if IS_SANDCASTLE:
@@ -941,18 +948,10 @@ def _check_return_codes(self, fn, elapsed_time) -> None:
941948
return
942949
else:
943950
raise unittest.SkipTest(skip.message)
944-
945-
# In most cases, we expect test to return exit code 0, standing for success.
946-
expected_return_code = 0
947-
# In some negative tests, we expect test to return non-zero exit code,
948-
# such as watchdog throwing SIGABRT.
949-
if fn in self.special_return_code_checks:
950-
expected_return_code = self.special_return_code_checks[fn]
951-
952951
self.assertEqual(
953952
first_process.exitcode,
954-
expected_return_code,
955-
msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
953+
0,
954+
msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
956955
)
957956

958957
@property

0 commit comments

Comments
 (0)