Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import time
import unittest
from collections import defaultdict
from typing import List, Optional

import ray
Expand All @@ -11,7 +12,7 @@
from trinity.common.config import ExperienceBufferConfig
from trinity.common.constants import StorageType, SyncStyle
from trinity.common.experience import EID, Experience
from trinity.common.models.model import InferenceModel
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task
from trinity.common.workflows.workflow import WORKFLOWS, Workflow
from trinity.explorer.scheduler import Scheduler
Expand Down Expand Up @@ -134,6 +135,41 @@ def run(self):
raise RuntimeError("This method should not be called")


@WORKFLOWS.register_module("dummy_workflow_with_state")
class DummyWorkflowWithState(Workflow):
can_repeat: bool = True
is_async: bool = True

def __init__(self, *, task, model: ModelWrapper, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.step_num = task.workflow_args.get("step_num", 1)

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base

async def run_async(self) -> List[Experience]:
exps = []
for i in range(self.repeat_times):
run_level_metrics = {"run_metrics": float(i + self.run_id_base)}
run_level_exps = []
for step in range(self.step_num):
run_level_exps.append(
Experience(
eid=EID(run=i + self.run_id_base, step=step),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
)
)
run_level_exps[-1].metrics = run_level_metrics
self.logger.info(f"Setting workflow state to repeat_cnt={i}")
await self.model.set_workflow_state({"repeat_cnt": i})
await asyncio.sleep(1)
exps.extend(run_level_exps)
return exps


@ray.remote
class DummyModel(InferenceModel):
def sync_model(self, model_version, update_weight_args_list):
Expand Down Expand Up @@ -779,3 +815,59 @@ def tearDown(self):
ray.shutdown()
except Exception:
pass


class TestRunnerStateCollection(unittest.IsolatedAsyncioTestCase):
async def test_runner_state_collection(self):
ray.init(ignore_reinit_error=True)
config = get_template_config()
config.explorer.runner_per_model = 2
config.explorer.runner_state_report_interval = 1
config.explorer.max_repeat_times_per_runner = 2
config.check_and_update()
scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()])
# 4 runner in side the scheduler
await scheduler.start()

tasks = [
Task(
workflow=DummyWorkflowWithState, # type: ignore[type-abstract]
workflow_args={"step_num": 2},
repeat_times=4,
raw_task={},
)
for _ in range(4)
]
scheduler.schedule(tasks, batch_id=0)

async def monitor_routine():
runner_0_state_history = defaultdict(set)
for _ in range(16):
await asyncio.sleep(0.3)
states = scheduler.get_all_state()
self.assertEqual(len(states), 4)
for state in states.values():
self.assertIn("runner_id", state)
self.assertIn("running_workflow_id", state)
self.assertIn("model_version", state)
self.assertIn("begin_time", state)
self.assertIn("terminate_time", state)
self.assertIn("repeat_cnt", state)
ids = scheduler.get_key_state("running_workflow_id")
self.assertEqual(len(ids), 4)
self.assertEqual(len(set(ids.values())), 4)
runner_0_state = scheduler.get_runner_state(0)
for key, value in runner_0_state.items():
runner_0_state_history[key].add(value)
self.assertEqual(len(runner_0_state_history["repeat_cnt"]), 2) # max_repeat_times is 2
self.assertEqual(len(runner_0_state_history["model_version"]), 1)
self.assertEqual(
len(runner_0_state_history["running_workflow_id"]), 2
) # split into 2 sub tasks
self.assertEqual(len(runner_0_state_history["begin_time"]), 2)
self.assertEqual(len(runner_0_state_history["runner_id"]), 1)

