diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index b7a50875b7..87c2c95941 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -314,7 +314,7 @@ Controls the rollout models and workflow execution. ```yaml explorer: name: explorer - runner_num: 32 + runner_per_model: 8 max_timeout: 900 max_retry_times: 2 env_vars: {} @@ -325,10 +325,12 @@ explorer: auxiliary_models: - model_path: /PATH/TO/MODEL tensor_parallel_size: 1 + eval_interval: 100 + eval_on_startup: True ``` - `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. -- `runner_num`: Number of parallel workflow runners. +- `runner_per_model`: Number of parallel workflow runners per each rollout model. - `max_timeout`: Maximum time (in seconds) for a workflow to complete. - `max_retry_times`: Maximum number of retries for a workflow. - `env_vars`: Environment variables to be set for every workflow runners. @@ -336,6 +338,9 @@ explorer: - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. - `auxiliary_models`: Additional models used for custom workflows. +- `eval_interval`: Interval (in steps) for evaluating the model. +- `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting. +- `runner_num`: (*Deprecated*) Number of parallel workflow runners. --- diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 281008ae46..f6079ad55e 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -34,6 +34,7 @@ buffer: path: 'sqlite:///alfworld.db' explorer: runner_num: 32 + max_timeout: 3600 rollout_model: engine_type: vllm_async engine_num: 2 @@ -44,10 +45,12 @@ explorer: seed: 42 gpu_memory_utilization: 0.7 enable_chunked_prefill: true + env_vars: + TMPDIR: /PATH/TO/ALFWORLD_TMP_DIR synchronizer: sync_method: 'nccl' - sync_interval: 8 - sync_timeout: 1200 + sync_interval: 5 + sync_timeout: 3600 trainer: trainer_type: 'verl' trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml' diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 490117da1d..fb38e43660 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -16,7 +16,7 @@ from trinity.common.constants import StorageType -class TestFileBuffer(unittest.TestCase): +class TestFileBuffer(unittest.IsolatedAsyncioTestCase): temp_output_path = "tmp/test_file_buffer/" @classmethod @@ -30,7 +30,7 @@ def tearDownClass(cls): if os.path.exists(cls.temp_output_path): os.system(f"rm -rf {cls.temp_output_path}") - def test_file_buffer(self): + async def test_file_buffer(self): meta = StorageConfig( name="test_buffer", path=os.path.join(self.temp_output_path, "buffer.jsonl"), @@ -46,8 +46,9 @@ def test_file_buffer(self): # test writer writer = JSONWriter(meta, None) + await writer.acquire() writer.write(data) - writer.release() + await writer.release() # test reader meta.path = self.temp_output_path @@ -119,23 +120,31 @@ def test_file_reader(self): # noqa: C901 break self.assertEqual(len(tasks), 40 - 24) - def test_file_writer(self): + async def test_file_writer(self): writer = get_buffer_writer( self.config.buffer.trainer_input.experience_buffer, self.config.buffer ) + await writer.acquire() writer.write( [ {"prompt": "hello world"}, {"prompt": "hi"}, ] ) + await writer.write_async( + [ + {"prompt": "My name is"}, + {"prompt": "What is your name?"}, + ] + ) + await writer.release() file_wrapper = ray.get_actor("json-test_buffer") self.assertIsNotNone(file_wrapper) file_path = default_storage_path( self.config.buffer.trainer_input.experience_buffer, self.config.buffer ) with open(file_path, "r") as f: - self.assertEqual(len(f.readlines()), 2) + self.assertEqual(len(f.readlines()), 4) def setUp(self): self.config = get_template_config() diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 5819aeb462..32b59534a7 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -1,9 +1,10 @@ import os +import threading import time import torch -from tests.tools import RayUnittestBase +from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter from trinity.common.config import BufferConfig, StorageConfig @@ -13,8 +14,8 @@ BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl") -class TestQueueBuffer(RayUnittestBase): - def test_queue_buffer(self): +class TestQueueBuffer(RayUnittestBaseAysnc): + async def test_queue_buffer(self): total_num = 8 put_batch_size = 2 read_batch_size = 4 @@ -32,7 +33,7 @@ def test_queue_buffer(self): ) writer = QueueWriter(meta, config) reader = QueueReader(meta, config) - self.assertEqual(writer.acquire(), 1) + self.assertEqual(await writer.acquire(), 1) exps = [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), @@ -43,7 +44,7 @@ def test_queue_buffer(self): for i in range(1, put_batch_size + 1) ] for _ in range(total_num // put_batch_size): - writer.write(exps) + await writer.write_async(exps) for _ in range(total_num // read_batch_size): exps = reader.read() self.assertEqual(len(exps), read_batch_size) @@ -62,7 +63,7 @@ def test_queue_buffer(self): ) exps = reader.read(batch_size=put_batch_size * 2) self.assertEqual(len(exps), put_batch_size * 2) - self.assertEqual(writer.release(), 0) + self.assertEqual(await writer.release(), 0) self.assertRaises(StopIteration, reader.read) with open(BUFFER_FILE_PATH, "r") as f: self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2) @@ -71,6 +72,34 @@ def test_queue_buffer(self): et = time.time() self.assertTrue(et - st > 2) + # test queue capacity + meta = StorageConfig( + name="test_buffer_small", + algorithm_type="ppo", + storage_type=StorageType.QUEUE, + max_read_timeout=3, + capacity=4, + path=BUFFER_FILE_PATH, + ) + writer = QueueWriter(meta, config) + reader = QueueReader(meta, config) + writer.write([{"content": "hello"}]) + writer.write([{"content": "hi"}]) + writer.write([{"content": "hello"}]) + writer.write([{"content": "hi"}]) + + # should be blocked + def write_blocking_call(): + writer.write([{"content": "blocked"}]) + + thread = threading.Thread(target=write_blocking_call) + thread.start() + thread.join(timeout=2) + self.assertTrue(thread.is_alive(), "write() did not block as expected") + reader.read() + thread.join(timeout=1) + self.assertFalse(thread.is_alive()) + def setUp(self): if os.path.exists(BUFFER_FILE_PATH): os.remove(BUFFER_FILE_PATH) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index e40a91b4c7..22b1c739a6 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -1,9 +1,9 @@ import os -import unittest import ray import torch +from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig @@ -13,8 +13,8 @@ db_path = os.path.join(os.path.dirname(__file__), "test.db") -class TestSQLBuffer(unittest.TestCase): - def test_create_sql_buffer(self) -> None: +class TestSQLBuffer(RayUnittestBaseAysnc): + async def test_create_sql_buffer(self) -> None: total_num = 8 put_batch_size = 2 read_batch_size = 4 @@ -42,9 +42,9 @@ def test_create_sql_buffer(self) -> None: ) for i in range(1, put_batch_size + 1) ] - self.assertEqual(sql_writer.acquire(), 1) + self.assertEqual(await sql_writer.acquire(), 1) for _ in range(total_num // put_batch_size): - sql_writer.write(exps) + await sql_writer.write_async(exps) for _ in range(total_num // read_batch_size): exps = sql_reader.read() self.assertEqual(len(exps), read_batch_size) @@ -66,5 +66,5 @@ def test_create_sql_buffer(self) -> None: self.assertEqual(len(exps), put_batch_size * 2) db_wrapper = ray.get_actor("sql-test_buffer") self.assertIsNotNone(db_wrapper) - self.assertEqual(sql_writer.release(), 0) + self.assertEqual(await sql_writer.release(), 0) self.assertRaises(StopIteration, sql_reader.read) diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py deleted file mode 100644 index 1ba7731efc..0000000000 --- a/tests/explorer/runner_pool_test.py +++ /dev/null @@ -1,255 +0,0 @@ -import copy -import os -import time -import unittest -from typing import List, Tuple - -import ray -import torch - -from tests.tools import get_unittest_dataset_config -from trinity.buffer.reader.queue_reader import QueueReader -from trinity.common.config import InferenceModelConfig, StorageConfig, load_config -from trinity.common.constants import StorageType -from trinity.common.experience import Experience -from trinity.common.models.model import InferenceModel -from trinity.common.workflows import Task -from trinity.common.workflows.workflow import WORKFLOWS, Workflow -from trinity.explorer.runner_pool import RunnerPool - -config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data", "template.yaml") - - -@WORKFLOWS.register_module("dummy_workflow") -class DummyWorkflow(Workflow): - def __init__(self, model, task, auxiliary_models): - super().__init__(model, task, auxiliary_models) - self.error_type = task.task_desc - self.seconds = None - if "timeout" in self.error_type: - self.seconds = int(self.error_type.split("_")[-1]) - - def run(self) -> List[Experience]: - if "timeout" in self.error_type: - time.sleep(self.seconds) - elif self.error_type == "exception": - raise ValueError("Exception occurred") - elif self.error_type == "exit": - exit(1) - elif self.error_type == "auxiliary_models": - assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 - return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)] - - -@ray.remote -class DummyModel(InferenceModel): - def sync_model(self, model_version, update_weight_args_list): - return True - - def get_model_version(self): - return 0 - - def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - ) -> None: - pass - - -@ray.remote -class DummyAuxiliaryModel(InferenceModel): - def sync_model(self, model_version, update_weight_args_list): - return True - - def get_model_version(self): - return 0 - - def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - ) -> None: - pass - - def has_api_server(self) -> bool: - return True - - def api_server_ready(self) -> Tuple[str, str]: - return "http://localhosts:12345", "placeholder" - - -class RunnerPoolTest(unittest.TestCase): - def setUp(self): - ray.init(ignore_reinit_error=True) - self.config = load_config(config_dir) - self.config.explorer.runner_num = 2 - self.config.explorer.max_retry_times = 0 - self.config.explorer.max_timeout = 5 - self.config.buffer.read_batch_size = 2 - self.config.buffer.pad_token_id = 0 - self.config.buffer.explorer_output = ( - self.config.buffer.trainer_input.experience_buffer - ) = StorageConfig( - name="test", - storage_type=StorageType.QUEUE, - algorithm_type="ppo", - path="", - ) - self.queue = QueueReader( - self.config.buffer.trainer_input.experience_buffer, self.config.buffer - ) - - def test_runner_pool(self): - pool = RunnerPool(self.config, [DummyModel.remote(), DummyModel.remote()]) - taskset_config = get_unittest_dataset_config("countdown") - tasks = [ - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "timeout_100", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "exception", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "timeout_2", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "success", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "timeout_101", - }, - ), - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "exit", - }, - ), - ] - - pool.run_tasks( - tasks=tasks, - ) - - # The excepted return order is: `exception` -> `timeout_2` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit` - # 1. `exception` - st = time.time() - status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st < 2) - print(f"First task use time: {et - st}") - self.assertEqual(len(status), 1) - self.assertFalse(status[0].ok) - # 2. `timeout_2 - st = time.time() - status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st > 2) - self.assertEqual(len(status), 1) - self.assertTrue(status[0].ok) - # 3. `success` - st = time.time() - status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st < 1) - self.assertEqual(len(status), 1) - self.assertTrue(status[0].ok) - # 4. `timeout_100`and `timeout_101` - st = time.time() - status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st > 5) - self.assertEqual(len(status), 2) - self.assertFalse(status[0].ok) - self.assertFalse(status[1].ok) - - # 5.`exit` - status = pool.get_next_unorder() - self.assertEqual(len(status), 1) - self.assertFalse(status[0].ok) - - exps = self.queue.read() - self.assertEqual(len(exps), 2) # `timeout_2` and `success` - self.assertEqual(len(pool._idle_actors), self.config.explorer.runner_num) - - def test_runner_pool_with_auxiliary_models(self): - config = copy.deepcopy(self.config) - config.explorer.auxiliary_models = [ - InferenceModelConfig( - engine_num=1, - ), - InferenceModelConfig( - engine_num=1, - ), - ] - pool = RunnerPool( - config, - [DummyModel.remote(), DummyModel.remote()], - [[DummyAuxiliaryModel.remote()], [DummyAuxiliaryModel.remote()]], - ) - taskset_config = get_unittest_dataset_config("countdown") - tasks = [ - Task( - workflow=DummyWorkflow, - format_args=taskset_config.format, - rollout_args=taskset_config.rollout_args, - is_eval=False, - raw_task={ - taskset_config.format.prompt_key: "auxiliary_models", - }, - ), - ] - - pool.run_tasks( - tasks=tasks, - ) - - # `auxiliary_models` - status = pool.get_next_unorder() - self.assertEqual(len(status), 1) - self.assertTrue(status[0].ok) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py new file mode 100644 index 0000000000..0b96c98805 --- /dev/null +++ b/tests/explorer/scheduler_test.py @@ -0,0 +1,373 @@ +import asyncio +import time +import unittest +from typing import List, Tuple + +import ray +import torch + +from tests.tools import get_template_config +from trinity.buffer.reader.queue_reader import QueueReader +from trinity.common.config import StorageConfig +from trinity.common.constants import StorageType +from trinity.common.experience import Experience +from trinity.common.models.model import InferenceModel +from trinity.common.workflows import Task +from trinity.common.workflows.workflow import WORKFLOWS, Workflow +from trinity.explorer.scheduler import Scheduler + + +@WORKFLOWS.register_module("dummy_workflow") +class DummyWorkflow(Workflow): + def __init__(self, model, task, auxiliary_models): + super().__init__(model, task, auxiliary_models) + self.error_type = task.raw_task.get("error_type", "") + self.seconds = None + if "timeout" in self.error_type: + parts = self.error_type.split("_") + if len(parts) > 1: + self.seconds = int(parts[-1]) + else: + self.seconds = 10 + + def run(self) -> List[Experience]: + if "timeout" in self.error_type: + time.sleep(self.seconds) + elif self.error_type == "exception": + raise ValueError("Exception occurred") + elif self.error_type == "exit": + exit(1) + elif self.error_type == "auxiliary_models": + assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 + + return [ + Experience( + tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type or "success" + ) + ] + + +@ray.remote +class DummyModel(InferenceModel): + def sync_model(self, model_version, update_weight_args_list): + return True + + def get_model_version(self): + return 0 + + def init_process_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + timeout: int = 1200, + update_with_checkpoint: bool = True, + ) -> None: + pass + + +@ray.remote +class DummyAuxiliaryModel(InferenceModel): + def sync_model(self, model_version, update_weight_args_list): + return True + + def get_model_version(self): + return 0 + + def init_process_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + timeout: int = 1200, + update_with_checkpoint: bool = True, + ) -> None: + pass + + def has_api_server(self) -> bool: + return True + + def api_server_ready(self) -> Tuple[str, str]: + return "http://localhosts:12345", "placeholder" + + +def generate_tasks( + total_num: int, timeout_num: int = 0, exception_num: int = 0, timeout_seconds: int = 10 +): + """Generate some tasks for testing + + Args: + total_num: number of normal tasks + timeout_num: number of timeout tasks + exception_num: number of exception tasks + timeout_seconds: the timeout for timeout tasks + """ + tasks = [Task(workflow=DummyWorkflow, raw_task={}) for _ in range(total_num)] + + tasks.extend( + [ + Task( + workflow=DummyWorkflow, + raw_task={"error_type": f"timeout_{timeout_seconds}"}, + ) + for _ in range(timeout_num) + ] + ) + + tasks.extend( + [ + Task( + workflow=DummyWorkflow, + raw_task={"error_type": "exception"}, + ) + for _ in range(exception_num) + ] + ) + + return tasks + + +class SchedulerTest(unittest.IsolatedAsyncioTestCase): + def setUp(self): + ray.init(ignore_reinit_error=True) + self.config = get_template_config() + self.config.explorer.max_retry_times = 1 + self.config.explorer.max_timeout = 5 + self.config.explorer.runner_per_model = 2 + self.config.buffer.read_batch_size = 2 + self.config.buffer.pad_token_id = 0 + self.config.buffer.explorer_output = ( + self.config.buffer.trainer_input.experience_buffer + ) = StorageConfig( + name="test", + storage_type=StorageType.QUEUE, + algorithm_type="ppo", + path="", + ) + self.queue = QueueReader( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) + + async def test_get_results(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks = generate_tasks(8) + scheduler.schedule(tasks, batch_id=0) + + results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + self.assertEqual(len(results), 8) + + for result in results: + self.assertTrue(result.ok) + + for batch_id in range(1, 4): + tasks = generate_tasks(4) + scheduler.schedule(tasks, batch_id=batch_id) + + for batch_id in range(1, 4): + self.assertTrue(scheduler.has_step(batch_id)) + results = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10) + self.assertEqual(len(results), 4) + self.assertFalse(scheduler.has_step(batch_id)) + + tasks = generate_tasks(3) + scheduler.schedule(tasks, batch_id=4) + self.assertTrue(scheduler.has_step(4)) + results = await scheduler.get_results(batch_id=4) + self.assertEqual(len(results), 3) + self.assertFalse(scheduler.has_step(4)) + + # test timeout + tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10) + scheduler.schedule(tasks, batch_id=0) + + start_time = time.time() + results = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) + end_time = time.time() + + self.assertLessEqual(end_time - start_time, 5) + self.assertEqual(len(results), 2) + + # test run tasks after timeout + tasks = generate_tasks(4) + scheduler.schedule(tasks, batch_id=0) + + # actor restart is slow, set a big timeout + results = await scheduler.get_results(batch_id=0, timeout=20) + self.assertEqual(len(results), 4) + + success_count = sum(1 for r in results if r.ok) + + self.assertEqual(success_count, sum(1 for r in results if r.ok)) + + # test exception tasks + tasks = generate_tasks(1, exception_num=3) + scheduler.schedule(tasks, batch_id=1) + results = await scheduler.get_results(batch_id=1, timeout=5) + self.assertEqual(len(results), 4) + + success_count = sum(1 for r in results if r.ok) + self.assertEqual(success_count, 1) + + # test clear_timeout_tasks + tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=2) + results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False) + self.assertEqual(len(results), 3) + results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False) + self.assertEqual(len(results), 1) + + await scheduler.stop() + + async def test_wait_all(self): + """Test wait all""" + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks1 = generate_tasks(4) + tasks2 = generate_tasks(3) + scheduler.schedule(tasks1, batch_id=0) + scheduler.schedule(tasks2, batch_id=1) + + start_time = time.time() + await scheduler.wait_all(timeout=10.0) + end_time = time.time() + + self.assertLess(end_time - start_time, 5.0) + + self.assertEqual(len(scheduler.pending_tasks), 0) + self.assertEqual(len(scheduler.running_tasks), 0) + + results0 = await scheduler.get_results(batch_id=0, min_num=4, timeout=1) + results1 = await scheduler.get_results(batch_id=1, min_num=3, timeout=1) + self.assertEqual(len(results0), 4) + self.assertEqual(len(results1), 3) + + # test timeout + tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10) + scheduler.schedule(tasks, batch_id=0) + + start_time = time.time() + with self.assertRaises(TimeoutError): + await scheduler.wait_all(timeout=3.0) + end_time = time.time() + + self.assertGreaterEqual(end_time - start_time, 2.8) + self.assertLessEqual(end_time - start_time, 4.0) + + # test empty scenario + + start_time = time.time() + await scheduler.wait_all(timeout=5.0) + end_time = time.time() + + self.assertLess(end_time - start_time, 1.0) + await scheduler.stop() + + async def test_wait_all_timeout_with_multi_batch(self): + self.config.explorer.max_timeout = 5 + self.config.explorer.rollout_model.engine_num = 4 + self.config.explorer.runner_per_model = 1 + + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks = generate_tasks(1, timeout_num=3, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=0) + tasks = generate_tasks(2, timeout_num=2, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=1) + tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) + scheduler.schedule(tasks, batch_id=2) + start_time = time.time() + await scheduler.wait_all() + end_time = time.time() + self.assertTrue( + end_time - start_time > 9, + f"wait time should be greater than 9, but got {end_time - start_time}", + ) + + await scheduler.stop() + + async def test_concurrent_operations(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + async def schedule_tasks(batch_id, num_tasks): + tasks = generate_tasks(num_tasks) + scheduler.schedule(tasks, batch_id=batch_id) + return await scheduler.get_results(batch_id=batch_id, min_num=num_tasks, timeout=10) + + results = await asyncio.gather( + schedule_tasks(0, 3), + schedule_tasks(1, 4), + schedule_tasks(2, 2), + ) + + self.assertEqual(len(results[0]), 3) + self.assertEqual(len(results[1]), 4) + self.assertEqual(len(results[2]), 2) + + await scheduler.stop() + + async def test_scheduler_restart_after_stop(self): + scheduler = Scheduler(self.config, [DummyModel.remote()]) + + await scheduler.start() + tasks = generate_tasks(2) + scheduler.schedule(tasks, batch_id=0) + results = await scheduler.get_results(batch_id=0, min_num=2, timeout=10) + self.assertEqual(len(results), 2) + await scheduler.stop() + + await scheduler.start() + tasks = generate_tasks(3) + scheduler.schedule(tasks, batch_id=1) + results = await scheduler.get_results(batch_id=1, min_num=3, timeout=10) + self.assertEqual(len(results), 3) + await scheduler.stop() + + async def test_scheduler_all_methods(self): + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + tasks = generate_tasks(8) + scheduler.schedule(tasks, batch_id=0) + self.assertTrue(scheduler.has_step(0)) + results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + self.assertEqual(len(results), 8) + scheduler.schedule(tasks, batch_id=1) + scheduler.schedule(tasks[:4], batch_id=2) + self.assertFalse(scheduler.has_step(0)) + results = await scheduler.get_results(batch_id=0, min_num=8) + self.assertFalse(scheduler.has_step(0)) + self.assertEqual(len(results), 0) # batch_id 0 has no more tasks + self.assertFalse(scheduler.has_step(0)) + self.assertTrue(scheduler.has_step(1)) + self.assertTrue(scheduler.has_step(2)) + await scheduler.wait_all() + st = time.time() + results = await scheduler.get_results(batch_id=1) + et = time.time() + self.assertTrue(et - st < 1.0) + self.assertEqual(len(results), 8) + self.assertFalse(scheduler.has_step(1)) + self.assertTrue(scheduler.has_step(2)) + st = time.time() + results = await scheduler.get_results(batch_id=2) + et = time.time() + self.assertTrue(et - st < 1.0) + self.assertEqual(len(results), 4) + self.assertFalse(scheduler.has_step(2)) + await scheduler.stop() + + def tearDown(self): + try: + ray.shutdown() + except Exception: + pass diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 98180fff48..aaca7ff0a8 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -37,7 +37,7 @@ buffer: default_reward_fn_type: '' explorer: eval_interval: 100 - runner_num: 4 + runner_per_model: 8 rollout_model: engine_type: vllm_async engine_num: 2 diff --git a/tests/tools.py b/tests/tools.py index 7be4a2c4ef..a9b2ca8349 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -182,3 +182,13 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): ray.shutdown(_exiting_interpreter=True) + + +class RayUnittestBaseAysnc(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True, namespace="trinity_unittest") + + @classmethod + def tearDownClass(cls): + ray.shutdown(_exiting_interpreter=True) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 4c87732926..cb125ac2b5 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -89,30 +89,32 @@ def test_trainer(self): trainer_type=self.config.trainer.trainer_type, step_num=4, ) - checkpoint_step_8, _ = get_checkpoint_dir_with_step_num( + # check save lastest checkpoint + checkpoint_step_8, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, - step_num=8, ) - self.assertTrue(os.path.exists(checkpoint_step_4)) - self.assertTrue(os.path.exists(checkpoint_step_8)) + self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))) > 0) + self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0) + self.assertEqual(step_num, 8) # TODO: Reinit will fail when using v1 engine, find a way to fix it ray.init(ignore_reinit_error=True) # test bench mode self.config.mode = "bench" self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT - self.config.explorer.eval_on_latest_checkpoint = False + self.config.explorer.bench_on_latest_checkpoint = False self.config.check_and_update() bench(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) - countdown_metrics = parser.metric_list("eval/countdown") - copy_countdown_metrics = parser.metric_list("eval/copy_countdown") - self.assertTrue(len(countdown_metrics) > 0) - self.assertTrue(len(copy_countdown_metrics) > 0) - countdown_metric_steps = parser.metric_steps(countdown_metrics[0]) - countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0]) - self.assertEqual([0, 4, 8], countdown_metric_steps) - self.assertEqual([0, 4, 8], countdown_copy_metric_steps) + for prefix in ["eval", "bench"]: + countdown_metrics = parser.metric_list(f"{prefix}/countdown") + copy_countdown_metrics = parser.metric_list(f"{prefix}/copy_countdown") + self.assertTrue(len(countdown_metrics) > 0) + self.assertTrue(len(copy_countdown_metrics) > 0) + countdown_metric_steps = parser.metric_steps(countdown_metrics[0]) + countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0]) + self.assertEqual([0, 4, 8], countdown_metric_steps) + self.assertEqual([0, 4, 8], countdown_copy_metric_steps) def tearDown(self): # remove dir only when the test passed @@ -329,7 +331,6 @@ def test_fully_async_mode(self): config.cluster.node_num = 1 explorer1_config.explorer.rollout_model.engine_num = 1 explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 - explorer1_config.explorer.runner_num = 4 explorer1_config.buffer.explorer_output = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, @@ -353,7 +354,7 @@ def test_fully_async_mode(self): explorer_process_1 = multiprocessing.Process(target=run_explorer, args=(explorer1_config,)) explorer_process_1.start() - time.sleep(20) + time.sleep(5) explorer2_config.explorer.name = "explorer2" explorer2_config.check_and_update() explorer_process_2 = multiprocessing.Process(target=run_explorer, args=(explorer2_config,)) diff --git a/trinity/buffer/buffer_writer.py b/trinity/buffer/buffer_writer.py index 13079ffb76..3d3e939196 100644 --- a/trinity/buffer/buffer_writer.py +++ b/trinity/buffer/buffer_writer.py @@ -11,7 +11,11 @@ def write(self, data: List) -> None: """Write to buffer.""" @abstractmethod - def acquire(self) -> int: + async def write_async(self, data: List) -> None: + """Write to buffer asynchronously.""" + + @abstractmethod + async def acquire(self) -> int: """Acquire the buffer writer. Returns: @@ -19,7 +23,7 @@ def acquire(self) -> int: """ @abstractmethod - def release(self) -> int: + async def release(self) -> int: """Release the buffer writer. After release, the buffer writer can not be used again. Returns: diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index b49644e13c..534283c50b 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -28,7 +28,7 @@ class QueueActor: def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger = get_logger(__name__) self.config = config - self.capacity = getattr(config, "capacity", 10000) + self.capacity = storage_config.capacity self.queue = asyncio.Queue(self.capacity) st_config = deepcopy(storage_config) st_config.wrap_in_ray = False @@ -57,7 +57,7 @@ async def release(self) -> int: self.ref_count -= 1 if self.ref_count <= 0: await self.queue.put(self.FINISH_MESSAGE) - self.writer.release() + await self.writer.release() return self.ref_count def length(self) -> int: diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index 16ec96d0a9..93c10479ca 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -20,15 +20,21 @@ def write(self, data: List) -> None: else: self.writer.write(data) - def acquire(self) -> int: + async def write_async(self, data): if self.wrap_in_ray: - return ray.get(self.writer.acquire()) + await self.writer.write.remote(data) + else: + self.writer.write(data) + + async def acquire(self) -> int: + if self.wrap_in_ray: + return await self.writer.acquire.remote() else: return 0 - def release(self) -> int: + async def release(self) -> int: if self.wrap_in_ray: - return ray.get(self.writer.release.remote()) + return await self.writer.release.remote() else: self.writer.release() return 0 diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 7b12fab4c1..9b13262b80 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -23,8 +23,11 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): def write(self, data: List) -> None: ray.get(self.queue.put_batch.remote(data)) - def acquire(self) -> int: - return ray.get(self.queue.acquire.remote()) + async def write_async(self, data): + return await self.queue.put_batch.remote(data) - def release(self) -> int: - return ray.get(self.queue.release.remote()) + async def acquire(self) -> int: + return await self.queue.acquire.remote() + + async def release(self) -> int: + return await self.queue.release.remote() diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 95344d4447..a951201b80 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -23,15 +23,21 @@ def write(self, data: list) -> None: else: self.db_wrapper.write(data) - def acquire(self) -> int: + async def write_async(self, data): if self.wrap_in_ray: - return ray.get(self.db_wrapper.acquire.remote()) + await self.db_wrapper.write.remote(data) + else: + self.db_wrapper.write(data) + + async def acquire(self) -> int: + if self.wrap_in_ray: + return await self.db_wrapper.acquire.remote() else: return 0 - def release(self) -> int: + async def release(self) -> int: if self.wrap_in_ray: - return ray.get(self.db_wrapper.release.remote()) + return await self.db_wrapper.release.remote() else: self.db_wrapper.release() return 0 diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 1b3ba1f4bb..76830a125f 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -25,6 +25,7 @@ def bench(config: Config) -> None: """Evaluate model.""" + config.explorer.name = "benchmark" explorer = ( ray.remote(Explorer) .options( diff --git a/trinity/common/config.py b/trinity/common/config.py index 29973d8342..70d147460f 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -305,12 +305,14 @@ class ExplorerConfig: name: str = EXPLORER_NAME # for workflow runner # number of workflow runners. - # For sync engine (vllm), it should be equal to `engine_num`. - # For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num` - runner_num: int = 1 - max_timeout: int = 900 # wait each task for 15 minutes + # For sync engine (vllm), it should be `1`. + # For async engine (vllm_async), it could be a large number. + runner_per_model: int = 8 # number of runners per each rollout model + max_timeout: int = 1800 # wait each task for 30 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout - env_vars: dict = field(default_factory=dict) + env_vars: dict = field(default_factory=dict) # environment variables for workflow runner + + runner_num: Optional[int] = None # deprecated # for inference models # for rollout model @@ -320,7 +322,10 @@ class ExplorerConfig: # for evaluation eval_interval: int = 100 - eval_on_latest_checkpoint: bool = False + eval_on_startup: bool = True # evalulate at step 0 + + # for benchmark + bench_on_latest_checkpoint: bool = False # only benchmark the latest checkpoint @dataclass @@ -363,7 +368,7 @@ class SynchronizerConfig: # allow explorer to run `sync_offset` steps before sync sync_offset: int = 0 # waiting for `sync_timeout` seconds before timeout in `nccl` method - sync_timeout: int = 1800 + sync_timeout: int = 3600 # wait for the lastest checkpoint to be ready # TODO: to be used wait_for_checkpoint: bool = False diff --git a/trinity/explorer/__init__.py b/trinity/explorer/__init__.py index e7794c7cf6..8665a1b125 100644 --- a/trinity/explorer/__init__.py +++ b/trinity/explorer/__init__.py @@ -1,4 +1,3 @@ from trinity.explorer.explorer import Explorer -from trinity.explorer.runner_pool import RunnerPool -__all__ = ["Explorer", "RunnerPool"] +__all__ = ["Explorer"] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index cb52a0c302..c22516be42 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -5,7 +5,8 @@ import asyncio import os import time -from collections import defaultdict +import traceback +from collections import deque from typing import List, Optional import torch @@ -24,10 +25,10 @@ get_checkpoint_dir_with_step_num, load_state_dict, ) -from trinity.explorer.runner_pool import RunnerPool +from trinity.explorer.scheduler import Scheduler from trinity.manager.manager import CacheManager from trinity.utils.log import get_logger -from trinity.utils.monitor import MONITOR +from trinity.utils.monitor import MONITOR, gather_metrics class Explorer: @@ -38,20 +39,21 @@ def __init__(self, config: Config): self.cache = CacheManager(config) explorer_meta = self.cache.load_explorer() self.explore_step_num = explorer_meta.get("latest_iteration", 0) + self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) + self.experience_buffer = None if self.config.mode != "bench": self.experience_buffer = get_buffer_writer( self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) - self.experience_buffer.acquire() self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0) self.taskset = get_buffer_reader( self.config.buffer.explorer_input.taskset, self.config.buffer ) - self.runner_pool = self._init_runner_pool() + self.scheduler = self._init_scheduler() self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, name=self.config.name, @@ -65,7 +67,7 @@ def __init__(self, config: Config): self.use_checkpoint_weights_update = ( self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT ) - self.eval_explore_step_num = None + self.pending_eval_tasks = deque() # For checkpoint weights update # Use explorer to periodically load the latest model weights and @@ -111,20 +113,14 @@ async def setup_weight_sync_group( ] await asyncio.gather(*refs) - def _init_runner_pool(self) -> RunnerPool: + def _init_scheduler(self) -> Scheduler: if self.config.explorer.rollout_model.engine_type != "vllm_async": # sync model requires the same number of runners as the number of models - self.config.explorer.runner_num = self.config.explorer.rollout_model.engine_num + self.config.explorer.runner_per_model = 1 self.logger.info( "Sync vLLM model requires the same number of runners as the number of models" ) - if self.config.explorer.runner_num < self.config.explorer.rollout_model.engine_num: - self.config.explorer.runner_num = self.config.explorer.rollout_model.engine_num - self.logger.info( - f"Number of Runners is less than number of models, set to {self.config.explorer.runner_num}" - ) - self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners") - return RunnerPool(self.config, self.models, self.auxiliary_models) + return Scheduler(self.config, self.models, self.auxiliary_models) async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: # TODO: update model weight @@ -141,7 +137,7 @@ async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: ) self.state_dict.clear() - async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: + async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: # TODO: support more checkpoint types try: checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( @@ -150,12 +146,14 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> No step_num=step_num, ) if checkpoint_dir == self.old_checkpoint: - return + return checkpoint_step_num model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor")) await self._update_model_weight(checkpoint_step_num, model_weights) self.old_checkpoint = checkpoint_dir + return checkpoint_step_num except Exception as e: self.logger.warning(f"Fail to load checkpoint: {e}") + return 0 async def _nccl_weights_update(self): assert self.state_dict_meta is not None @@ -175,6 +173,7 @@ async def _nccl_weights_update(self): await asyncio.gather( *[model.sync_model.remote(self.explore_step_num) for model in self.models] ) + self.status = RunningStatus.RUNNING async def ready_to_sync(self): async with self._ready_to_sync_condition: @@ -183,9 +182,17 @@ async def ready_to_sync(self): async def prepare(self) -> None: """Preparation before running.""" + futures = [asyncio.create_task(self.scheduler.start())] if self.use_checkpoint_weights_update: master_address, master_port = await self.models[0].get_available_address.remote() - await self.setup_weight_sync_group(master_address, master_port) + futures.append( + asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) + ) + asyncio.gather(*futures, return_exceptions=True) + if self.experience_buffer: + await self.experience_buffer.acquire() + if self.config.explorer.eval_on_startup and self.explore_step_num == 0: + self.eval() async def get_weight(self, name: str) -> torch.Tensor: """Get the weight of the loaded model (For checkpoint weights update).""" @@ -193,33 +200,34 @@ async def get_weight(self, name: str) -> torch.Tensor: async def explore(self) -> str: """ - The dreamming loop for explorer and trainer. - | <----------------------------------------- one period ----------------------------------------------> | - explorer | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- eval --> | <-- [idle] --> | <-- sync --> | - trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- [idle] --> | <-- sync --> | + The timeline of the exploration process: + | <--------------------------------- one period -------------------------------------> | + explorer | <---------------- step_1 --------------> | | + | | <---------------- step_2 --------------> | | + | ... | + | | <---------------- step_n ---------------> | | + | | <---------------------- eval --------------------> | <-- sync --> | + |--------------------------------------------------------------------------------------| + trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> | """ - self.eval_explore_step_num = None while True: try: self.logger.info(f"Explore step {self.explore_step_num + 1} started.") - if ( - self.eval_explore_step_num is None - and self.explore_step_num % self.config.explorer.eval_interval == 0 - ): - self.eval_explore_step_num = self.explore_step_num - explore_contionue = self.explore_step() + explore_contionue = await self.explore_step() if not explore_contionue: + # TODO: support eval on last checkpoint break + if self.need_eval(): + self.eval() if self.need_sync(): - self.wait_for_workflow_done() await self.sync_weight() - except Exception as e: - self.logger.error(f"Error in Explorer: {e}") + except Exception: + self.logger.error(f"Error in Explorer: {traceback.format_exc()}") break self.logger.info("--------------------\n> Explorer finished.\n--------------------") return self.config.explorer.name - def explore_step(self) -> bool: + async def explore_step(self) -> bool: algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1) # skip warmup if algo_config.algorithm_type == "sft": @@ -229,15 +237,11 @@ def explore_step(self) -> bool: tasks = self.taskset.read() except StopIteration: self.logger.warning("No more tasks to explore. Stop exploring.") - self.cache.save_explorer( - current_step=self.explore_step_num, - current_task_index=self.explore_step_num * self.config.buffer.batch_size, - ) + await self.save_checkpoint(sync_weight=False) self.status = RunningStatus.STOPPED - self.wait_for_workflow_done() - self.experience_buffer.release() + await self.experience_buffer.release() return False - self.runner_pool.run_tasks(tasks) + self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) self.explore_step_num += 1 return True @@ -248,60 +252,42 @@ def need_sync(self) -> bool: self.explore_step_num - self.config.synchronizer.sync_offset ) % self.config.synchronizer.sync_interval == 0 - def eval(self, eval_explore_step_num: int): + def need_eval(self) -> bool: + return self.explore_step_num % self.config.explorer.eval_interval == 0 + + def eval(self): """Evaluation on all evaluation data samples.""" if len(self.config.buffer.explorer_input.eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") return - self.logger.info(f"Evaluation at step {eval_explore_step_num} started.") - all_st = time.time() - log_metrics = {} + self.logger.info(f"Evaluation at step {self.explore_step_num} started.") for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: self.logger.info( - f"Evaluation on {eval_taskset_config.name} at step {eval_explore_step_num} started." + f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started." ) eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer) - st = time.time() - all_metrics = defaultdict(list) - - def wait(): - status_list = self.runner_pool.get_next_unorder() - if not isinstance(status_list, list): - status_list = [status_list] - for status in status_list: - if not status.ok: - self.logger.error(f"Error when running task: {status.message}") - else: - for metric_name, metric_value in status.metric.items(): - all_metrics[metric_name].append(metric_value) - + eval_batch_id = f"{self.explore_step_num}/{eval_taskset.name}" + self.pending_eval_tasks.append((self.explore_step_num, eval_taskset.name)) while True: - if not self.runner_pool.has_free(): - wait() try: - self.runner_pool.run_tasks(eval_taskset.read()) + self.scheduler.schedule(eval_taskset.read(), batch_id=eval_batch_id) except StopIteration: break - while self.runner_pool.has_next(): - wait() - metrics = self.monitor.calculate_metrics(all_metrics, prefix=f"eval/{eval_taskset.name}") # type: ignore - log_metrics.update(metrics) - log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st - log_metrics["eval/total_time"] = time.time() - all_st - self.monitor.log(log_metrics, step=eval_explore_step_num) # type: ignore - self.logger.info(f"Evaluation at step {eval_explore_step_num} finished.") async def benchmark(self) -> bool: """Benchmark the model checkpoints.""" # benchmark on the latest checkpoint - if self.config.explorer.eval_on_latest_checkpoint: - await self._checkpoint_weights_update() - self.eval(self.explore_step_num) + if self.config.explorer.bench_on_latest_checkpoint: + self.explore_step_num = await self._checkpoint_weights_update() + self.eval() + await self._log_eval_metrics(prefix="bench") return True # benchmark on base model - self.eval(0) - # benchmark on all checkoints + if self.config.explorer.eval_on_startup: + await self._log_eval_metrics(prefix="bench") + + # benchmark on all checkpoints all_ckp_steps = sorted( [ int(ckp.split("global_step_")[-1]) @@ -311,62 +297,79 @@ async def benchmark(self) -> bool: ] ) for step_num in all_ckp_steps: - await self._checkpoint_weights_update(step_num=step_num) - self.eval(step_num) + self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num) + self.eval() + await self._log_eval_metrics(prefix="bench") return True - def wait_for_workflow_done(self) -> None: - """Wait for workflow to finish.""" - all_metrics = defaultdict(list) - # wait for all tasks of this step to finish - while self.runner_pool.has_next(): - status_list = self.runner_pool.get_next_unorder() - if not isinstance(status_list, list): - status_list = [status_list] - for status in status_list: - if not status.ok: - self.logger.error(f"Error when running task: {status.message}") - # submit another task to replace the failed task - try: - tasks = self.taskset.read(batch_size=1) - except StopIteration: - self.logger.warning("No more tasks in taskset. Stop retrying.") - return - self.runner_pool.run_tasks(tasks) - else: - for metric_name, metric_value in status.metric.items(): - all_metrics[metric_name].append(metric_value) - # eval - if self.eval_explore_step_num is not None: - self.eval(self.eval_explore_step_num) - self.eval_explore_step_num = None - # calculate metrics - log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore - self.monitor.log(log_metrics, step=self.explore_step_num) - self.logger.info(f"Explore step {self.explore_step_num} finished.") + async def save_checkpoint(self, sync_weight: bool = False) -> None: + # wait for all tasks to complete + self.logger.info("Waiting for all tasks to complete") + await self.scheduler.wait_all() + self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") + log_task = asyncio.create_task( + self._log_metrics(self.last_sync_step + 1, self.explore_step_num) + ) + + if sync_weight: + # sync weights + self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") + if self.use_checkpoint_weights_update: + await self._checkpoint_weights_update() + else: # nccl weights update + await self._nccl_weights_update() + self.last_sync_step = self.explore_step_num + self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} finished") + + # overlay log and weight sync + await log_task - async def sync_weight(self) -> None: - """Synchronize model weights.""" - # call this method before training start to load the latest model weights - self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.") - if self.use_checkpoint_weights_update: - await self._checkpoint_weights_update() - else: # nccl weights update - await self._nccl_weights_update() # save explore checkpoint self.cache.save_explorer( current_step=self.explore_step_num, current_task_index=self.explore_step_num * self.config.buffer.batch_size, ) - self.status = RunningStatus.RUNNING - self.logger.info(f"Explorer sync at step {self.explore_step_num} finished") + + async def sync_weight(self) -> None: + """Synchronize model weights.""" + # call this method before training start to load the latest model weights + await self.save_checkpoint(sync_weight=True) + + async def _log_metrics(self, start_step: int, end_step: int) -> None: + for step in range(start_step, end_step + 1): + self.logger.info(f"Log metrics of step {step}") + await self._log_explore_metrics(step=step) + await self._log_eval_metrics(step=step) + + async def _log_explore_metrics(self, step: int) -> None: + results = await self.scheduler.get_results(batch_id=step) + if results: + metric = gather_metrics([status.metric for status in results], "rollout") + self.monitor.log(metric, step=step) + + async def _log_eval_metrics(self, step: Optional[int] = None, prefix: str = "eval") -> None: + if not self.pending_eval_tasks: + return + step = step or self.explore_step_num + st = time.time() + metric = {} + while self.pending_eval_tasks: + eval_step, eval_task_name = self.pending_eval_tasks[0] + if eval_step != step: + return + self.pending_eval_tasks.popleft() + eval_results = await self.scheduler.get_results(f"{step}/{eval_task_name}") + metric.update( + gather_metrics( + [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" + ) + ) + metric[f"{prefix}/total_time"] = time.time() - st + self.monitor.log(metric, step) async def running_status(self) -> RunningStatus: return self.status - def flush_log(self, step: int) -> None: - """Flush the log of the current step.""" - self.monitor.log({}, step=step, commit=True) - - def shutdown(self) -> None: + async def shutdown(self) -> None: self.monitor.close() + await self.scheduler.stop() diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py deleted file mode 100644 index e5ef8bdd5d..0000000000 --- a/trinity/explorer/runner_pool.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Runner pool for running tasks in parallel. Modified from ray.util.actor_pool.ActorPool.""" -import random -from typing import List, Optional, Tuple, Union - -import ray - -from trinity.common.config import Config -from trinity.common.models.model import InferenceModel -from trinity.common.workflows import Task -from trinity.explorer.workflow_runner import Status, WorkflowRunner -from trinity.utils.log import get_logger - - -class RunnerPool: - """A pool of WorkflowRunner. - - The RunnerPool will automatically handle the exceptions during the workflow - and retry when the workflow fails or timeout. The number of max retries is - set in `config.explorer.max_retry_times` and the max timeout is set in - `config.explorer.max_timeout`. - """ - - def __init__( - self, - config: Config, - models: List[InferenceModel], - auxiliary_models: Optional[List[List[InferenceModel]]] = None, - ): - # actors to be used - self.logger = get_logger(__name__) - self.config = config - self.models = models - self.auxiliary_models = auxiliary_models or [] - self.timeout = config.explorer.max_timeout - self.max_retry_times = config.explorer.max_retry_times - - # get actor from future - self._future_to_actor = {} - - # get future from index - self._index_to_future = {} - - # next task to do - self._next_task_index = 0 - - # next task to return - self._next_return_index = 0 - - # next work depending when actors free - self._pending_submits = [] - - # create new actors - self.engine_status = [0] * config.explorer.rollout_model.engine_num - self.auxiliary_engine_status_list = [ - [0] * cfg.engine_num for cfg in config.explorer.auxiliary_models - ] - self._idle_actors = list() - self.actor_to_engine_index = {} - self._namespace = ray.get_runtime_context().namespace - self._create_actors(config.explorer.runner_num) - - def _create_actors(self, num: int = 1): - new_actors = [] - for _ in range(num): - engine_index = self.engine_status.index(min(self.engine_status)) - selected_auxiliary_models = [ - models[engine_status.index(min(engine_status))] - for models, engine_status in zip( - self.auxiliary_models, self.auxiliary_engine_status_list - ) - ] - new_actor = ( - ray.remote(WorkflowRunner) - .options( - namespace=self._namespace, - scheduling_strategy="SPREAD", - runtime_env={"env_vars": self.config.explorer.env_vars}, - ) - .remote( - self.config, - self.models[engine_index], - selected_auxiliary_models, - ) - ) - new_actors.append(new_actor) - self.engine_status[engine_index] += 1 - self.actor_to_engine_index[new_actor] = engine_index - for actor in new_actors: - self._return_actor(actor) - - def _kill_actors(self, actors): - if not isinstance(actors, list): - actors = [actors] - - for actor in actors: - release_engine_index = self.actor_to_engine_index[actor] - self.engine_status[release_engine_index] -= 1 - del self.actor_to_engine_index[actor] - ray.kill(actor) - - def _run_task(self, task: Task, retry_times: int = 0) -> None: - """Run a task in the pool. - - Arguments: - task: A task to run. - retry_times: The current retry times of the task. - """ - if self._idle_actors: - actor = self._idle_actors.pop() - future = actor.run_task.remote(task) - future_key = tuple(future) if isinstance(future, list) else future - self._future_to_actor[future_key] = (task, actor, retry_times) - self._index_to_future[self._next_task_index] = future - self._next_task_index += 1 - else: - self._pending_submits.append((task, retry_times)) - - def run_tasks(self, tasks: Union[List[Task], Task]) -> None: - """Schedule a list of tasks to run in the pool. - - Arguments: - tasks: A list of tasks. - """ - if isinstance(tasks, Task): - tasks = [tasks] - for task in tasks: - self._run_task(task, 0) - - def has_next(self): - """Returns whether there are any pending results to return. - - Returns: - True if there are any pending results not yet returned. - """ - return bool(self._future_to_actor) - - def _handle_single_future(self, future, is_timeout) -> Tuple[Status, Task, int]: - future_key = tuple(future) if isinstance(future, list) else future - t, a, r = self._future_to_actor.pop(future_key) - - if is_timeout: - # when timeout, restart the actor - self.logger.warning(f"Workflow {t.task_desc} Timeout.") - - # kill the actor and update engine status - self._kill_actors(a) - - # start a new actor - self._create_actors(num=1) - - return_status = Status( - False, metric={"time_per_task": self.timeout}, message="Workflow Timeout." - ) - else: - self._return_actor(a) - try: - return_status = ray.get(future) - except Exception as e: - self.logger.error(f"Error when running task: {e}") - return_status = Status( - False, - metric={"time_per_task": self.timeout}, - message=f"Error when running task: {e}", - ) - return return_status, t, r - - def get_next_unorder(self) -> List[Status]: - """Returns the next pending result unorder. - - Returns: - The return status of the next task. - """ - if not self.has_next(): - raise StopIteration("No more results to get") - is_timeout = False - res, _ = ray.wait(list(self._future_to_actor), num_returns=1, timeout=self.timeout) - if not res: - is_timeout = True - future_list = list(self._future_to_actor) - else: - future_list = res - - return_status_list = list() - for future in future_list: - return_status, t, r = self._handle_single_future(future, is_timeout) - - if not return_status.ok: - if r >= self.max_retry_times: - return_status_list.append( - Status( - False, - metric={"retry_times": r + 1}, - message=f"{return_status.message}\nWorkflow Retry Times Exceeded.", - ) - ) - else: - self.logger.info(f"Retry Workflow {t.task_desc}.") - self._run_task(t, r + 1) - else: - return_status_list.append(return_status) - - return return_status_list if return_status_list else self.get_next_unorder() - - # todo: this function may be discarded in the next version - def get_next(self) -> Status: - """Returns the next pending result in order. - - This returns the next task result, blocking for up to - the specified timeout until it is available. - - Returns: - The return status of the next task. - """ - if not self.has_next(): - raise StopIteration("No more results to get") - future = self._index_to_future[self._next_return_index] - is_timeout = False - res, _ = ray.wait([future], timeout=self.timeout) - if not res: - is_timeout = True - del self._index_to_future[self._next_return_index] - self._next_return_index += 1 - - future_key = tuple(future) if isinstance(future, list) else future - t, a, r = self._future_to_actor.pop(future_key) - - if is_timeout: - # when timeout, restart the actor - self.logger.warning(f"Workflow {t.task_desc} Timeout.") - ray.kill(a) - # TODO: balance the model - self._return_actor( - ray.remote(WorkflowRunner) - .options( - namespace=self._namespace, - scheduling_strategy="SPREAD", - ) - .remote( - self.config, - self.models[ - random.randint(0, self.config.explorer.rollout_model.engine_num - 1) - ], - ) - ) - return_status = Status( - False, metric={"time_per_task": self.timeout}, message="Workflow Timeout." - ) - else: - self._return_actor(a) - try: - return_status = ray.get(future) - except Exception as e: - self.logger.error(f"Error when running task: {e}") - return_status = Status( - False, - metric={"time_per_task": self.timeout}, - message=f"Error when running task: {e}", - ) - - if not return_status.ok: - if r >= self.max_retry_times: - return Status( - False, - metric={"retry_times": r + 1}, - message=f"{return_status.message}\nWorkflow Retry Times Exceeded.", - ) - else: - self.logger.info(f"Retry Workflow {t.task_desc}.") - self._run_task(t, r + 1) - return self.get_next() - else: - return return_status - - def _return_actor(self, actor): - try: - ray.get(actor.is_alive.remote()) - self._idle_actors.append(actor) - except Exception: - self.logger.info("The actor is not alive, restart a new actor") - self._kill_actors(actor) - self._create_actors(num=1) - - if self._pending_submits: - self._run_task(*self._pending_submits.pop(0)) - - def has_free(self): - """Returns whether there are any idle actors available. - - Returns: - True if there are any idle actors and no pending submits. - """ - return len(self._idle_actors) > 0 and len(self._pending_submits) == 0 - - def pop_idle(self): - """Removes an idle actor from the pool. - - Returns: - An idle actor if one is available. - None if no actor was free to be removed. - """ - if self.has_free(): - return self._idle_actors.pop() - return None diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py new file mode 100644 index 0000000000..cc3bea2ca1 --- /dev/null +++ b/trinity/explorer/scheduler.py @@ -0,0 +1,393 @@ +"""Scheduler for rollout tasks.""" + +import asyncio +import re +import time +import traceback +from collections import defaultdict, deque +from typing import Dict, List, Optional, Tuple, Union + +import ray + +from trinity.common.config import Config +from trinity.common.models import InferenceModel +from trinity.common.workflows import Task +from trinity.explorer.workflow_runner import Status, WorkflowRunner +from trinity.utils.log import get_logger + + +class RunnerWrapper: + """A wrapper for a WorkflowRunner""" + + def __init__( + self, + runner_id: int, + rollout_model: InferenceModel, + auxiliary_models: List[InferenceModel], + config: Config, + ): + self.logger = get_logger(__name__) + self.runner_id = runner_id + self.rollout_model = rollout_model + self.auxiliary_models = auxiliary_models + self.config = config + self.retry_times = config.explorer.max_retry_times + self.timeout = config.explorer.max_timeout + self.namespace = ray.get_runtime_context().namespace + self.runner = self._create_runner() + + def _create_runner(self): + return ( + ray.remote(WorkflowRunner) + .options( + namespace=self.namespace, + scheduling_strategy="SPREAD", + runtime_env={ + "env_vars": self.config.explorer.env_vars, + }, + ) + .remote(self.config, self.rollout_model, self.auxiliary_models) + ) + + async def run_with_retry(self, task: Task) -> Tuple[Status, int]: + """ + Returns: + `Status`: The return status of the task. + `int`: The runner_id of current runner. + """ + last_exception_msg = None + await self.runner.__ray_ready__.remote() + start_time = time.time() + status = Status(ok=False, metric=dict()) + try: + for attempt in range(self.retry_times + 1): + try: + status = await asyncio.wait_for(self.runner.run_task.remote(task), self.timeout) + if status.ok: + break + else: + self.logger.error(status.message) + except asyncio.TimeoutError: + last_exception_msg = ( + f"Timeout when running task at runner {self.runner_id}: {task}" + ) + self.logger.error(last_exception_msg) + status = Status(ok=False, metric=dict(), message=last_exception_msg) + except Exception: + last_exception_msg = traceback.format_exc() + self.logger.warning( + f"Task execution attempt {attempt + 1} failed:\n{last_exception_msg}" + ) + status = Status(ok=False, metric=dict(), message=last_exception_msg) + finally: + end_time = time.time() + status.metric["task_run_time"] = end_time - start_time + return status, self.runner_id + + def restart_runner(self): + old_runner = self.runner + self.runner = self._create_runner() + try: + ray.kill(old_runner) + except Exception: + pass + + +def sort_batch_id(batch_id: Union[int, str]): + """Priority of batch_id""" + # TODO: avoid sort the batch_id every time + if isinstance(batch_id, int): + return (batch_id, 0) + else: + match = re.match(r"^(\d+)", batch_id) + if match: + num = int(match.group(1)) + return (num, 1) + else: + return (float("inf"), 1) + + +class Scheduler: + """Scheduler for rollout tasks.""" + + def __init__( + self, + config: Config, + rollout_model: List[InferenceModel], + auxiliary_models: Optional[List[List[InferenceModel]]] = None, + ): + self.logger = get_logger(__name__) + self.config = config + self.rollout_model = rollout_model + self.auxiliary_models = auxiliary_models or [] + self.namespace = ray.get_runtime_context().namespace + self.default_timeout = config.explorer.max_timeout * (config.explorer.max_retry_times + 1) + self.max_retry_times = config.explorer.max_retry_times + self.running = False + + self.runner_num = len(rollout_model) * config.explorer.runner_per_model + self.runners: Dict[int, RunnerWrapper] = dict() + self.idle_runners = set() # runner_id + self.busy_runners = dict() # runner_id -> (task, batch_id) + + self.pending_tasks_heap = [] + self.pending_tasks: Dict[Union[int, str], deque] = defaultdict(deque) # batch_id -> tasks + self.running_tasks: Dict[Union[int, str], set[asyncio.Future]] = defaultdict( + set + ) # batch_id -> futures + self.completed_tasks: Dict[Union[int, str], deque[Status]] = defaultdict( + deque + ) # batch_id -> results + + self.scheduler_task: Optional[asyncio.Task] = None + self.running = False + + self.total_scheduled = 0 + self.total_completed = 0 + + def _create_runner( + self, + runner_id: int, + ): + runner = RunnerWrapper( + runner_id=runner_id, + rollout_model=self.rollout_model[runner_id % len(self.rollout_model)], + auxiliary_models=[ + self.auxiliary_models[j][runner_id % len(self.auxiliary_models[j])] + for j in range(len(self.auxiliary_models)) + ], + config=self.config, + ) + self.runners[runner_id] = runner + self.idle_runners.add(runner_id) + + def _restart_runner(self, runner_id: int): + """Restart a runner.""" + self.runners[runner_id].restart_runner() + + if runner_id in self.busy_runners: + task, idx = self.busy_runners.pop(runner_id) + self.logger.warning( + f"Runner {runner_id} failed to run task at batch_id {idx}: {task.raw_task}" + ) + + self.idle_runners.add(runner_id) + self.logger.info(f"Runner {runner_id} restarted.") + + async def _scheduler_loop(self) -> None: + self.logger.info("Scheduler loop started.") + while self.running: + try: + await self._schedule_pending_tasks() + await self._check_completed_tasks() + await asyncio.sleep(0.01) + except Exception: + self.logger.error(f"Error in scheduler loop:\n{traceback.format_exc()}") + await asyncio.sleep(0.1) + self.logger.info("Scheduler loop stopped.") + + async def _schedule_pending_tasks(self) -> None: + if not self.idle_runners: + return + + # TODO: Support more advanced scheduling strategies + for batch_id in sorted(self.pending_tasks.keys(), key=sort_batch_id): + task_queue = self.pending_tasks[batch_id] + + while task_queue and self.idle_runners: + task = task_queue.pop() + runner_id = self.idle_runners.pop() + self.busy_runners[runner_id] = (task, batch_id) + self.running_tasks[batch_id].add( + asyncio.create_task(self.runners[runner_id].run_with_retry(task)) + ) + + if not task_queue: + del self.pending_tasks[batch_id] + + async def _check_completed_tasks(self) -> None: + for batch_id in list(self.running_tasks.keys()): + futures = self.running_tasks[batch_id] + + for future in list(futures): + if future.done(): + futures.remove(future) + try: + task_result, runner_id = await future + self.completed_tasks[batch_id].appendleft(task_result) + self.busy_runners.pop(runner_id) + self.idle_runners.add(runner_id) + + self.logger.debug( + f"Task completed (batch_id {batch_id}), success: {task_result.ok}" + ) + + except Exception as e: + self.logger.error(f"Error getting task result: {e}") + + if not futures: + del self.running_tasks[batch_id] + + def _clear_timeout_tasks(self, batch_id: Union[int, str]) -> None: + if batch_id in self.pending_tasks: + self.logger.info(f"Clear timeout pending tasks at batch_id {batch_id}.") + del self.pending_tasks[batch_id] + if batch_id in self.running_tasks: + self.logger.info(f"Clear timeout running tasks at batch_id {batch_id}.") + for future in self.running_tasks[batch_id]: + future.cancel() + del self.running_tasks[batch_id] + + async def start(self) -> None: + if self.running: + return + self.running = True + for i in range(self.runner_num): + self._create_runner(i) + self.scheduler_task = asyncio.create_task(self._scheduler_loop()) + for _, runner in self.runners.items(): + await runner.runner.__ray_ready__.remote() + self.logger.info(f"Starting Scheduler with {self.runner_num} runners") + + async def stop(self) -> None: + if not self.running: + return + + self.running = False + all_running_futures = [] + for futures in self.running_tasks.values(): + all_running_futures.extend(futures) + + if all_running_futures: + self.logger.info(f"Waiting for {len(all_running_futures)} running tasks to complete...") + await asyncio.gather(*all_running_futures, return_exceptions=True) + + if self.scheduler_task: + self.scheduler_task.cancel() + try: + await self.scheduler_task + except asyncio.CancelledError: + pass + self.logger.info("Scheduler stopped") + + def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: + """Schedule the provided tasks. + + Args: + tasks (`List[Task]`): The tasks to schedule. + batch_id (`Union[int, str]`): The id of provided tasks. It should be an integer or a string + starting with an integer (e.g., 123, "123/my_task") + """ + if not tasks: + return + for task in tasks: + self.pending_tasks[batch_id].appendleft(task) + + async def get_results( + self, + batch_id: Union[int, str], + min_num: Optional[int] = None, + timeout: Optional[float] = None, + clear_timeout_tasks: bool = True, + ) -> List[Status]: + """Get the result of tasks at the specific batch_id. + + Args: + batch_id (`Union[int, str]`): Only wait for tasks at this batch. + min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. + timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. + clear_timeout_tasks (`bool`): Whether to clear timeout tasks. + """ + timeout = timeout or self.default_timeout + start_time = time.time() + if min_num is None: + min_num = 0 + if batch_id in self.pending_tasks: + min_num += len(self.pending_tasks[batch_id]) + if batch_id in self.running_tasks: + min_num += len(self.running_tasks[batch_id]) + if batch_id in self.completed_tasks: + min_num += len(self.completed_tasks[batch_id]) + + self.logger.debug(f"Waiting for {min_num} tasks to complete...") + + while time.time() - start_time < timeout: + completed_count = len(self.completed_tasks[batch_id]) + if completed_count >= min_num: + break + await asyncio.sleep(0.1) + + if time.time() - start_time > timeout: + self.logger.error(f"Timed out waiting for tasks to complete after {timeout} seconds") + if clear_timeout_tasks: + self._clear_timeout_tasks(batch_id=batch_id) + for runner_id in list(self.busy_runners.keys()): + if self.busy_runners[runner_id][1] == batch_id: + self._restart_runner(runner_id) + + results = [] + for _ in range(min_num): + if len(self.completed_tasks[batch_id]) > 0: + results.append(self.completed_tasks[batch_id].pop()) + + if not self.completed_tasks[batch_id]: + del self.completed_tasks[batch_id] + + completed_count = len(results) + if completed_count < min_num: + self.logger.warning( + f"Timeout reached, only {completed_count}/{min_num} tasks completed" + ) + + return results + + def has_step(self, batch_id: Union[int, str]) -> bool: + return ( + batch_id in self.completed_tasks + or batch_id in self.pending_tasks + or batch_id in self.running_tasks + ) + + async def wait_all( + self, timeout: Optional[float] = None, clear_timeout_tasks: bool = True + ) -> None: + """Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError. + + Args: + timeout (`float`): timeout in seconds. Raise `TimeoutError` when no new tasks is completed within timeout. + clear_timeout_tasks (`bool`): Whether to clear timeout tasks. + """ + timeout = timeout or self.default_timeout + start_time = time.time() + + self.logger.debug("Waiting for all tasks to complete...") + last_completed_count = 0 + while time.time() - start_time < timeout: + has_pending = bool(self.pending_tasks) + has_running = bool(self.running_tasks) + + if not has_pending and not has_running: + self.logger.debug("All tasks completed successfully") + return + + completed_count = sum(len(tasks) for tasks in self.completed_tasks.values()) + if completed_count != last_completed_count: + # flush timeout when new tasks are completed + start_time = time.time() + last_completed_count = completed_count + + await asyncio.sleep(0.1) + + pending_count = sum(len(tasks) for tasks in self.pending_tasks.values()) + running_count = sum(len(futures) for futures in self.running_tasks.values()) + error_msg = f"Timeout after {timeout} seconds. Still have {pending_count} pending tasks and {running_count} running tasks." + self.logger.error(error_msg) + + if clear_timeout_tasks: + for batch_id in self.pending_tasks.keys() | self.running_tasks.keys(): + self._clear_timeout_tasks(batch_id) + busy_runner_ids = list(self.busy_runners.keys()) + for runner_id in busy_runner_ids: + self._restart_runner(runner_id) + + raise TimeoutError(error_msg) diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index b9ba995985..b468382300 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -199,9 +199,11 @@ def _expert_buffer_part(self): def _expert_explorer_part(self): self.get_configs("sync_method", "sync_interval", "sync_timeout") - self.get_configs("runner_num", "max_timeout", "explorer_max_retry_times", "eval_interval") + self.get_configs( + "runner_per_model", "max_timeout", "explorer_max_retry_times", "eval_interval" + ) - self.get_configs("eval_on_latest_checkpoint") + self.get_configs("bench_on_latest_checkpoint") with st.expander("Rollout Model Config", expanded=True): self.get_configs("engine_type", "engine_num", "tensor_parallel_size") @@ -571,7 +573,7 @@ def _gen_buffer_config(self): def _gen_explorer_config(self): explorer_config = { - "runner_num": st.session_state["runner_num"], + "runner_per_model": st.session_state["runner_per_model"], "max_timeout": st.session_state["max_timeout"], "max_retry_times": st.session_state["explorer_max_retry_times"], "rollout_model": { @@ -584,7 +586,7 @@ def _gen_explorer_config(self): }, "auxiliary_models": [], "eval_interval": st.session_state["eval_interval"], - "eval_on_latest_checkpoint": st.session_state["eval_on_latest_checkpoint"], + "bench_on_latest_checkpoint": st.session_state["bench_on_latest_checkpoint"], } for i in range(st.session_state["_auxiliary_models_num"]): auxiliary_model_config = { diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py index 12e8034a30..249c669f60 100644 --- a/trinity/manager/config_registry/explorer_config_manager.py +++ b/trinity/manager/config_registry/explorer_config_manager.py @@ -9,9 +9,9 @@ def explorer_visible() -> bool: return st.session_state["mode"] == "both" -@CONFIG_GENERATORS.register_config(default_value=32, visible=explorer_visible) -def set_runner_num(**kwargs): - st.number_input("Runner Num", min_value=1, **kwargs) +@CONFIG_GENERATORS.register_config(default_value=8, visible=explorer_visible) +def set_runner_per_model(**kwargs): + st.number_input("Runner per Model", min_value=1, **kwargs) @CONFIG_GENERATORS.register_config(default_value=900, visible=explorer_visible) @@ -30,7 +30,7 @@ def set_eval_interval(**kwargs): @CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible) -def set_eval_on_latest_checkpoint(**kwargs): +def set_bench_on_latest_checkpoint(**kwargs): st.checkbox("Eval on Latest Checkpoint", **kwargs) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 3a9f51f677..1378449cf2 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -73,10 +73,6 @@ def sync_weight(self) -> None: f"Trainer synchronizing weights at step {self.engine.train_step_num} end." ) - def flush_log(self, step: int) -> None: - """Flush the log of the current step.""" - self.engine.monitor.log({}, step=step, commit=True) - def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint step_num = self.engine.train_step_num diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index e83df10b8f..5896fc110d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -2,7 +2,7 @@ import os from abc import ABC, abstractmethod -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np import pandas as pd @@ -16,6 +16,18 @@ MONITOR = Registry("monitor") +def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict: + df = pd.DataFrame(metric_list) + numeric_df = df.select_dtypes(include=[np.number]) + stats_df = numeric_df.agg(["mean", "max", "min"]) + metric = {} + for col in stats_df.columns: + metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col] + metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col] + metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col] + return metric + + class Monitor(ABC): """Monitor"""