Skip to content

Commit 4f641aa

Browse files
aschhabrapytorchmergebot
authored andcommitted
capturing exit codes after sigterm/sigkill from torch elastic. (pytorch#160908)
Summary: **Background** Torch Elastic sends SIGKILL/SIGTERM signals if any process fails while others are still running. However, processes terminated by these signals do not generate termination logs, causing confusion. **Solution** Capture exit codes after SIGTERM signals to ensure complete and accurate termination logging. Test Plan: unit tests https://www.internalfb.com/mlhub/pipelines/runs/mast/f773486907-TrainingApplication__13_D79584569?job_attempt=1&version=0&tab=summary&env=PRODUCTION Rollback Plan: Differential Revision: D79584569 Pull Request resolved: pytorch#160908 Approved by: https://github.com/d4l3k
1 parent 8dbac62 commit 4f641aa

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

test/distributed/elastic/multiprocessing/api_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,8 @@ def test_binary_exit(self):
568568
)
569569

570570
results = pc.wait(period=0.1)
571-
572571
self.assertTrue(results.is_failed())
573-
self.assertEqual(1, len(results.failures))
572+
self.assertEqual(2, len(results.failures))
574573

575574
failure = results.failures[0]
576575
self.assertEqual(138, failure.exitcode)
@@ -583,6 +582,13 @@ def test_binary_exit(self):
583582
self.assertTrue(pc._stderr_tail.stopped())
584583
self.assertTrue(pc._stdout_tail.stopped())
585584

585+
failure = results.failures[1]
586+
self.assertEqual(-15, failure.exitcode)
587+
self.assertEqual("SIGTERM", failure.signal_name())
588+
self.assertEqual("<NONE>", failure.error_file_data["message"])
589+
# Assert that the failure message contains expected substrings
590+
self.assertIn("Signal 15 (SIGTERM) received by PID", failure.message)
591+
586592
def test_binary_raises(self):
587593
pc = start_processes(
588594
name="echo",

test/distributed/elastic/multiprocessing/bin/echo1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import argparse
1010
import os
1111
import sys
12+
import time
1213

1314

1415
if __name__ == "__main__":
@@ -23,5 +24,6 @@
2324
print(f"exit {exitcode} from {rank}", file=sys.stderr)
2425
sys.exit(exitcode)
2526
else:
27+
time.sleep(1000)
2628
print(f"{args.msg} stdout from {rank}")
2729
print(f"{args.msg} stderr from {rank}", file=sys.stderr)

torch/distributed/elastic/multiprocessing/api.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,8 +875,7 @@ def _start(self):
875875
for local_rank in range(self.nprocs)
876876
}
877877

878-
def _poll(self) -> Optional[RunProcsResult]:
879-
done_local_ranks = set()
878+
def _capture_process_failures(self, done_local_ranks: set[int]):
880879
for local_rank in self._running_local_ranks:
881880
handler = self.subprocess_handlers[local_rank]
882881
exitcode = handler.proc.poll()
@@ -891,11 +890,19 @@ def _poll(self) -> Optional[RunProcsResult]:
891890
)
892891
# else: --> succeeded; nothing to do
893892

893+
def _poll(self) -> Optional[RunProcsResult]:
894+
done_local_ranks: set[int] = set()
895+
self._capture_process_failures(done_local_ranks)
896+
894897
self._running_local_ranks.difference_update(done_local_ranks)
895898

896899
# if ALL procs are finished or ANY have failed
897900
if not self._running_local_ranks or self._failures:
898901
self.close() # terminate all running procs
902+
self._capture_process_failures(
903+
done_local_ranks
904+
) # log sigterms and sigkill exit codes in the self._failures for bookkeeping purposes
905+
899906
result = RunProcsResult(
900907
failures=self._failures,
901908
stdouts=self.stdouts,

0 commit comments

Comments
 (0)