await asyncio.gather(
monitor_routine(),
scheduler.get_results(batch_id=0),
)
92 changes: 92 additions & 0 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Test for the workflow module"""
import asyncio
import unittest
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Optional
from unittest import mock
Expand Down Expand Up @@ -508,6 +509,49 @@ def tearDown(self):
ray.shutdown(_exiting_interpreter=True)


class StateRecordingWorkflow(Workflow):
is_async: bool = True

def __init__(self, *, task, model: ModelWrapper, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.wait_time = task.workflow_args.get("wait_time", 1)

async def run_async(self):
for i in range(self.wait_time):
await self.model.set_workflow_state({"step": i})
await asyncio.sleep(1)
return [Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, reward=1.0)]


class TestWorkflowStateRecording(unittest.IsolatedAsyncioTestCase):
async def test_workflow_state_recording(self):
model = MagicMock()
model_wrapper = ModelWrapper(model, engine_type="vllm")

task = Task(
workflow=StateRecordingWorkflow,
repeat_times=3,
raw_task={},
workflow_args={"wait_time": 3},
)
workflow = task.to_workflow(model_wrapper)

async def monitor_routine():
old_state = {}
count = 0
for i in range(20):
await asyncio.sleep(0.2)
new_state = await model_wrapper.get_workflow_state()
print(new_state)
if new_state.get("step") != old_state.get("step"):
old_state = new_state
count += 1
self.assertEqual(count, 3)
return count

await asyncio.gather(*[monitor_routine(), workflow.run_async()])


class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase):
@unittest.skip("Waiting for agentscope>=0.1.6")
async def test_adapter(self):
Expand Down Expand Up @@ -604,3 +648,51 @@ async def test_workflow_runner(self):
self.assertTrue(status.ok)
self.assertIsInstance(exps, list)
self.assertEqual(len(exps), 2)

async def test_workflow_runner_get_state(self):
config = get_template_config()

async def mock_get_api_server_url_remote():
return None

async def mock_get_model_version_remote():
return 1

model = MagicMock()
model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote)
model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote)

runner = WorkflowRunner(
config,
model=model,
auxiliary_models=[],
runner_id=1,
)
await runner.prepare()

task = Task(
workflow=StateRecordingWorkflow,
raw_task={},
workflow_args={"wait_time": 2},
batch_id=1,
task_id=2,
)

async def monitor_routine():
state_history = defaultdict(set)
count = 0
for i in range(20):
await asyncio.sleep(0.4)
new_state = await runner.get_runner_state()
for k, v in new_state.items():
state_history[k].add(v)
self.assertEqual(len(state_history["runner_id"]), 1)
self.assertEqual(len(state_history["model_version"]), 1)
self.assertEqual(len(state_history["running_workflow_id"]), 3)
self.assertEqual(len(state_history["begin_time"]), 3)
self.assertEqual(len(state_history["step"]), 2)
return count

await asyncio.gather(
*[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)]
)
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ class ExplorerConfig:
# Experimental feature
over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig)
dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig)
runner_state_report_interval: int = 0 # report runner state every N seconds, 0 means disable


@dataclass
Expand Down
20 changes: 19 additions & 1 deletion trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
from abc import ABC, abstractmethod
from functools import partial
from typing import List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import httpx
import numpy as np
Expand Down Expand Up @@ -103,7 +103,9 @@ def __init__(
self.enable_history = enable_history
self.history = []
self.status = RunningStatus.RUNNING
self.workflow_state: Dict = {}
self.request_count = 0
self.state_lock = asyncio.Lock()

async def prepare(self) -> None:
"""Prepare the model wrapper."""
Expand Down Expand Up @@ -361,6 +363,22 @@ def extract_experience_from_history(self, clear_history: bool = True) -> List[Ex
self.history.clear()
return exps

# Workflow state management methods
async def set_workflow_state(self, state: Dict) -> None:
"""Set the state of workflow using the model."""
async with self.state_lock:
self.workflow_state.update(state)

async def clean_workflow_state(self) -> None:
"""Clean the state of workflow using the model."""
async with self.state_lock:
self.workflow_state = {}

async def get_workflow_state(self) -> Dict:
"""Get the state of workflow using the model."""
async with self.state_lock:
return self.workflow_state.copy()


def convert_api_output_to_experience(
output,
Expand Down
74 changes: 74 additions & 0 deletions trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
self.timeout = config.explorer.max_timeout
self.namespace = ray.get_runtime_context().namespace
self.runner = self._create_runner()
self.state = {}

def _create_runner(self):
return (
Expand All @@ -79,6 +80,10 @@ def _create_runner(self):
async def prepare(self):
await self.runner.prepare.remote()

async def update_state(self) -> None:
"""Get the runner state."""
self.state = await self.runner.get_runner_state.remote()

async def run_with_retry(
self, task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float
) -> Tuple[Status, List, int, float]:
Expand Down Expand Up @@ -184,6 +189,7 @@ def __init__(
self.max_repeat_times = config.explorer.max_repeat_times_per_runner
self.default_batch_size = config.buffer.batch_size
self.running = False
self.runner_states = []

self.runner_num = len(rollout_model) * config.explorer.runner_per_model
self.runners: Dict[int, RunnerWrapper] = dict()
Expand Down Expand Up @@ -253,6 +259,23 @@ async def _scheduler_loop(self) -> None:
await asyncio.sleep(0.1)
self.logger.info("Scheduler loop stopped.")

async def _monitor_runner_state_loop(self) -> None:
interval = self.config.explorer.runner_state_report_interval
if interval <= 0:
self.logger.info("Runner state monitoring loop disabled.")
return

self.logger.info("Runner state monitoring loop started.")
while self.running:
try:
await asyncio.gather(*[runner.update_state() for runner in self.runners.values()])
except Exception:
self.logger.error(
f"Error in runner state monitoring loop:\n{traceback.format_exc()}"
)
await asyncio.sleep(0.1)
self.logger.info("Runner state monitoring loop stopped.")

async def _schedule_pending_tasks(self) -> None:
if not self.idle_runners:
return
Expand Down Expand Up @@ -331,6 +354,7 @@ async def start(self) -> None:
self.running = True
await asyncio.gather(*[self._create_runner(i) for i in range(self.runner_num)])
self.scheduler_task = asyncio.create_task(self._scheduler_loop())
self.monitor_task = asyncio.create_task(self._monitor_runner_state_loop())
ready_refs = [runner.runner.__ray_ready__.remote() for runner in self.runners.values()]
await asyncio.gather(*ready_refs)
self.logger.info(f"Starting Scheduler with {self.runner_num} runners")
Expand All @@ -354,6 +378,12 @@ async def stop(self) -> None:
await self.scheduler_task
except asyncio.CancelledError:
pass
if self.monitor_task:
self.monitor_task.cancel()
try:
await self.monitor_task
except asyncio.CancelledError:
pass
self.logger.info("Scheduler stopped")

def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None:
Expand Down Expand Up @@ -525,3 +555,47 @@ async def wait_all(
)

raise TimeoutError(error_msg)

def get_key_state(self, key: str) -> Dict:
"""Get the scheduler state.

Args:
key (`str`): The key of the state to get.

Returns:
`Dict`: A dictionary of runner ids to their state for the given key.
"""
result = {}
for runner in self.runners.values():
runner_state = runner.state
if runner_state and key in runner_state:
result[runner.runner_id] = runner_state[key]
return result

def get_runner_state(self, runner_id: int) -> Dict:
"""Get the scheduler state.

Args:
runner_id (`int`): The id of the runner.

Returns:
`Dict`: The state of the runner.
"""
runner = self.runners.get(runner_id, None)
if runner:
return runner.state
else:
return {}

def get_all_state(self) -> Dict:
"""Get all runners' state.

Returns:
`Dict`: The state of all runners.
"""
result = {}
for runner in self.runners.values():
runner_state = runner.state
if runner_state:
result[runner.runner_id] = runner_state
return result
Loading