|
1 | 1 | import asyncio |
2 | 2 | import time |
3 | 3 | import unittest |
| 4 | +from collections import defaultdict |
4 | 5 | from typing import List, Optional |
5 | 6 |
|
6 | 7 | import ray |
|
11 | 12 | from trinity.common.config import ExperienceBufferConfig |
12 | 13 | from trinity.common.constants import StorageType, SyncStyle |
13 | 14 | from trinity.common.experience import EID, Experience |
14 | | -from trinity.common.models.model import InferenceModel |
| 15 | +from trinity.common.models.model import InferenceModel, ModelWrapper |
15 | 16 | from trinity.common.workflows import Task |
16 | 17 | from trinity.common.workflows.workflow import WORKFLOWS, Workflow |
17 | 18 | from trinity.explorer.scheduler import Scheduler |
@@ -134,6 +135,41 @@ def run(self): |
134 | 135 | raise RuntimeError("This method should not be called") |
135 | 136 |
|
136 | 137 |
|
| 138 | +@WORKFLOWS.register_module("dummy_workflow_with_state") |
| 139 | +class DummyWorkflowWithState(Workflow): |
| 140 | + can_repeat: bool = True |
| 141 | + is_async: bool = True |
| 142 | + |
| 143 | + def __init__(self, *, task, model: ModelWrapper, auxiliary_models): |
| 144 | + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) |
| 145 | + self.step_num = task.workflow_args.get("step_num", 1) |
| 146 | + |
| 147 | + def set_repeat_times(self, repeat_times, run_id_base): |
| 148 | + self.repeat_times = repeat_times |
| 149 | + self.run_id_base = run_id_base |
| 150 | + |
| 151 | + async def run_async(self) -> List[Experience]: |
| 152 | + exps = [] |
| 153 | + for i in range(self.repeat_times): |
| 154 | + run_level_metrics = {"run_metrics": float(i + self.run_id_base)} |
| 155 | + run_level_exps = [] |
| 156 | + for step in range(self.step_num): |
| 157 | + run_level_exps.append( |
| 158 | + Experience( |
| 159 | + eid=EID(run=i + self.run_id_base, step=step), |
| 160 | + tokens=torch.zeros(5), |
| 161 | + prompt_length=2, |
| 162 | + prompt_text="success", |
| 163 | + ) |
| 164 | + ) |
| 165 | + run_level_exps[-1].metrics = run_level_metrics |
| 166 | + self.logger.info(f"Setting workflow state to repeat_cnt={i}") |
| 167 | + await self.model.set_workflow_state({"repeat_cnt": i}) |
| 168 | + await asyncio.sleep(1) |
| 169 | + exps.extend(run_level_exps) |
| 170 | + return exps |
| 171 | + |
| 172 | + |
137 | 173 | @ray.remote |
138 | 174 | class DummyModel(InferenceModel): |
139 | 175 | def sync_model(self, model_version, update_weight_args_list): |
@@ -779,3 +815,58 @@ def tearDown(self): |
779 | 815 | ray.shutdown() |
780 | 816 | except Exception: |
781 | 817 | pass |
| 818 | + |
| 819 | + |
| 820 | +class TestRunnerStateCollection(unittest.IsolatedAsyncioTestCase): |
| 821 | + async def test_runner_state_collection(self): |
| 822 | + ray.init(ignore_reinit_error=True) |
| 823 | + config = get_template_config() |
| 824 | + config.explorer.runner_per_model = 2 |
| 825 | + config.explorer.runner_state_report_interval = 0.5 |
| 826 | + config.explorer.max_repeat_times_per_runner = 2 |
| 827 | + config.check_and_update() |
| 828 | + scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()]) |
| 829 | + # 4 runner in side the scheduler |
| 830 | + await scheduler.start() |
| 831 | + |
| 832 | + tasks = [ |
| 833 | + Task( |
| 834 | + workflow=DummyWorkflowWithState, # type: ignore[type-abstract] |
| 835 | + workflow_args={"step_num": 2}, |
| 836 | + repeat_times=4, |
| 837 | + raw_task={}, |
| 838 | + ) |
| 839 | + for _ in range(4) |
| 840 | + ] |
| 841 | + scheduler.schedule(tasks, batch_id=0) |
| 842 | + |
| 843 | + async def monitor_routine(): |
| 844 | + runner_0_state_history = defaultdict(set) |
| 845 | + await asyncio.sleep(0.5) # wait for first report |
| 846 | + for _ in range(16): |
| 847 | + await asyncio.sleep(0.3) |
| 848 | + states = scheduler.get_all_state() |
| 849 | + self.assertEqual(len(states), 4) |
| 850 | + for state in states.values(): |
| 851 | + self.assertIn("workflow_id", state) |
| 852 | + self.assertIn("model_version", state) |
| 853 | + self.assertIn("begin_time", state) |
| 854 | + self.assertIn("terminate_time", state) |
| 855 | + self.assertIn("repeat_cnt", state) |
| 856 | + ids = scheduler.get_key_state("workflow_id") |
| 857 | + self.assertEqual(len(ids), 4) |
| 858 | + self.assertEqual(len(set(ids.values())), 4) |
| 859 | + runner_0_state = scheduler.get_runner_state(0) |
| 860 | + for key, value in runner_0_state.items(): |
| 861 | + runner_0_state_history[key].add(value) |
| 862 | + self.assertEqual(len(runner_0_state_history["repeat_cnt"]), 2) # max_repeat_times is 2 |
| 863 | + self.assertEqual(len(runner_0_state_history["model_version"]), 1) |
| 864 | + self.assertEqual( |
| 865 | + len(runner_0_state_history["workflow_id"]), 2 |
| 866 | + ) # split into 2 sub tasks |
| 867 | + self.assertEqual(len(runner_0_state_history["begin_time"]), 2) |
| 868 | + |
| 869 | + await asyncio.gather( |
| 870 | + monitor_routine(), |
| 871 | + scheduler.get_results(batch_id=0), |
| 872 | + ) |
0 commit comments