Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ explorer:
dynamic_timeout:
enable: false
ratio: 3.0
runner_state_report_interval: 0
```

- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
Expand All @@ -397,6 +398,7 @@ explorer:
- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks.
- `enable`: Whether to enable dynamic timeout. Default is `false`.
- `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`.
- `runner_state_report_interval`: Workflow runner report interval (in seconds). If set to a value greater than `0`, the workflow runner will periodically report its status to the main explorer process and print it in the command line for monitoring. Default is `0`, meaning this feature is disabled. If you want to use this feature, it is recommended to set it to `10` seconds or longer to minimize performance impact.

---

Expand Down
28 changes: 18 additions & 10 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,26 +368,34 @@ explorer:
tensor_parallel_size: 1
eval_interval: 100
eval_on_startup: True
over_rollout:
ratio: 0.0
wait_after_min: 30.0
dynamic_timeout:
enable: false
ratio: 3.0
runner_state_report_interval: 0
```

- `name`: explorer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。
- `runner_per_model`: 每个 rollout 模型的并行工作流执行器数量
- `max_timeout`: 工作流完成的最大时间(秒)。
- `max_retry_times`: 工作流的最大重试次数
- `env_vars`: 为每个工作流执行器设置的环境变量
- `runner_per_model`: 每个推理引擎实例所服务的 WorkflowRunner 数量
- `max_timeout`: 等待 Workflow 完成的最大时间(秒)。
- `max_retry_times`: Workflow 失败或超时情况下的最大重试次数
- `env_vars`: 为每个 WorkflowRunner 设置的环境变量
- `rollout_model.engine_type`: 推理引擎类型。支持 `vllm_async` 和 `vllm`,二者的含义相同,都使用了异步引擎。后续版本会只保留 `vllm`。
- `rollout_model.engine_num`: 推理引擎数量
- `rollout_model.tensor_parallel_size`: 张量并行度
- `rollout_model.enable_history`: 是否启用模型调用历史记录。若设为 `True`,模型包装器会自动记录模型调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。
- `auxiliary_models`: 用于自定义工作流的额外模型
- `eval_interval`: 模型评估的间隔(以步为单位)。
- `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此重启时不会触发
- `rollout_model.engine_num`: 推理引擎实例的数量
- `rollout_model.tensor_parallel_size`: 每个实例的张量并行度
- `rollout_model.enable_history`: 是否启用模型调用历史记录功能。若设为 `True`,模型会自动记录调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。
- `auxiliary_models`: 用于自定义工作流的辅助模型
- `eval_interval`: 模型评估的间隔(以 step 为单位)。
- `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此中断训练后重启时不会触发该行为
- `over_rollout`: [实验性] 超量 rollout 机制的配置,允许 explorer 在每个步骤中使用少于完整批次大小的任务继续进行。这在某些任务显著耗时较长的场景中能有效地提高吞吐量。仅当使用动态同步(`synchronizer.sync_style` 不是 `fixed`)时适用。
- `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。
- `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。
- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。
- `enable`: 是否启用动态超时。默认为 `false`。
- `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。
- `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。推荐如需使用此功能,将其设置为 `10` 秒或更长时间以减少对性能的影响。

---

Expand Down
93 changes: 92 additions & 1 deletion tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import time
import unittest
from collections import defaultdict
from typing import List, Optional

import ray
Expand All @@ -11,7 +12,7 @@
from trinity.common.config import ExperienceBufferConfig
from trinity.common.constants import StorageType, SyncStyle
from trinity.common.experience import EID, Experience
from trinity.common.models.model import InferenceModel
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task
from trinity.common.workflows.workflow import WORKFLOWS, Workflow
from trinity.explorer.scheduler import Scheduler
Expand Down Expand Up @@ -134,6 +135,41 @@ def run(self):
raise RuntimeError("This method should not be called")


@WORKFLOWS.register_module("dummy_workflow_with_state")
class DummyWorkflowWithState(Workflow):
can_repeat: bool = True
is_async: bool = True

def __init__(self, *, task, model: ModelWrapper, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.step_num = task.workflow_args.get("step_num", 1)

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base

async def run_async(self) -> List[Experience]:
exps = []
for i in range(self.repeat_times):
run_level_metrics = {"run_metrics": float(i + self.run_id_base)}
run_level_exps = []
for step in range(self.step_num):
run_level_exps.append(
Experience(
eid=EID(run=i + self.run_id_base, step=step),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
)
)
run_level_exps[-1].metrics = run_level_metrics
self.logger.info(f"Setting workflow state to repeat_cnt={i}")
await self.model.set_workflow_state({"repeat_cnt": i})
await asyncio.sleep(1)
exps.extend(run_level_exps)
return exps


@ray.remote
class DummyModel(InferenceModel):
def sync_model(self, model_version, update_weight_args_list):
Expand Down Expand Up @@ -779,3 +815,58 @@ def tearDown(self):
ray.shutdown()
except Exception:
pass


class TestRunnerStateCollection(unittest.IsolatedAsyncioTestCase):
async def test_runner_state_collection(self):
ray.init(ignore_reinit_error=True)
config = get_template_config()
config.explorer.runner_per_model = 2
config.explorer.runner_state_report_interval = 0.5
config.explorer.max_repeat_times_per_runner = 2
config.check_and_update()
scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()])
# 4 runner in side the scheduler
await scheduler.start()

tasks = [
Task(
workflow=DummyWorkflowWithState, # type: ignore[type-abstract]
workflow_args={"step_num": 2},
repeat_times=4,
raw_task={},
)
for _ in range(4)
]
scheduler.schedule(tasks, batch_id=0)

async def monitor_routine():
runner_0_state_history = defaultdict(set)
await asyncio.sleep(0.5) # wait for first report
for _ in range(16):
await asyncio.sleep(0.3)
states = scheduler.get_all_state()
self.assertEqual(len(states), 4)
for state in states.values():
self.assertIn("workflow_id", state)
self.assertIn("model_version", state)
self.assertIn("begin_time", state)
self.assertIn("terminate_time", state)
self.assertIn("repeat_cnt", state)
ids = scheduler.get_key_state("workflow_id")
self.assertEqual(len(ids), 4)
self.assertEqual(len(set(ids.values())), 4)
runner_0_state = scheduler.get_runner_state(0)
for key, value in runner_0_state.items():
runner_0_state_history[key].add(value)
self.assertEqual(len(runner_0_state_history["repeat_cnt"]), 2) # max_repeat_times is 2
self.assertEqual(len(runner_0_state_history["model_version"]), 1)
self.assertEqual(
len(runner_0_state_history["workflow_id"]), 2
) # split into 2 sub tasks
self.assertEqual(len(runner_0_state_history["begin_time"]), 2)

