Skip to content

Commit ee05b3e

Browse files
authored
Support recording workflow running status (#397)
1 parent 01423d6 commit ee05b3e

File tree

10 files changed

+365
-20
lines changed

10 files changed

+365
-20
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ explorer:
379379
dynamic_timeout:
380380
enable: false
381381
ratio: 3.0
382+
runner_state_report_interval: 0
382383
```
383384

384385
- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
@@ -399,6 +400,7 @@ explorer:
399400
- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks.
400401
- `enable`: Whether to enable dynamic timeout. Default is `false`.
401402
- `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`.
403+
- `runner_state_report_interval`: Workflow runner report interval (in seconds). If set to a value greater than `0`, the workflow runner will periodically report its status to the main explorer process and print it in the command line for monitoring. Default is `0`, meaning this feature is disabled. If you want to use this feature, it is recommended to set it to `10` seconds or longer to minimize performance impact.
402404

403405
---
404406

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -370,26 +370,34 @@ explorer:
370370
tensor_parallel_size: 1
371371
eval_interval: 100
372372
eval_on_startup: True
373+
over_rollout:
374+
ratio: 0.0
375+
wait_after_min: 30.0
376+
dynamic_timeout:
377+
enable: false
378+
ratio: 3.0
379+
runner_state_report_interval: 0
373380
```
374381

375382
- `name`: explorer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。
376-
- `runner_per_model`: 每个 rollout 模型的并行工作流执行器数量
377-
- `max_timeout`: 工作流完成的最大时间(秒)。
378-
- `max_retry_times`: 工作流的最大重试次数
379-
- `env_vars`: 为每个工作流执行器设置的环境变量
383+
- `runner_per_model`: 每个推理引擎实例所服务的 WorkflowRunner 数量
384+
- `max_timeout`: 等待 Workflow 完成的最大时间(秒)。
385+
- `max_retry_times`: Workflow 失败或超时情况下的最大重试次数
386+
- `env_vars`: 为每个 WorkflowRunner 设置的环境变量
380387
- `rollout_model.engine_type`: 推理引擎类型。支持 `vllm_async` 和 `vllm`,二者的含义相同,都使用了异步引擎。后续版本会只保留 `vllm`。
381-
- `rollout_model.engine_num`: 推理引擎数量
382-
- `rollout_model.tensor_parallel_size`: 张量并行度
383-
- `rollout_model.enable_history`: 是否启用模型调用历史记录。若设为 `True`,模型包装器会自动记录模型调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。
384-
- `auxiliary_models`: 用于自定义工作流的额外模型
385-
- `eval_interval`: 模型评估的间隔(以步为单位)。
386-
- `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此重启时不会触发
388+
- `rollout_model.engine_num`: 推理引擎实例的数量
389+
- `rollout_model.tensor_parallel_size`: 每个实例的张量并行度
390+
- `rollout_model.enable_history`: 是否启用模型调用历史记录功能。若设为 `True`,模型会自动记录调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。
391+
- `auxiliary_models`: 用于自定义工作流的辅助模型
392+
- `eval_interval`: 模型评估的间隔(以 step 为单位)。
393+
- `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此中断训练后重启时不会触发该行为
387394
- `over_rollout`: [实验性] 超量 rollout 机制的配置,允许 explorer 在每个步骤中使用少于完整批次大小的任务继续进行。这在某些任务显著耗时较长的场景中能有效地提高吞吐量。仅当使用动态同步(`synchronizer.sync_style` 不是 `fixed`)时适用。
388395
- `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。
389396
- `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。
390397
- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。
391398
- `enable`: 是否启用动态超时。默认为 `false`。
392399
- `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。
400+
- `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。推荐如需使用此功能,将其设置为 `10` 秒或更长时间以减少对性能的影响。
393401

394402
---
395403

tests/explorer/scheduler_test.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import time
33
import unittest
4+
from collections import defaultdict
45
from typing import List, Optional
56

67
import ray
@@ -11,7 +12,7 @@
1112
from trinity.common.config import ExperienceBufferConfig
1213
from trinity.common.constants import StorageType, SyncStyle
1314
from trinity.common.experience import EID, Experience
14-
from trinity.common.models.model import InferenceModel
15+
from trinity.common.models.model import InferenceModel, ModelWrapper
1516
from trinity.common.workflows import Task
1617
from trinity.common.workflows.workflow import WORKFLOWS, Workflow
1718
from trinity.explorer.scheduler import Scheduler
@@ -134,6 +135,41 @@ def run(self):
134135
raise RuntimeError("This method should not be called")
135136

136137

138+
@WORKFLOWS.register_module("dummy_workflow_with_state")
139+
class DummyWorkflowWithState(Workflow):
140+
can_repeat: bool = True
141+
is_async: bool = True
142+
143+
def __init__(self, *, task, model: ModelWrapper, auxiliary_models):
144+
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
145+
self.step_num = task.workflow_args.get("step_num", 1)
146+
147+
def set_repeat_times(self, repeat_times, run_id_base):
148+
self.repeat_times = repeat_times
149+
self.run_id_base = run_id_base
150+
151+
async def run_async(self) -> List[Experience]:
152+
exps = []
153+
for i in range(self.repeat_times):
154+
run_level_metrics = {"run_metrics": float(i + self.run_id_base)}
155+
run_level_exps = []
156+
for step in range(self.step_num):
157+
run_level_exps.append(
158+
Experience(
159+
eid=EID(run=i + self.run_id_base, step=step),
160+
tokens=torch.zeros(5),
161+
prompt_length=2,
162+
prompt_text="success",
163+
)
164+
)
165+
run_level_exps[-1].metrics = run_level_metrics
166+
self.logger.info(f"Setting workflow state to repeat_cnt={i}")
167+
await self.model.set_workflow_state({"repeat_cnt": i})
168+
await asyncio.sleep(1)
169+
exps.extend(run_level_exps)
170+
return exps
171+
172+
137173
@ray.remote
138174
class DummyModel(InferenceModel):
139175
def sync_model(self, model_version, update_weight_args_list):
@@ -779,3 +815,58 @@ def tearDown(self):
779815
ray.shutdown()
780816
except Exception:
781817
pass
818+
819+
820+
class TestRunnerStateCollection(unittest.IsolatedAsyncioTestCase):
821+
async def test_runner_state_collection(self):
822+
ray.init(ignore_reinit_error=True)
823+
config = get_template_config()
824+
config.explorer.runner_per_model = 2
825+
config.explorer.runner_state_report_interval = 0.5
826+
config.explorer.max_repeat_times_per_runner = 2
827+
config.check_and_update()
828+
scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()])
829+
# 4 runner in side the scheduler
830+
await scheduler.start()
831+
832+
tasks = [
833+
Task(
834+
workflow=DummyWorkflowWithState, # type: ignore[type-abstract]
835+
workflow_args={"step_num": 2},
836+
repeat_times=4,
837+
raw_task={},
838+
)
839+
for _ in range(4)
840+
]
841+
scheduler.schedule(tasks, batch_id=0)
842+
843+
async def monitor_routine():
844+
runner_0_state_history = defaultdict(set)
845+
await asyncio.sleep(0.5) # wait for first report
846+
for _ in range(16):
847+
await asyncio.sleep(0.3)
848+
states = scheduler.get_all_state()
849+
self.assertEqual(len(states), 4)
850+
for state in states.values():
851+
self.assertIn("workflow_id", state)
852+
self.assertIn("model_version", state)
853+
self.assertIn("begin_time", state)
854+
self.assertIn("terminate_time", state)
855+
self.assertIn("repeat_cnt", state)
856+
ids = scheduler.get_key_state("workflow_id")
857+
self.assertEqual(len(ids), 4)
858+
self.assertEqual(len(set(ids.values())), 4)
859+
runner_0_state = scheduler.get_runner_state(0)
860+
for key, value in runner_0_state.items():
861+
runner_0_state_history[key].add(value)
862+
self.assertEqual(len(runner_0_state_history["repeat_cnt"]), 2) # max_repeat_times is 2
863+
self.assertEqual(len(runner_0_state_history["model_version"]), 1)
864+
self.assertEqual(
865+
len(runner_0_state_history["workflow_id"]), 2
866+
) # split into 2 sub tasks
867+
self.assertEqual(len(runner_0_state_history["begin_time"]), 2)
868+
869+
await asyncio.gather(
870+
monitor_routine(),
871+
scheduler.get_results(batch_id=0),
872+
)

