Skip to content

Commit 7792055

Browse files
Merge pull request NVIDIA#246 from hexinw-nvidia/signal
fix: Handle signals during rendezvous
2 parents 496e5d1 + 107c84f commit 7792055

File tree

10 files changed

+625
-352
lines changed

10 files changed

+625
-352
lines changed

src/nvidia_resiliency_ext/fault_tolerance/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class FaultToleranceConfig:
9191
* `install_exception_hook` [bool] if True, installs sys.excepthook to capture uncaught exceptions
9292
in training worker processes, format and log the traceback, and use os._exit() to exit the
9393
process reliably. Default: False.
94+
* `num_warmup_iterations` [int] number of warmup iterations before monitoring step section and
95+
out-of-section timeouts. The first N iterations (relative to cycle start) are excluded from
96+
timeout monitoring as they can be significantly slower than steady-state iterations.
97+
Default: 5. Can be overridden by workload (e.g., Megatron-LM via init_workload_monitoring).
9498
* Attribution service (optional):
9599
- `attrsvc_host` [str] hostname/IP of the attribution service
96100
- `attrsvc_port` [int] port of the attribution service
@@ -128,6 +132,9 @@ class FaultToleranceConfig:
128132
min_progress_iterations: int = 200
129133
progress_update_interval: float = 30.0 # Seconds between sending progress updates to launcher
130134
install_exception_hook: bool = False
135+
num_warmup_iterations: int = (
136+
5 # Number of warmup iterations before monitoring step section and out-of-section timeouts
137+
)
131138
# Attribution service configuration (optional)
132139
attrsvc_host: Optional[str] = None
133140
attrsvc_port: Optional[int] = None

