From 731eb8ab4fc1ee243d7aa8d0dcb2ce6b302f3736 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 20 Nov 2025 15:59:48 +0800 Subject: [PATCH 01/11] fix conflict --- trinity/common/models/model.py | 20 ++++++++++++++++++- trinity/explorer/workflow_runner.py | 31 +++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index a2ecdd90a4..bd0f36c99c 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -4,7 +4,7 @@ import socket from abc import ABC, abstractmethod from functools import partial -from typing import List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import httpx import numpy as np @@ -103,7 +103,9 @@ def __init__( self.enable_history = enable_history self.history = [] self.status = RunningStatus.RUNNING + self.workflow_state: Dict = {} self.request_count = 0 + self.state_lock = asyncio.Lock() async def prepare(self) -> None: """Prepare the model wrapper.""" @@ -361,6 +363,22 @@ def extract_experience_from_history(self, clear_history: bool = True) -> List[Ex self.history.clear() return exps + # Workflow state management methods + async def set_workflow_state(self, state: Dict) -> None: + """Set the state of workflow using the model.""" + async with self.state_lock: + self.workflow_state.update(state) + + async def clean_workflow_state(self) -> None: + """Clean the state of workflow using the model.""" + async with self.state_lock: + self.workflow_state = {} + + async def get_workflow_state(self) -> Dict: + """Get the state of workflow using the model.""" + async with self.state_lock: + return self.workflow_state.copy() + def convert_api_output_to_experience( output, diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 85af23aa1b..9a44518e7b 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -74,6 +74,14 @@ def __init__( self.auxiliary_model_async_clients = [] self.workflow_instance: Workflow = None self.runner_id = runner_id + self.runner_state = { + "runner_id": self.runner_id, + "running_workflow_id": None, + "model_version": None, + "begin_time": 0, + "terminate_time": 0, + } + self.lock = asyncio.Lock() async def prepare(self) -> None: """Prepare the runner.""" @@ -121,23 +129,37 @@ async def _run_task( ) -> Tuple[List[Experience], List[Dict]]: """Init workflow from the task and run it.""" self._create_workflow_instance(task) + if self.workflow_instance.repeatable: self.workflow_instance.set_repeat_times(repeat_times, run_id_base) st = time.time() + await self.model_wrapper.clean_workflow_state() + self.runner_state[ + "running_workflow_id" + ] = f"{task.batch_id}/{task.task_id}/{run_id_base}" + self.runner_state["terminate_time"] = None + self.runner_state["begin_time"] = st exps = await self._run_workflow(self.workflow_instance) - task_execution_time = time.time() - st + et = time.time() + self.runner_state["terminate_time"] = et # repeatable workflow cannot calculate run level metrics, we use experience level metrics directly run_metrics = [exp.metrics for exp in exps if exp.metrics] for metric in run_metrics: - metric["time/task_execution"] = task_execution_time + metric["time/task_execution"] = et - st else: exps = [] run_metrics = [] for i in range(repeat_times): st = time.time() + await self.model_wrapper.clean_workflow_state() + self.runner_state["running_workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" + self.runner_state["terminate_time"] = None + self.runner_state["begin_time"] = st new_exps = await self._run_workflow(self.workflow_instance) + et = time.time() + self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) - run_metric["time/task_execution"] = time.time() - st + run_metric["time/task_execution"] = et - st run_metrics.append(run_metric) for exp in new_exps: exp.eid.run = run_id_base + i @@ -156,9 +178,10 @@ async def run_task( # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead try: st = time.time() + model_version = await self.model_wrapper.model_version_async + self.runner_state["model_version"] = model_version exps, metrics = await self._run_task(task, repeat_times, run_id_base) assert exps is not None and len(exps) > 0, "An empty experience is generated" - model_version = await self.model_wrapper.model_version_async # set eid for each experience for exp in exps: exp.eid.batch = task.batch_id From b438bfca25c44b1f0417bb5fd18fd7653ef3fce5 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 20 Nov 2025 18:45:07 +0800 Subject: [PATCH 02/11] support collecting workflow status --- tests/explorer/scheduler_test.py | 93 ++++++++++++++++++++++++++++- tests/explorer/workflow_test.py | 92 ++++++++++++++++++++++++++++ trinity/common/config.py | 1 + trinity/explorer/scheduler.py | 74 +++++++++++++++++++++++ trinity/explorer/workflow_runner.py | 7 +++ 5 files changed, 266 insertions(+), 1 deletion(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 9495b86f7a..0a5fbf7714 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -2,6 +2,7 @@ import time import unittest from typing import List, Optional +from collections import defaultdict import ray import torch @@ -11,7 +12,7 @@ from trinity.common.config import ExperienceBufferConfig from trinity.common.constants import StorageType, SyncStyle from trinity.common.experience import EID, Experience -from trinity.common.models.model import InferenceModel +from trinity.common.models.model import InferenceModel, ModelWrapper from trinity.common.workflows import Task from trinity.common.workflows.workflow import WORKFLOWS, Workflow from trinity.explorer.scheduler import Scheduler @@ -134,6 +135,41 @@ def run(self): raise RuntimeError("This method should not be called") +@WORKFLOWS.register_module("dummy_workflow_with_state") +class DummyWorkflowWithState(Workflow): + can_repeat: bool = True + is_async: bool = True + + def __init__(self, *, task, model: ModelWrapper, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.step_num = task.workflow_args.get("step_num", 1) + + def set_repeat_times(self, repeat_times, run_id_base): + self.repeat_times = repeat_times + self.run_id_base = run_id_base + + async def run_async(self) -> List[Experience]: + exps = [] + for i in range(self.repeat_times): + run_level_metrics = {"run_metrics": float(i + self.run_id_base)} + run_level_exps = [] + for step in range(self.step_num): + run_level_exps.append( + Experience( + eid=EID(run=i + self.run_id_base, step=step), + tokens=torch.zeros(5), + prompt_length=2, + prompt_text="success", + ) + ) + run_level_exps[-1].metrics = run_level_metrics + self.logger.info(f"Setting workflow state to repeat_cnt={i}") + await self.model.set_workflow_state({"repeat_cnt": i}) + await asyncio.sleep(1) + exps.extend(run_level_exps) + return exps + + @ray.remote class DummyModel(InferenceModel): def sync_model(self, model_version, update_weight_args_list): @@ -779,3 +815,58 @@ def tearDown(self): ray.shutdown() except Exception: pass + + +class TestRunnerStateCollection(unittest.IsolatedAsyncioTestCase): + + async def test_runner_state_collection(self): + ray.init(ignore_reinit_error=True) + config = get_template_config() + config.explorer.runner_per_model = 2 + config.explorer.runner_state_report_interval = 1 + config.explorer.max_repeat_times_per_runner = 2 + config.check_and_update() + scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()]) + # 4 runner in side the scheduler + await scheduler.start() + + tasks = [ + Task( + workflow=DummyWorkflowWithState, # type: ignore[type-abstract] + workflow_args={"step_num": 2}, + repeat_times=4, + raw_task={}, + ) + for _ in range(4) + ] + scheduler.schedule(tasks, batch_id=0) + + async def monitor_routine(): + runner_0_state_history = defaultdict(set) + for _ in range(16): + await asyncio.sleep(0.3) + states = scheduler.get_all_state() + self.assertEqual(len(states), 4) + for state in states.values(): + self.assertIn("runner_id", state) + self.assertIn("running_workflow_id", state) + self.assertIn("model_version", state) + self.assertIn("begin_time", state) + self.assertIn("terminate_time", state) + self.assertIn("repeat_cnt", state) + ids = scheduler.get_key_state("running_workflow_id") + self.assertEqual(len(ids), 4) + self.assertEqual(len(set(ids.values())), 4) + runner_0_state = scheduler.get_runner_state(0) + for key, value in runner_0_state.items(): + runner_0_state_history[key].add(value) + self.assertEqual(len(runner_0_state_history["repeat_cnt"]), 2) # max_repeat_times is 2 + self.assertEqual(len(runner_0_state_history["model_version"]), 1) + self.assertEqual(len(runner_0_state_history["running_workflow_id"]), 2) # split into 2 sub tasks + self.assertEqual(len(runner_0_state_history["begin_time"]), 2) + self.assertEqual(len(runner_0_state_history["runner_id"]), 1) + + await asyncio.gather( + monitor_routine(), + scheduler.get_results(batch_id=0), + ) \ No newline at end of file diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 81e1afa916..c53f50e80c 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -2,6 +2,7 @@ """Test for the workflow module""" import asyncio import unittest +from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, Optional from unittest import mock @@ -508,6 +509,49 @@ def tearDown(self): ray.shutdown(_exiting_interpreter=True) +class StateRecordingWorkflow(Workflow): + is_async: bool = True + + def __init__(self, *, task, model: ModelWrapper, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.wait_time = task.workflow_args.get("wait_time", 1) + + async def run_async(self): + for i in range(self.wait_time): + await self.model.set_workflow_state({"step": i}) + await asyncio.sleep(1) + return [Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, reward=1.0)] + + +class TestWorkflowStateRecording(unittest.IsolatedAsyncioTestCase): + async def test_workflow_state_recording(self): + model = MagicMock() + model_wrapper = ModelWrapper(model, engine_type="vllm") + + task = Task( + workflow=StateRecordingWorkflow, + repeat_times=3, + raw_task={}, + workflow_args={"wait_time": 3}, + ) + workflow = task.to_workflow(model_wrapper) + + async def monitor_routine(): + old_state = {} + count = 0 + for i in range(20): + await asyncio.sleep(0.2) + new_state = await model_wrapper.get_workflow_state() + print(new_state) + if new_state.get("step") != old_state.get("step"): + old_state = new_state + count += 1 + self.assertEqual(count, 3) + return count + + await asyncio.gather(*[monitor_routine(), workflow.run_async()]) + + class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase): @unittest.skip("Waiting for agentscope>=0.1.6") async def test_adapter(self): @@ -604,3 +648,51 @@ async def test_workflow_runner(self): self.assertTrue(status.ok) self.assertIsInstance(exps, list) self.assertEqual(len(exps), 2) + + async def test_workflow_runner_get_state(self): + config = get_template_config() + + async def mock_get_api_server_url_remote(): + return None + + async def mock_get_model_version_remote(): + return 1 + + model = MagicMock() + model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote) + model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote) + + runner = WorkflowRunner( + config, + model=model, + auxiliary_models=[], + runner_id=1, + ) + await runner.prepare() + + task = Task( + workflow=StateRecordingWorkflow, + raw_task={}, + workflow_args={"wait_time": 2}, + batch_id=1, + task_id=2, + ) + + async def monitor_routine(): + state_history = defaultdict(set) + count = 0 + for i in range(20): + await asyncio.sleep(0.4) + new_state = await runner.get_runner_state() + for k, v in new_state.items(): + state_history[k].add(v) + self.assertEqual(len(state_history["runner_id"]), 1) + self.assertEqual(len(state_history["model_version"]), 1) + self.assertEqual(len(state_history["running_workflow_id"]), 3) + self.assertEqual(len(state_history["begin_time"]), 3) + self.assertEqual(len(state_history["step"]), 2) + return count + + await asyncio.gather( + *[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)] + ) diff --git a/trinity/common/config.py b/trinity/common/config.py index 302e7556a5..2df9817071 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -670,6 +670,7 @@ class ExplorerConfig: # Experimental feature over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig) dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig) + runner_state_report_interval: int = 0 # report runner state every N seconds, 0 means disable @dataclass diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index ae2fdfe4a6..4bb7783256 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -61,6 +61,7 @@ def __init__( self.timeout = config.explorer.max_timeout self.namespace = ray.get_runtime_context().namespace self.runner = self._create_runner() + self.state = {} def _create_runner(self): return ( @@ -79,6 +80,10 @@ def _create_runner(self): async def prepare(self): await self.runner.prepare.remote() + async def update_state(self) -> None: + """Get the runner state.""" + self.state = await self.runner.get_runner_state.remote() + async def run_with_retry( self, task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float ) -> Tuple[Status, List, int, float]: @@ -184,6 +189,7 @@ def __init__( self.max_repeat_times = config.explorer.max_repeat_times_per_runner self.default_batch_size = config.buffer.batch_size self.running = False + self.runner_states = [] self.runner_num = len(rollout_model) * config.explorer.runner_per_model self.runners: Dict[int, RunnerWrapper] = dict() @@ -253,6 +259,23 @@ async def _scheduler_loop(self) -> None: await asyncio.sleep(0.1) self.logger.info("Scheduler loop stopped.") + async def _monitor_runner_state_loop(self) -> None: + interval = self.config.explorer.runner_state_report_interval + if interval <= 0: + self.logger.info("Runner state monitoring loop disabled.") + return + + self.logger.info("Runner state monitoring loop started.") + while self.running: + try: + await asyncio.gather(*[runner.update_state() for runner in self.runners.values()]) + except Exception: + self.logger.error( + f"Error in runner state monitoring loop:\n{traceback.format_exc()}" + ) + await asyncio.sleep(0.1) + self.logger.info("Runner state monitoring loop stopped.") + async def _schedule_pending_tasks(self) -> None: if not self.idle_runners: return @@ -331,6 +354,7 @@ async def start(self) -> None: self.running = True await asyncio.gather(*[self._create_runner(i) for i in range(self.runner_num)]) self.scheduler_task = asyncio.create_task(self._scheduler_loop()) + self.monitor_task = asyncio.create_task(self._monitor_runner_state_loop()) ready_refs = [runner.runner.__ray_ready__.remote() for runner in self.runners.values()] await asyncio.gather(*ready_refs) self.logger.info(f"Starting Scheduler with {self.runner_num} runners") @@ -354,6 +378,12 @@ async def stop(self) -> None: await self.scheduler_task except asyncio.CancelledError: pass + if self.monitor_task: + self.monitor_task.cancel() + try: + await self.monitor_task + except asyncio.CancelledError: + pass self.logger.info("Scheduler stopped") def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: @@ -525,3 +555,47 @@ async def wait_all( ) raise TimeoutError(error_msg) + + def get_key_state(self, key: str) -> Dict: + """Get the scheduler state. + + Args: + key (`str`): The key of the state to get. + + Returns: + `Dict`: A dictionary of runner ids to their state for the given key. + """ + result = {} + for runner in self.runners.values(): + runner_state = runner.state + if runner_state and key in runner_state: + result[runner.runner_id] = runner_state[key] + return result + + def get_runner_state(self, runner_id: int) -> Dict: + """Get the scheduler state. + + Args: + runner_id (`int`): The id of the runner. + + Returns: + `Dict`: The state of the runner. + """ + runner = self.runners.get(runner_id, None) + if runner: + return runner.state + else: + return {} + + def get_all_state(self) -> Dict: + """Get all runners' state. + + Returns: + `Dict`: The state of all runners. + """ + result = {} + for runner in self.runners.values(): + runner_state = runner.state + if runner_state: + result[runner.runner_id] = runner_state + return result diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 9a44518e7b..9358507c05 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -168,6 +168,13 @@ async def _run_task( self._create_workflow_instance(task) return exps, run_metrics + async def get_runner_state(self) -> Dict: + """Get the runner state.""" + async with self.lock: + runner_state = self.runner_state.copy() + runner_state.update(await self.model_wrapper.get_workflow_state()) + return runner_state + async def run_task( self, task: Task, From c104202d82983ee9b950cfb8efa892864fcf11d5 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 20 Nov 2025 18:45:31 +0800 Subject: [PATCH 03/11] fix pre-commit --- tests/explorer/scheduler_test.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 0a5fbf7714..04c5c5c979 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1,8 +1,8 @@ import asyncio import time import unittest -from typing import List, Optional from collections import defaultdict +from typing import List, Optional import ray import torch @@ -818,7 +818,6 @@ def tearDown(self): class TestRunnerStateCollection(unittest.IsolatedAsyncioTestCase): - async def test_runner_state_collection(self): ray.init(ignore_reinit_error=True) config = get_template_config() @@ -862,11 +861,13 @@ async def monitor_routine(): runner_0_state_history[key].add(value) self.assertEqual(len(runner_0_state_history["repeat_cnt"]), 2) # max_repeat_times is 2 self.assertEqual(len(runner_0_state_history["model_version"]), 1) - self.assertEqual(len(runner_0_state_history["running_workflow_id"]), 2) # split into 2 sub tasks + self.assertEqual( + len(runner_0_state_history["running_workflow_id"]), 2 + ) # split into 2 sub tasks self.assertEqual(len(runner_0_state_history["begin_time"]), 2) self.assertEqual(len(runner_0_state_history["runner_id"]), 1) await asyncio.gather( monitor_routine(), scheduler.get_results(batch_id=0), - ) \ No newline at end of file + ) From b7143ace9fdd07c8f7c2c117dc744f41bed46f1b Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 20 Nov 2025 19:16:50 +0800 Subject: [PATCH 04/11] fix comments --- tests/explorer/workflow_test.py | 1 - trinity/common/config.py | 4 +++- trinity/explorer/scheduler.py | 2 +- trinity/explorer/workflow_runner.py | 4 +--- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index c53f50e80c..af62c1b0d9 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -542,7 +542,6 @@ async def monitor_routine(): for i in range(20): await asyncio.sleep(0.2) new_state = await model_wrapper.get_workflow_state() - print(new_state) if new_state.get("step") != old_state.get("step"): old_state = new_state count += 1 diff --git a/trinity/common/config.py b/trinity/common/config.py index 2df9817071..9c19014a51 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -670,7 +670,9 @@ class ExplorerConfig: # Experimental feature over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig) dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig) - runner_state_report_interval: int = 0 # report runner state every N seconds, 0 means disable + runner_state_report_interval: int = ( + 0 # report runner state every `runner_state_report_interval` seconds, 0 to disable + ) @dataclass diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 4bb7783256..416c873de6 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -189,7 +189,6 @@ def __init__( self.max_repeat_times = config.explorer.max_repeat_times_per_runner self.default_batch_size = config.buffer.batch_size self.running = False - self.runner_states = [] self.runner_num = len(rollout_model) * config.explorer.runner_per_model self.runners: Dict[int, RunnerWrapper] = dict() @@ -268,6 +267,7 @@ async def _monitor_runner_state_loop(self) -> None: self.logger.info("Runner state monitoring loop started.") while self.running: try: + await asyncio.sleep(interval) await asyncio.gather(*[runner.update_state() for runner in self.runners.values()]) except Exception: self.logger.error( diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 9358507c05..5093b015d6 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -81,7 +81,6 @@ def __init__( "begin_time": 0, "terminate_time": 0, } - self.lock = asyncio.Lock() async def prepare(self) -> None: """Prepare the runner.""" @@ -170,8 +169,7 @@ async def _run_task( async def get_runner_state(self) -> Dict: """Get the runner state.""" - async with self.lock: - runner_state = self.runner_state.copy() + runner_state = self.runner_state.copy() runner_state.update(await self.model_wrapper.get_workflow_state()) return runner_state From 5a44c0381bab7bfe4ee8e1d0922ace24a019226e Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 20 Nov 2025 19:28:03 +0800 Subject: [PATCH 05/11] fix comments --- tests/explorer/scheduler_test.py | 3 ++- trinity/explorer/scheduler.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 04c5c5c979..049920a871 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -822,7 +822,7 @@ async def test_runner_state_collection(self): ray.init(ignore_reinit_error=True) config = get_template_config() config.explorer.runner_per_model = 2 - config.explorer.runner_state_report_interval = 1 + config.explorer.runner_state_report_interval = 0.5 config.explorer.max_repeat_times_per_runner = 2 config.check_and_update() scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()]) @@ -842,6 +842,7 @@ async def test_runner_state_collection(self): async def monitor_routine(): runner_0_state_history = defaultdict(set) + await asyncio.sleep(0.5) # wait for first report for _ in range(16): await asyncio.sleep(0.3) states = scheduler.get_all_state() diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 416c873de6..c2e9e29589 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -267,13 +267,12 @@ async def _monitor_runner_state_loop(self) -> None: self.logger.info("Runner state monitoring loop started.") while self.running: try: - await asyncio.sleep(interval) await asyncio.gather(*[runner.update_state() for runner in self.runners.values()]) except Exception: self.logger.error( f"Error in runner state monitoring loop:\n{traceback.format_exc()}" ) - await asyncio.sleep(0.1) + await asyncio.sleep(interval) self.logger.info("Runner state monitoring loop stopped.") async def _schedule_pending_tasks(self) -> None: @@ -354,9 +353,9 @@ async def start(self) -> None: self.running = True await asyncio.gather(*[self._create_runner(i) for i in range(self.runner_num)]) self.scheduler_task = asyncio.create_task(self._scheduler_loop()) - self.monitor_task = asyncio.create_task(self._monitor_runner_state_loop()) ready_refs = [runner.runner.__ray_ready__.remote() for runner in self.runners.values()] await asyncio.gather(*ready_refs) + self.monitor_task = asyncio.create_task(self._monitor_runner_state_loop()) self.logger.info(f"Starting Scheduler with {self.runner_num} runners") async def stop(self) -> None: From 14939cf15cd6253341e57734cdeae7aeb2e197da Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 20 Nov 2025 19:54:41 +0800 Subject: [PATCH 06/11] fix tests --- tests/explorer/workflow_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index af62c1b0d9..e8072228ce 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -602,6 +602,9 @@ def get_openai_client(self): def get_openai_async_client(self): return openai.AsyncOpenAI(api_key="EMPTY") + async def clean_workflow_state(self): + return + @property async def model_version_async(self): return 0 From b6d3a272bfc3ed380226dc3040d32a895a4a9eed Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 20 Nov 2025 20:51:51 +0800 Subject: [PATCH 07/11] print runner state --- trinity/common/config.py | 5 ++--- trinity/explorer/scheduler.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 9c19014a51..153fabb974 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -670,9 +670,8 @@ class ExplorerConfig: # Experimental feature over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig) dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig) - runner_state_report_interval: int = ( - 0 # report runner state every `runner_state_report_interval` seconds, 0 to disable - ) + # report runner state every `runner_state_report_interval` seconds, 0 to disable + runner_state_report_interval: int = 0 @dataclass diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index c2e9e29589..d86859dbc1 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -268,6 +268,7 @@ async def _monitor_runner_state_loop(self) -> None: while self.running: try: await asyncio.gather(*[runner.update_state() for runner in self.runners.values()]) + self.print_all_state() except Exception: self.logger.error( f"Error in runner state monitoring loop:\n{traceback.format_exc()}" @@ -598,3 +599,33 @@ def get_all_state(self) -> Dict: if runner_state: result[runner.runner_id] = runner_state return result + + def print_all_state(self) -> None: + """Print all runners' state in a clear, aligned table format.""" + all_keys = set() + for runner in self.runners.values(): + runner_state = runner.state + if runner_state: + all_keys.update(runner_state.keys()) + all_keys = sorted(all_keys) + # Prepare header + header = ["runner_id"] + all_keys # type: ignore [operator] + # Prepare rows + rows = [] + for runner in self.runners.values(): + runner_state = runner.state or {} + row = [str(runner.runner_id)] + for key in all_keys: + value = runner_state.get(key, "-") + row.append(str(value)) + rows.append(row) + # Calculate column widths + col_widths = [max(len(str(x)) for x in col) for col in zip(header, *rows)] + # Print header + header_line = " | ".join(str(h).ljust(w) for h, w in zip(header, col_widths)) + self.logger.info(header_line) + self.logger.info("-+-".join("-" * w for w in col_widths)) + # Print each row + for row in rows: + line = " | ".join(str(cell).ljust(w) for cell, w in zip(row, col_widths)) + self.logger.info(line) From b60ae7fc6cc8fc6e0f347ad824501edbc3031132 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 20 Nov 2025 21:03:38 +0800 Subject: [PATCH 08/11] update doc --- .../source/tutorial/trinity_configs.md | 2 ++ .../source_zh/tutorial/trinity_configs.md | 29 ++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index f25738f546..f3ea387118 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -377,6 +377,7 @@ explorer: dynamic_timeout: enable: false ratio: 3.0 + runner_state_report_interval: 0 ``` - `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. @@ -397,6 +398,7 @@ explorer: - `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks. - `enable`: Whether to enable dynamic timeout. Default is `false`. - `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`. +- `runner_state_report_interval`: Workflow runner report interval (in seconds). If set to a value greater than `0`, the workflow runner will periodically report its status to the main explorer process and print it in the command line for monitoring. Default is `0`, meaning this feature is disabled. If you want to use this feature, it is recommended to set it to `10` seconds or longer to minimize performance impact. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index c6dc344007..362f696ba3 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -368,27 +368,34 @@ explorer: tensor_parallel_size: 1 eval_interval: 100 eval_on_startup: True + over_rollout: + ratio: 0.0 + wait_after_min: 30.0 + dynamic_timeout: + enable: false + ratio: 3.0 + runner_state_report_interval: 0 ``` - `name`: explorer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。 -- `runner_per_model`: 每个 rollout 模型的并行工作流执行器数量。 -- `max_timeout`: 工作流完成的最大时间(秒)。 -- `max_retry_times`: 工作流的最大重试次数。 -- `env_vars`: 为每个工作流执行器设置的环境变量。 +- `runner_per_model`: 每个推理引擎实例所服务的 WorkflowRunner 数量。 +- `max_timeout`: 等待 Workflow 完成的最大时间(秒)。 +- `max_retry_times`: Workflow 失败或超时情况下的最大重试次数。 +- `env_vars`: 为每个 WorkflowRunner 设置的环境变量。 - `rollout_model.engine_type`: 推理引擎类型。支持 `vllm_async` 和 `vllm`,二者的含义相同,都使用了异步引擎。后续版本会只保留 `vllm`。 -- `rollout_model.engine_num`: 推理引擎数量。 -- `rollout_model.tensor_parallel_size`: 张量并行度。 -- `rollout_model.enable_history`: 是否启用模型调用历史记录。若设为 `True`,模型包装器会自动记录模型调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。 -- `auxiliary_models`: 用于自定义工作流的额外模型。 -- `eval_interval`: 模型评估的间隔(以步为单位)。 -- `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此重启时不会触发。 +- `rollout_model.engine_num`: 推理引擎实例的数量。 +- `rollout_model.tensor_parallel_size`: 每个实例的张量并行度。 +- `rollout_model.enable_history`: 是否启用模型调用历史记录功能。若设为 `True`,模型会自动记录调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。 +- `auxiliary_models`: 用于自定义工作流的辅助模型。 +- `eval_interval`: 模型评估的间隔(以 step 为单位)。 +- `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此中断训练后重启时不会触发该行为。 - `over_rollout`: [实验性] 超量 rollout 机制的配置,允许 explorer 在每个步骤中使用少于完整批次大小的任务继续进行。这在某些任务显著耗时较长的场景中能有效地提高吞吐量。仅当使用动态同步(`synchronizer.sync_style` 不是 `fixed`)时适用。 - `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。 - `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。 - `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。 - `enable`: 是否启用动态超时。默认为 `false`。 - `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。 - +- `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。 --- ## Synchronizer 配置 From 9faf849ac994fb0f6b8684eddfa14827817d533b Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 21 Nov 2025 12:08:00 +0800 Subject: [PATCH 09/11] simplify workflow runner state --- docs/sphinx_doc/source_zh/tutorial/trinity_configs.md | 3 ++- tests/explorer/scheduler_test.py | 8 +++----- tests/explorer/workflow_test.py | 3 +-- trinity/explorer/workflow_runner.py | 9 +++------ 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 362f696ba3..9ebab50322 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -395,7 +395,8 @@ explorer: - `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。 - `enable`: 是否启用动态超时。默认为 `false`。 - `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。 -- `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。 +- `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。推荐如需使用此功能,将其设置为 `10` 秒或更长时间以减少对性能的影响。 + --- ## Synchronizer 配置 diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 049920a871..2a132b57ad 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -848,13 +848,12 @@ async def monitor_routine(): states = scheduler.get_all_state() self.assertEqual(len(states), 4) for state in states.values(): - self.assertIn("runner_id", state) - self.assertIn("running_workflow_id", state) + self.assertIn("workflow_id", state) self.assertIn("model_version", state) self.assertIn("begin_time", state) self.assertIn("terminate_time", state) self.assertIn("repeat_cnt", state) - ids = scheduler.get_key_state("running_workflow_id") + ids = scheduler.get_key_state("workflow_id") self.assertEqual(len(ids), 4) self.assertEqual(len(set(ids.values())), 4) runner_0_state = scheduler.get_runner_state(0) @@ -863,10 +862,9 @@ async def monitor_routine(): self.assertEqual(len(runner_0_state_history["repeat_cnt"]), 2) # max_repeat_times is 2 self.assertEqual(len(runner_0_state_history["model_version"]), 1) self.assertEqual( - len(runner_0_state_history["running_workflow_id"]), 2 + len(runner_0_state_history["workflow_id"]), 2 ) # split into 2 sub tasks self.assertEqual(len(runner_0_state_history["begin_time"]), 2) - self.assertEqual(len(runner_0_state_history["runner_id"]), 1) await asyncio.gather( monitor_routine(), diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index e8072228ce..e0a86da6f1 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -688,9 +688,8 @@ async def monitor_routine(): new_state = await runner.get_runner_state() for k, v in new_state.items(): state_history[k].add(v) - self.assertEqual(len(state_history["runner_id"]), 1) self.assertEqual(len(state_history["model_version"]), 1) - self.assertEqual(len(state_history["running_workflow_id"]), 3) + self.assertEqual(len(state_history["workflow_id"]), 3) self.assertEqual(len(state_history["begin_time"]), 3) self.assertEqual(len(state_history["step"]), 2) return count diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 5093b015d6..314f8b94e0 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -75,8 +75,7 @@ def __init__( self.workflow_instance: Workflow = None self.runner_id = runner_id self.runner_state = { - "runner_id": self.runner_id, - "running_workflow_id": None, + "workflow_id": None, "model_version": None, "begin_time": 0, "terminate_time": 0, @@ -133,9 +132,7 @@ async def _run_task( self.workflow_instance.set_repeat_times(repeat_times, run_id_base) st = time.time() await self.model_wrapper.clean_workflow_state() - self.runner_state[ - "running_workflow_id" - ] = f"{task.batch_id}/{task.task_id}/{run_id_base}" + self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_id_base}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st exps = await self._run_workflow(self.workflow_instance) @@ -151,7 +148,7 @@ async def _run_task( for i in range(repeat_times): st = time.time() await self.model_wrapper.clean_workflow_state() - self.runner_state["running_workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" + self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st new_exps = await self._run_workflow(self.workflow_instance) From 54179beb1e01b16a2698a6bea50df2081ccb7869 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 21 Nov 2025 17:52:51 +0800 Subject: [PATCH 10/11] remove confusing log --- trinity/trainer/verl/fsdp_workers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 3e508c9124..dd977206c8 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -464,7 +464,7 @@ def _build_model_optimizer( # noqa: C901 num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + print(f"num_warmup_steps: {num_warmup_steps}") if warmup_style == "constant": actor_lr_scheduler = get_constant_schedule_with_warmup( @@ -1180,7 +1180,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + print(f"num_warmup_steps: {num_warmup_steps}") from verl.utils.torch_functional import ( get_constant_schedule_with_warmup, From 062dba771f99365cf935717636eb5ec4e929fba1 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 25 Nov 2025 14:53:26 +0800 Subject: [PATCH 11/11] add more metrics --- trinity/explorer/explorer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 686b08f080..30ece621d3 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -380,6 +380,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: metric.update(pipeline_metrics) if statuses: metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout")) + metric["rollout/finished_task_count"] = len(statuses) self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: @@ -392,10 +393,11 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - eval_results, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}") + statuses, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}") + metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses) metric.update( gather_metrics( - [status.metrics[0] for status in eval_results], f"{prefix}/{eval_task_name}" + [status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}" ) ) if self.eval_start_time is not None: