Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def get_num_rollout_per_epoch(self):
def generate(self, rollout_id):
start_time = time.time()
self.rollout_id = rollout_id
self.recover_rollout_engines()
self.health_monitoring_resume()
if self.args.ci_test and self.args.use_fault_tolerance and rollout_id >= 2:
self._try_ci_fault_injection()
Expand Down
73 changes: 55 additions & 18 deletions slime/utils/health_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,17 @@ def stop(self) -> None:
self._is_checking_enabled = False

def pause(self) -> None:
"""Pause health checking. Called when engines are offloaded."""
"""Pause health checking. Called when engines are offloaded.

Before pausing, performs a final health check to ensure all engines are healthy.
Any unhealthy engines will be killed before pausing.
"""
if self._pause_event is None:
return
logger.info("Pausing health monitor...")
logger.info("Pausing health monitor (running final health check first)...")
# Run a final health check before pausing to catch any unhealthy engines
if self._is_checking_enabled:
self._run_health_checks()
self._pause_event.set()
self._is_checking_enabled = False

Expand Down Expand Up @@ -136,27 +143,57 @@ def _health_monitor_loop(self) -> None:
break

def _run_health_checks(self) -> None:
for rollout_engine_id, engine in enumerate(self._rollout_manager.rollout_engines):
if self._stop_event is not None and self._stop_event.is_set():
break
if self._pause_event is not None and self._pause_event.is_set():
break
self._check_engine_health(rollout_engine_id, engine)
"""Run health checks for all engines in parallel."""
engines = self._rollout_manager.rollout_engines
if not engines:
return

# Collect all valid engines with their indices
engine_tasks = [
(i, engine, engine.health_generate.remote(timeout=self._check_timeout))
for i, engine in enumerate(engines)
if engine is not None
]

def _check_engine_health(self, rollout_engine_id, engine) -> None:
if engine is None:
logger.info(f"Skipping health check for engine {rollout_engine_id} (None)")
if not engine_tasks:
return

# Wait for all health checks in parallel
refs = [task for _, _, task in engine_tasks]
try:
ray.get(engine.health_generate.remote(timeout=self._check_timeout))
results = ray.get(refs, timeout=self._check_timeout + 5)
# All succeeded
for (engine_id, _, _), result in zip(engine_tasks, results, strict=True):
if result is not True:
logger.error(f"Health check returned non-True for engine {engine_id}: {result}. Killing actor.")
self._kill_engine(rollout_engine_id=engine_id)
else:
logger.debug(f"Health check passed for rollout engine {engine_id}")
except ray.exceptions.GetTimeoutError:
# Timeout - need to check which ones failed
logger.warning("Some health checks timed out, checking individual results...")
self._check_individual_results(engine_tasks)
except Exception as e:
logger.error(
f"Health check failed for rollout engine {rollout_engine_id} (ray timeout or error). Killing actor. Exception: {e}"
)
self._kill_engine(rollout_engine_id=rollout_engine_id)
else:
logger.debug(f"Health check passed for rollout engine {rollout_engine_id}")
# Some other error - check each one individually
logger.warning(f"Batch health check failed with error: {e}, checking individually...")
self._check_individual_results(engine_tasks)

def _check_individual_results(self, engine_tasks: list) -> None:
"""Check health check results individually after batch failure."""
for engine_id, _engine, ref in engine_tasks:
try:
result = ray.get(ref, timeout=0) # Non-blocking check
if result is not True:
logger.error(f"Health check returned non-True for engine {engine_id}: {result}. Killing actor.")
self._kill_engine(rollout_engine_id=engine_id)
else:
logger.debug(f"Health check passed for rollout engine {engine_id}")
except ray.exceptions.GetTimeoutError:
logger.error(f"Health check timed out for rollout engine {engine_id}. Killing actor.")
self._kill_engine(rollout_engine_id=engine_id)
except Exception as e:
logger.error(f"Health check failed for rollout engine {engine_id}: {e}. Killing actor.")
self._kill_engine(rollout_engine_id=engine_id)

def _kill_engine(self, rollout_engine_id: int):
logger.info(f"Killing engine group {rollout_engine_id}...")
Expand Down
Loading