src/nvidia_resiliency_ext/fault_tolerance/data.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def get_for_current_rank():
5656
"Could not find the rank of the current process. "
5757
"Is it a part of a distributed workload?"
5858
)
59+
global_rank = int(global_rank)
5960
local_rank = int(os.environ.get("LOCAL_RANK", -1))
6061
host = socket.gethostname()
6162
pid = os.getpid()
@@ -159,11 +160,19 @@ class InitMsg:
159160
iteration: Current training iteration if available from workload framework.
160161
If None, indicates that the workload cannot report iterations,
161162
and progress tracking should remain disabled.
163+
num_warmup_iters: Number of warmup iterations before monitoring step section
164+
and out-of-section timeouts. If None, server uses default from config.
162165
"""
163166

164-
def __init__(self, rank_info=None, iteration: Optional[int] = None):
167+
def __init__(
168+
self,
169+
rank_info=None,
170+
iteration: Optional[int] = None,
171+
num_warmup_iters: Optional[int] = None,
172+
):
165173
self.rank_info = rank_info
166174
self.iteration = iteration
175+
self.num_warmup_iters = num_warmup_iters
167176

168177

169178
class HeartbeatMsg:

src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,20 @@
1414
import json
1515
import logging
1616
import os
17+
import signal
1718
import socket
1819
import threading
1920
import time
2021
from collections import defaultdict
2122
from dataclasses import dataclass
2223
from datetime import timedelta
2324
from enum import Enum
25+
from types import FrameType
2426
from typing import Any, Dict, List, Optional, Tuple, Union
2527

2628
from torch.distributed import PrefixStore, Store
2729
from torch.distributed.elastic.events import NodeState, construct_and_record_rdzv_event
30+
from torch.distributed.elastic.multiprocessing import SignalException
2831
from torch.distributed.elastic.rendezvous.api import (
2932
RendezvousClosedError,
3033
RendezvousError,
@@ -67,6 +70,24 @@
6770
log = logging.getLogger(LogConfig.name)
6871

6972

73+
def _rdzv_signal_exception_handler(sig: int, frame: Optional[FrameType]) -> None:
74+
del frame
75+
raise SignalException(f"Received signal {sig} during rendezvous", signal.Signals(sig))
76+
77+
78+
def _install_rdzv_signal_handlers() -> Dict[signal.Signals, Any]:
79+
prev_handlers: Dict[signal.Signals, Any] = {}
80+
for sig_to_handle in (signal.SIGTERM, signal.SIGINT):
81+
prev_handlers[sig_to_handle] = signal.getsignal(sig_to_handle)
82+
signal.signal(sig_to_handle, _rdzv_signal_exception_handler)
83+
return prev_handlers
84+
85+
86+
def _restore_rdzv_signal_handlers(prev_handlers: Dict[signal.Signals, Any]) -> None:
87+
for sig_to_handle, handler in prev_handlers.items():
88+
signal.signal(sig_to_handle, handler)
89+
90+
7091
def get_method_name(depth=2):
7192
if len(inspect.stack()) > depth:
7293
return inspect.stack()[depth].function
@@ -853,6 +874,13 @@ def perform_rendezvous(
853874
# Start timing AFTER Step 0 completes, since hot spares may wait indefinitely at Step 0
854875
self._rendezvous_start_time = time.monotonic()
855876

877+
# Record rendezvous start event - start profiling AFTER waiting for rendezvous to open
878+
# This ensures hot spares waiting at Step 0 don't skew the rendezvous performance measurement
879+
rendezvous_start_event_id = record_profiling_event(
880+
ProfilingEvent.RENDEZVOUS_STARTED,
881+
node_id=node_desc,
882+
)
883+
856884
# Step 1: Join the rendezvous and get unique identifier
857885
self._arrived_count = self.store.add(self.arrived_count_key, 1)
858886

@@ -1699,12 +1727,7 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
16991727
self._record(message=msg)
17001728
log.info(msg)
17011729

1702-
# Record rendezvous start event
1703-
rendezvous_start_event_id = record_profiling_event(
1704-
ProfilingEvent.RENDEZVOUS_STARTED,
1705-
node_id=self._this_node,
1706-
)
1707-
1730+
prev_signal_handlers = _install_rdzv_signal_handlers()
17081731
try:
17091732
# Check node health and control requests before starting rendezvous
17101733
self.ensure_node_is_healthy()
@@ -1738,6 +1761,8 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
17381761
node_state=NodeState.FAILED,
17391762
)
17401763
raise
1764+
finally:
1765+
_restore_rdzv_signal_handlers(prev_signal_handlers)
17411766

17421767
msg = (
17431768
f"The node '{self._this_node}' has joined the rendezvous "

src/nvidia_resiliency_ext/fault_tolerance/launcher.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,12 @@
7070
FT_LAUNCHER_IPC_SOCKET_ENV_VAR,
7171
FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR,
7272
)
73-
from nvidia_resiliency_ext.fault_tolerance.per_cycle_logs import (
74-
PerCycleLogsSpecs,
75-
PipeBasedLogsSpecs,
76-
)
73+
from nvidia_resiliency_ext.fault_tolerance.per_cycle_logs import PipeBasedLogsSpecs
7774
from nvidia_resiliency_ext.fault_tolerance.progress_tracker import TrainingProgressTracker
7875
from nvidia_resiliency_ext.fault_tolerance.rank_monitor_server import RankMonitorServer
7976
from nvidia_resiliency_ext.fault_tolerance.utils import (
8077
get_processes_by_pgids,
78+
is_slurm_job_array,
8179
patched_method,
8280
read_obj_from_ipc_stream,
8381
terminate_mp_processes,
@@ -1657,6 +1655,17 @@ def launch_agent(
16571655
shutdown_rdzv = False
16581656
logger.error(f"Agent .run() raised UnhealthyNodeException: {e}")
16591657
events.record(agent.get_event_failed())
1658+
1659+
# Exit behavior depends on deployment mode:
1660+
# - Job array: raise (exit 1) so replacement job can be launched
1661+
# - Single job with hot spares: don't raise (instead, exit 0) to avoid killing job
1662+
# since --kill-on-bad-exit is the default srun behavior
1663+
if is_slurm_job_array():
1664+
logger.info("Job array deployment: exiting with code 1 for replacement.")
1665+
raise
1666+
else:
1667+
logger.info("Single job deployment: exiting with code 0 for hot spare takeover.")
1668+
# Don't raise - returns None, main() will exit with 0
16601669
except ChildFailedError:
16611670
raise
16621671
except SignalException as e:
@@ -1672,7 +1681,7 @@ def launch_agent(
16721681
else:
16731682
logger.info("All ranks exited gracefully. Launcher exiting without an error.")
16741683
except Exception as e:
1675-
logger.error(f"Agent .run() raised exception, {e=}", exc_info=True)
1684+
logger.error(f"Agent .run() raised exception, {e=}")
16761685
events.record(agent.get_event_failed())
16771686
raise
16781687
finally:
@@ -2327,8 +2336,8 @@ def get_args_parser() -> ArgumentParser:
23272336
type=str,
23282337
help="Logging behavior configuration. Options: "
23292338
"(1) None (default): Creates separate log files per rank per restart cycle. "
2330-
"(2) 'per_cycle': Consolidates all ranks' logs into a single log file per restart cycle. "
2331-
"(3) Custom entrypoint name from torchrun.logs_specs group for advanced customization.",
2339+
"(2) Custom entrypoint name from torchrun.logs_specs group for advanced customization. "
2340+
"Note: For consolidated logging, use --ft-base-logfile instead (automatically uses PipeBasedLogsSpecs).",
23322341
)
23332342

23342343
#
@@ -2783,14 +2792,14 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
27832792
27842793
Built-in options:
27852794
- None (default): Uses DefaultLogsSpecs (per-rank log files per cycle)
2786-
- 'per_cycle': Uses PerCycleLogsSpecs (single log file per cycle for all ranks)
2795+
2796+
Note: The legacy 'per_cycle' option has been removed. Use --ft-base-logfile instead,
2797+
which automatically uses PipeBasedLogsSpecs for consolidated logging.
27872798
"""
27882799
logs_specs_cls = None
27892800

2790-
# Handle built-in per_cycle option
2791-
if logs_specs_name == "per_cycle":
2792-
logs_specs_cls = PerCycleLogsSpecs
2793-
elif logs_specs_name is not None:
2801+
# Try to load from entrypoints
2802+
if logs_specs_name is not None:
27942803
# Try to load from entrypoints
27952804
eps = metadata.entry_points()
27962805
if hasattr(eps, "select"): # >= 3.10

0 commit comments

Comments
 (0)