tests/explorer/workflow_test.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Test for the workflow module"""
33
import asyncio
44
import unittest
5+
from collections import defaultdict
56
from dataclasses import dataclass, field
67
from typing import Dict, Optional
78
from unittest import mock
@@ -508,6 +509,48 @@ def tearDown(self):
508509
ray.shutdown(_exiting_interpreter=True)
509510

510511

512+
class StateRecordingWorkflow(Workflow):
513+
is_async: bool = True
514+
515+
def __init__(self, *, task, model: ModelWrapper, auxiliary_models):
516+
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
517+
self.wait_time = task.workflow_args.get("wait_time", 1)
518+
519+
async def run_async(self):
520+
for i in range(self.wait_time):
521+
await self.model.set_workflow_state({"step": i})
522+
await asyncio.sleep(1)
523+
return [Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, reward=1.0)]
524+
525+
526+
class TestWorkflowStateRecording(unittest.IsolatedAsyncioTestCase):
527+
async def test_workflow_state_recording(self):
528+
model = MagicMock()
529+
model_wrapper = ModelWrapper(model, engine_type="vllm")
530+
531+
task = Task(
532+
workflow=StateRecordingWorkflow,
533+
repeat_times=3,
534+
raw_task={},
535+
workflow_args={"wait_time": 3},
536+
)
537+
workflow = task.to_workflow(model_wrapper)
538+
539+
async def monitor_routine():
540+
old_state = {}
541+
count = 0
542+
for i in range(20):
543+
await asyncio.sleep(0.2)
544+
new_state = await model_wrapper.get_workflow_state()
545+
if new_state.get("step") != old_state.get("step"):
546+
old_state = new_state
547+
count += 1
548+
self.assertEqual(count, 3)
549+
return count
550+
551+
await asyncio.gather(*[monitor_routine(), workflow.run_async()])
552+
553+
511554
class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase):
512555
@unittest.skip("Waiting for agentscope>=0.1.6")
513556
async def test_adapter(self):
@@ -559,6 +602,9 @@ def get_openai_client(self):
559602
def get_openai_async_client(self):
560603
return openai.AsyncOpenAI(api_key="EMPTY")
561604

605+
async def clean_workflow_state(self):
606+
return
607+
562608
@property
563609
async def model_version_async(self):
564610
return 0
@@ -604,3 +650,50 @@ async def test_workflow_runner(self):
604650
self.assertTrue(status.ok)
605651
self.assertIsInstance(exps, list)
606652
self.assertEqual(len(exps), 2)
653+
654+
async def test_workflow_runner_get_state(self):
655+
config = get_template_config()
656+
657+
async def mock_get_api_server_url_remote():
658+
return None
659+
660+
async def mock_get_model_version_remote():
661+
return 1
662+
663+
model = MagicMock()
664+
model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote)
665+
model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote)
666+
667+
runner = WorkflowRunner(
668+
config,
669+
model=model,
670+
auxiliary_models=[],
671+
runner_id=1,
672+
)
673+
await runner.prepare()
674+
675+
task = Task(
676+
workflow=StateRecordingWorkflow,
677+
raw_task={},
678+
workflow_args={"wait_time": 2},
679+
batch_id=1,
680+
task_id=2,
681+
)
682+
683+
async def monitor_routine():
684+
state_history = defaultdict(set)
685+
count = 0
686+
for i in range(20):
687+
await asyncio.sleep(0.4)
688+
new_state = await runner.get_runner_state()
689+
for k, v in new_state.items():
690+
state_history[k].add(v)
691+
self.assertEqual(len(state_history["model_version"]), 1)
692+
self.assertEqual(len(state_history["workflow_id"]), 3)
693+
self.assertEqual(len(state_history["begin_time"]), 3)
694+
self.assertEqual(len(state_history["step"]), 2)
695+
return count
696+
697+
await asyncio.gather(
698+
*[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)]
699+
)

trinity/common/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,8 @@ class ExplorerConfig:
674674
# Experimental feature
675675
over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig)
676676
dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig)
677+
# report runner state every `runner_state_report_interval` seconds, 0 to disable
678+
runner_state_report_interval: int = 0
677679

678680

679681
@dataclass

trinity/common/models/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import socket
55
from abc import ABC, abstractmethod
66
from functools import partial
7-
from typing import List, Optional, Sequence, Tuple, Union
7+
from typing import Dict, List, Optional, Sequence, Tuple, Union
88

99
import httpx
1010
import numpy as np
@@ -103,7 +103,9 @@ def __init__(
103103
self.enable_history = enable_history
104104
self.history = []
105105
self.status = RunningStatus.RUNNING
106+
self.workflow_state: Dict = {}
106107
self.request_count = 0
108+
self.state_lock = asyncio.Lock()
107109

108110
async def prepare(self) -> None:
109111
"""Prepare the model wrapper."""
@@ -364,6 +366,22 @@ def extract_experience_from_history(self, clear_history: bool = True) -> List[Ex
364366
self.history.clear()
365367
return exps
366368

369+
# Workflow state management methods
370+
async def set_workflow_state(self, state: Dict) -> None:
371+
"""Set the state of workflow using the model."""
372+
async with self.state_lock:
373+
self.workflow_state.update(state)
374+
375+
async def clean_workflow_state(self) -> None:
376+
"""Clean the state of workflow using the model."""
377+
async with self.state_lock:
378+
self.workflow_state = {}
379+
380+
async def get_workflow_state(self) -> Dict:
381+
"""Get the state of workflow using the model."""
382+
async with self.state_lock:
383+
return self.workflow_state.copy()
384+
367385

368386
def convert_api_output_to_experience(
369387
output,

trinity/explorer/explorer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None:
380380
metric.update(pipeline_metrics)
381381
if statuses:
382382
metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout"))
383+
metric["rollout/finished_task_count"] = len(statuses)
383384
self.monitor.log(metric, step=step)
384385

385386
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
@@ -392,10 +393,11 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
392393
if eval_step != step:
393394
return
394395
self.pending_eval_tasks.popleft()
395-
eval_results, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}")
396+
statuses, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}")
397+
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
396398
metric.update(
397399
gather_metrics(
398-
[status.metrics[0] for status in eval_results], f"{prefix}/{eval_task_name}"
400+
[status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}"
399401
)
400402
)
401403
if self.eval_start_time is not None:

0 commit comments

Comments
 (0)