4444 get_timeout ,
4545 init_multigpu_helper ,
4646 MultiProcessTestCase ,
47+ requires_gloo ,
4748 requires_multicast_support ,
4849 requires_nccl ,
4950 requires_nccl_version ,
5051 skip_if_lt_x_gpu ,
5152 skip_if_rocm_multiprocess ,
53+ sm_is_or_higher_than ,
5254 TEST_SKIPS ,
5355 with_dist_debug_levels ,
5456 with_nccl_blocking_wait ,
@@ -282,17 +284,16 @@ def opts(self, high_priority_stream=False):
282284
283285 def setUp (self ):
284286 super ().setUp ()
285-
286- # These tests are expected to throw SIGABRT(6); adding the negative sign
287- # bc the test return code is actually -6
288- self .special_return_code_checks = {
289- self .test_nan_assert_float16 .__wrapped__ : - signal .SIGABRT ,
290- self .test_nan_assert_float32 .__wrapped__ : - signal .SIGABRT ,
291- self .test_nan_assert_float64 .__wrapped__ : - signal .SIGABRT ,
292- self .test_nan_assert_bfloat16 .__wrapped__ : - signal .SIGABRT ,
293- self .test_nan_assert_float8_e4m3fn .__wrapped__ : - signal .SIGABRT ,
294- self .test_nan_assert_float8_e5m2 .__wrapped__ : - signal .SIGABRT ,
295- }
287+ # Need to skip return code checking for these tests since the child
288+ # processes don't exit cleanly in some cuda versions
289+ self .skip_return_code_checks = [
290+ self .test_nan_assert_float16 .__wrapped__ ,
291+ self .test_nan_assert_float32 .__wrapped__ ,
292+ self .test_nan_assert_float64 .__wrapped__ ,
293+ self .test_nan_assert_bfloat16 .__wrapped__ ,
294+ self .test_nan_assert_float8_e4m3fn .__wrapped__ ,
295+ self .test_nan_assert_float8_e5m2 .__wrapped__ ,
296+ ]
296297
297298 # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
298299 # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
@@ -533,14 +534,14 @@ def test_nan_assert(self, type):
533534
534535 # confirm enable/disable flag works
535536 backend ._set_enable_nan_check (False )
536- # Note: using all-gather here bc some NCCL/SM version does not support
537- # FP8 reduction
538- pg ._allgather_base (output , nan_tensor )
537+ pg .allreduce (nan_tensor )
539538
540539 backend ._set_enable_nan_check (True )
541- pg ._allgather_base (output , nan_tensor )
540+ with self .assertRaises (RuntimeError ):
541+ # Note: using all-gather here bc FP8 types do not support reduce ops
542+ # at the moment
543+ pg ._allgather_base (output , nan_tensor )
542544 dist .destroy_process_group ()
543-
544545 # reset env
545546 os .environ ["TORCH_NCCL_NAN_CHECK" ] = "0"
546547
@@ -575,13 +576,16 @@ def test_nan_rank_filter(self):
575576 def test_nan_check (self ):
576577 # Not expecting an error, NaN check should not make legit code fail
577578 device = torch .device (f"cuda:{ self .rank :d} " )
579+ if not sm_is_or_higher_than (device , 8 , 0 ):
580+ self .skipTest ("bf16 requires sm >= 8.0" )
581+
578582 os .environ ["TORCH_NCCL_NAN_CHECK" ] = "1"
579583 store = c10d .FileStore (self .file_name , self .world_size )
580584 c10d .init_process_group (
581585 backend = "nccl" , store = store , rank = self .rank , world_size = self .world_size
582586 )
583- x = torch .ones ((10 ,), device = device ) * self .rank
584- t = torch .ones (3 , 4 , device = device )
587+ x = torch .ones ((10 ,), dtype = torch . bfloat16 , device = device ) * self .rank
588+ t = torch .ones (3 , 4 , dtype = torch . bfloat16 , device = device )
585589 c10d .broadcast (x , src = 0 )
586590 c10d .all_reduce (t )
587591 c10d .barrier ()
@@ -2771,6 +2775,14 @@ def hook(work_info: torch._C._distributed_c10d.WorkInfo):
27712775class NcclErrorHandlingTest (MultiProcessTestCase ):
27722776 def setUp (self ):
27732777 super ().setUp ()
2778+ # Need to skip return code checking for these tests since the child
2779+ # processes don't exit cleanly.
2780+ self .skip_return_code_checks = [
2781+ self .test_nccl_errors_blocking_abort .__wrapped__ ,
2782+ self .test_nccl_errors_blocking_sigkill .__wrapped__ ,
2783+ self .test_nccl_errors_blocking_sigterm .__wrapped__ ,
2784+ self .test_nccl_errors_blocking_nonzero_exit .__wrapped__ ,
2785+ ]
27742786 # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
27752787 # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
27762788 os .environ ["TORCH_NCCL_ASYNC_ERROR_HANDLING" ] = "1"
@@ -2798,19 +2810,12 @@ def blocking_wait_error_msg(self):
27982810 def _run_all_reduce (self , pg ):
27992811 pg .allreduce (torch .rand (10 ).cuda (self .rank ))
28002812
2801- def _reduce_timeout (self ):
2802- # set heartbeat timeout to a small value so that we don't wait too long
2803- # for things to shutdown
2804- os .environ ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC" ] = "4"
2805- os .environ ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC" ] = "1000"
2806-
28072813 @requires_nccl ()
28082814 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
28092815 @skip_if_lt_x_gpu (3 )
28102816 @skip_if_rocm_multiprocess
28112817 @skip_but_pass_in_sandcastle ("Test does not pass when run locally" )
28122818 def test_nccl_errors_nonblocking (self ):
2813- self ._reduce_timeout ()
28142819 # Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
28152820 # since test_c10d_common runs with async error handling by default, but this
28162821 # tests behavior when it is not enabled.
@@ -2841,24 +2846,30 @@ def test_nccl_errors_nonblocking(self):
28412846 "TORCH_NCCL_ASYNC_ERROR_HANDLING"
28422847 ] = prev_nccl_async_error_handling
28432848
2844- @requires_nccl ()
2845- @skip_if_lt_x_gpu (3 )
2846- @skip_if_rocm_multiprocess
2847- def test_nccl_errors_blocking (self ):
2848- self ._reduce_timeout ()
2849- os .environ ["TORCH_NCCL_ASYNC_ERROR_HANDLING" ] = "0"
2849+ def _test_nccl_errors_blocking (self , func ):
28502850 store = c10d .FileStore (self .file_name , self .world_size )
28512851 process_group = c10d .ProcessGroupNCCL (
28522852 store ,
28532853 self .rank ,
28542854 self .world_size ,
2855+ timeout = timedelta (seconds = 10 ),
28552856 )
2856- x = torch .rand (1024 * 1024 ).cuda (self .rank )
2857- process_group .allreduce (x )
2857+ process_group .allreduce (torch .rand (10 ).cuda (self .rank ))
28582858 if self .rank == 0 :
2859- work = process_group .allreduce (x )
2859+ work = process_group .allreduce (torch . rand ( 10 ). cuda ( self . rank ) )
28602860 with self .assertRaisesRegex (dist .DistBackendError , "" ):
2861+ # It seems the error message would be different depending on
2862+ # whether the test is run on CI machine and devGPU. Skipping
2863+ # the error message check to make both sides happy.
28612864 work .wait (timeout = timedelta (seconds = self .op_timeout_sec ))
2865+ # Run some GPU operations to make sure cuda has not gotten stuck.
2866+ # It was observed cuda could get stuck if NCCL communicators were
2867+ # not properly aborted before throwing RuntimeError.
2868+ torch .rand (10 ).cuda (self .rank )
2869+ elif self .rank == 1 :
2870+ # Clean up structures (ex: files for FileStore before going down)
2871+ del process_group
2872+ func ()
28622873
28632874 def _test_barrier_error (self ):
28642875 store = c10d .FileStore (self .file_name , self .world_size )
@@ -2878,19 +2889,60 @@ def _test_barrier_error(self):
28782889 timeout = timedelta (seconds = self .op_timeout_sec )
28792890 )
28802891
2892+ @with_nccl_blocking_wait
2893+ @requires_nccl ()
2894+ @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2895+ @skip_if_lt_x_gpu (3 )
2896+ @skip_if_rocm_multiprocess
2897+ def test_nccl_errors_blocking_clean_exit (self ):
2898+ self ._test_nccl_errors_blocking (lambda : sys .exit (0 ))
2899+
2900+ @with_nccl_blocking_wait
2901+ @requires_nccl ()
2902+ @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2903+ @skip_if_lt_x_gpu (3 )
2904+ @skip_if_rocm_multiprocess
2905+ def test_nccl_errors_blocking_nonzero_exit (self ):
2906+ self ._test_nccl_errors_blocking (lambda : sys .exit (1 ))
2907+
2908+ @with_nccl_blocking_wait
2909+ @requires_nccl ()
2910+ @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2911+ @skip_if_lt_x_gpu (3 )
2912+ @skip_if_rocm_multiprocess
2913+ @skip_but_pass_in_sandcastle (
2914+ "Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
2915+ )
2916+ def test_nccl_errors_blocking_abort (self ):
2917+ self ._test_nccl_errors_blocking (lambda : os .abort ())
2918+
2919+ @with_nccl_blocking_wait
2920+ @requires_nccl ()
2921+ @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2922+ @skip_if_lt_x_gpu (3 )
2923+ @skip_if_rocm_multiprocess
2924+ def test_nccl_errors_blocking_sigkill (self ):
2925+ self ._test_nccl_errors_blocking (lambda : os .kill (os .getpid (), signal .SIGKILL ))
2926+
2927+ @with_nccl_blocking_wait
2928+ @requires_nccl ()
2929+ @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2930+ @skip_if_lt_x_gpu (3 )
2931+ @skip_if_rocm_multiprocess
2932+ def test_nccl_errors_blocking_sigterm (self ):
2933+ self ._test_nccl_errors_blocking (lambda : os .kill (os .getpid (), signal .SIGTERM ))
2934+
28812935 @with_nccl_blocking_wait
28822936 @requires_nccl ()
28832937 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
28842938 @skip_if_lt_x_gpu (3 )
28852939 def test_nccl_blocking_wait_with_barrier (self ):
2886- self ._reduce_timeout ()
28872940 self ._test_barrier_error ()
28882941
28892942 @requires_nccl ()
28902943 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
28912944 @skip_if_lt_x_gpu (3 )
28922945 def test_nccl_non_blocking_wait_with_barrier (self ):
2893- self ._reduce_timeout ()
28942946 # test the barrier behavior in the non blocking wait setting
28952947 prev_nccl_async_error_handling = os .environ .get (
28962948 "TORCH_NCCL_ASYNC_ERROR_HANDLING" , None
@@ -2961,7 +3013,6 @@ def assert_fut_success(fut):
29613013 @skip_if_rocm_multiprocess
29623014 @skip_if_lt_x_gpu (3 )
29633015 def test_restart_pg_after_error (self ):
2964- self ._reduce_timeout ()
29653016 # test the barrier behavior in the non blocking wait setting
29663017 prev_nccl_async_error_handling = os .environ .get (
29673018 "TORCH_NCCL_ASYNC_ERROR_HANDLING" , None
@@ -3051,6 +3102,45 @@ def test_invalid_nccl_blocking_wait_env(self):
30513102 self ._run_invalid_nccl_blocking_wait_env ("2147483647" )
30523103 self ._run_invalid_nccl_blocking_wait_env ("4294967295" )
30533104
3105+ @with_nccl_blocking_wait
3106+ @requires_nccl ()
3107+ @requires_gloo ()
3108+ @skip_if_lt_x_gpu (3 )
3109+ def test_nccl_timeout (self ):
3110+ store = c10d .FileStore (self .file_name , self .world_size )
3111+
3112+ # Initialize process_group.
3113+ process_group = c10d .ProcessGroupNCCL (
3114+ store , self .rank , self .world_size , timeout = timedelta (seconds = 10 )
3115+ )
3116+ # Control gloo pg used as go-ahead signal/barrier
3117+ # to coordinate btwn ranks.
3118+ pg_gloo = c10d .ProcessGroupGloo (store , self .rank , self .world_size )
3119+ failed_collective_timeout = timedelta (milliseconds = 100 )
3120+ process_group .allreduce (torch .rand (10 ).cuda (self .rank )).wait (
3121+ timeout = timedelta (seconds = 5 )
3122+ )
3123+
3124+ if self .rank == 0 :
3125+ # This should timeout in about 1 second.
3126+ # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
3127+ with self .assertRaisesRegex (
3128+ dist .DistBackendError , self .blocking_wait_error_msg
3129+ ):
3130+ process_group .allreduce (torch .rand (10 ).cuda (self .rank )).wait (
3131+ timeout = failed_collective_timeout
3132+ )
3133+ # Now do a barrier to tell other rank to go ahead.
3134+ pg_gloo .barrier ().wait ()
3135+ else :
3136+ # Wait on rank 0 to fail.
3137+ try :
3138+ pg_gloo .barrier ().wait ()
3139+ except Exception as e :
3140+ raise ValueError (
3141+ f"Rank { self .rank } barrier timed out waiting for rank 0 with error: { str (e )} "
3142+ ) from e
3143+
30543144
30553145class NcclUserBufferRegistrationTest (MultiProcessTestCase ):
30563146 def setUp (self ):
0 commit comments