await asyncio.gather(
monitor_routine(),
scheduler.get_results(batch_id=0),
)
93 changes: 93 additions & 0 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Test for the workflow module"""
import asyncio
import unittest
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Optional
from unittest import mock
Expand Down Expand Up @@ -508,6 +509,48 @@ def tearDown(self):
ray.shutdown(_exiting_interpreter=True)


class StateRecordingWorkflow(Workflow):
is_async: bool = True

def __init__(self, *, task, model: ModelWrapper, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.wait_time = task.workflow_args.get("wait_time", 1)

async def run_async(self):
for i in range(self.wait_time):
await self.model.set_workflow_state({"step": i})
await asyncio.sleep(1)
return [Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, reward=1.0)]


class TestWorkflowStateRecording(unittest.IsolatedAsyncioTestCase):
async def test_workflow_state_recording(self):
model = MagicMock()
model_wrapper = ModelWrapper(model, engine_type="vllm")

task = Task(
workflow=StateRecordingWorkflow,
repeat_times=3,
raw_task={},
workflow_args={"wait_time": 3},
)
workflow = task.to_workflow(model_wrapper)

async def monitor_routine():
old_state = {}
count = 0
for i in range(20):
await asyncio.sleep(0.2)
new_state = await model_wrapper.get_workflow_state()
if new_state.get("step") != old_state.get("step"):
old_state = new_state
count += 1
self.assertEqual(count, 3)
return count

await asyncio.gather(*[monitor_routine(), workflow.run_async()])


class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase):
@unittest.skip("Waiting for agentscope>=0.1.6")
async def test_adapter(self):
Expand Down Expand Up @@ -559,6 +602,9 @@ def get_openai_client(self):
def get_openai_async_client(self):
return openai.AsyncOpenAI(api_key="EMPTY")

async def clean_workflow_state(self):
return

@property
async def model_version_async(self):
return 0
Expand Down Expand Up @@ -604,3 +650,50 @@ async def test_workflow_runner(self):
self.assertTrue(status.ok)
self.assertIsInstance(exps, list)
self.assertEqual(len(exps), 2)

async def test_workflow_runner_get_state(self):
config = get_template_config()

async def mock_get_api_server_url_remote():
return None

async def mock_get_model_version_remote():
return 1

model = MagicMock()
model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote)
model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote)

runner = WorkflowRunner(
config,
model=model,
auxiliary_models=[],
runner_id=1,
)
await runner.prepare()

task = Task(
workflow=StateRecordingWorkflow,
raw_task={},
workflow_args={"wait_time": 2},
batch_id=1,
task_id=2,
)

async def monitor_routine():
state_history = defaultdict(set)
count = 0
for i in range(20):
await asyncio.sleep(0.4)
new_state = await runner.get_runner_state()
for k, v in new_state.items():
state_history[k].add(v)
self.assertEqual(len(state_history["model_version"]), 1)
self.assertEqual(len(state_history["workflow_id"]), 3)
self.assertEqual(len(state_history["begin_time"]), 3)
self.assertEqual(len(state_history["step"]), 2)
return count

await asyncio.gather(
*[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)]
)
2 changes: 2 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,8 @@ class ExplorerConfig:
# Experimental feature
over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig)
dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig)
# report runner state every `runner_state_report_interval` seconds, 0 to disable
runner_state_report_interval: int = 0


@dataclass
Expand Down
20 changes: 19 additions & 1 deletion trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
from abc import ABC, abstractmethod
from functools import partial
from typing import List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import httpx
import numpy as np
Expand Down Expand Up @@ -103,7 +103,9 @@ def __init__(
self.enable_history = enable_history
self.history = []
self.status = RunningStatus.RUNNING
self.workflow_state: Dict = {}
self.request_count = 0
self.state_lock = asyncio.Lock()

async def prepare(self) -> None:
"""Prepare the model wrapper."""
Expand Down Expand Up @@ -361,6 +363,22 @@ def extract_experience_from_history(self, clear_history: bool = True) -> List[Ex
self.history.clear()
return exps

# Workflow state management methods
async def set_workflow_state(self, state: Dict) -> None:
"""Set the state of workflow using the model."""
async with self.state_lock:
self.workflow_state.update(state)

async def clean_workflow_state(self) -> None:
"""Clean the state of workflow using the model."""
async with self.state_lock:
self.workflow_state = {}

async def get_workflow_state(self) -> Dict:
"""Get the state of workflow using the model."""
async with self.state_lock:
return self.workflow_state.copy()


def convert_api_output_to_experience(
output,
Expand Down
6 changes: 4 additions & 2 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None:
metric.update(pipeline_metrics)
if statuses:
metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout"))
metric["rollout/finished_task_count"] = len(statuses)
self.monitor.log(metric, step=step)

async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
Expand All @@ -392,10 +393,11 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
if eval_step != step:
return
self.pending_eval_tasks.popleft()
eval_results, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}")
statuses, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}")
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
metric.update(
gather_metrics(
[status.metrics[0] for status in eval_results], f"{prefix}/{eval_task_name}"
[status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}"
)
)
if self.eval_start_time is not None:
Expand Down
Loading