File tree Expand file tree Collapse file tree 3 files changed +19
-4
lines changed
test/distributed/elastic/multiprocessing
torch/distributed/elastic/multiprocessing Expand file tree Collapse file tree 3 files changed +19
-4
lines changed Original file line number Diff line number Diff line change @@ -568,9 +568,8 @@ def test_binary_exit(self):
568
568
)
569
569
570
570
results = pc .wait (period = 0.1 )
571
-
572
571
self .assertTrue (results .is_failed ())
573
- self .assertEqual (1 , len (results .failures ))
572
+ self .assertEqual (2 , len (results .failures ))
574
573
575
574
failure = results .failures [0 ]
576
575
self .assertEqual (138 , failure .exitcode )
@@ -583,6 +582,13 @@ def test_binary_exit(self):
583
582
self .assertTrue (pc ._stderr_tail .stopped ())
584
583
self .assertTrue (pc ._stdout_tail .stopped ())
585
584
585
+ failure = results .failures [1 ]
586
+ self .assertEqual (- 15 , failure .exitcode )
587
+ self .assertEqual ("SIGTERM" , failure .signal_name ())
588
+ self .assertEqual ("<NONE>" , failure .error_file_data ["message" ])
589
+ # Assert that the failure message contains expected substrings
590
+ self .assertIn ("Signal 15 (SIGTERM) received by PID" , failure .message )
591
+
586
592
def test_binary_raises (self ):
587
593
pc = start_processes (
588
594
name = "echo" ,
Original file line number Diff line number Diff line change 9
9
import argparse
10
10
import os
11
11
import sys
12
+ import time
12
13
13
14
14
15
if __name__ == "__main__" :
23
24
print (f"exit { exitcode } from { rank } " , file = sys .stderr )
24
25
sys .exit (exitcode )
25
26
else :
27
+ time .sleep (1000 )
26
28
print (f"{ args .msg } stdout from { rank } " )
27
29
print (f"{ args .msg } stderr from { rank } " , file = sys .stderr )
Original file line number Diff line number Diff line change @@ -875,8 +875,7 @@ def _start(self):
875
875
for local_rank in range (self .nprocs )
876
876
}
877
877
878
- def _poll (self ) -> Optional [RunProcsResult ]:
879
- done_local_ranks = set ()
878
+ def _capture_process_failures (self , done_local_ranks : set [int ]):
880
879
for local_rank in self ._running_local_ranks :
881
880
handler = self .subprocess_handlers [local_rank ]
882
881
exitcode = handler .proc .poll ()
@@ -891,11 +890,19 @@ def _poll(self) -> Optional[RunProcsResult]:
891
890
)
892
891
# else: --> succeeded; nothing to do
893
892
893
+ def _poll (self ) -> Optional [RunProcsResult ]:
894
+ done_local_ranks : set [int ] = set ()
895
+ self ._capture_process_failures (done_local_ranks )
896
+
894
897
self ._running_local_ranks .difference_update (done_local_ranks )
895
898
896
899
# if ALL procs are finished or ANY have failed
897
900
if not self ._running_local_ranks or self ._failures :
898
901
self .close () # terminate all running procs
902
+ self ._capture_process_failures (
903
+ done_local_ranks
904
+ ) # log sigterms and sigkill exit codes in the self._failures for bookkeeping purposes
905
+
899
906
result = RunProcsResult (
900
907
failures = self ._failures ,
901
908
stdouts = self .stdouts ,
You can’t perform that action at this time.
0 commit comments