diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index 0d5c632421..2f580426ab 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -513,7 +513,7 @@ Here, `` is the path to a YAML configuration file, which shoul Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow: ```bash -trinity debug --config --module workflow --output_file --plugin_dir +trinity debug --config --module workflow --output-file --plugin-dir ``` - ``: Path to the YAML configuration file, usually the same as used for starting the inference model. diff --git a/docs/sphinx_doc/source/tutorial/faq.md b/docs/sphinx_doc/source/tutorial/faq.md index 2cd2a8aba7..ce960acffa 100644 --- a/docs/sphinx_doc/source/tutorial/faq.md +++ b/docs/sphinx_doc/source/tutorial/faq.md @@ -94,7 +94,7 @@ ray start --head **A:** The following parameters may be helpful: -- For trainer, adjust `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu` when `actor_rollout_ref.actor.use_dynamic_bsz=false`; adjust `actor_rollout_ref.actor.ppo_max_token_len_per_gpu` and `actor_rollout_ref.actor.ulysses_sequence_parallel_size` when `actor_rollout_ref.actor.use_dynamic_bsz=true`. Setting `actor_rollout_ref.actor.entropy_from_logits_with_chunking=true` may also help. +- For trainer, adjust `trainer.max_token_len_per_gpu` when `trainer.use_dynamic_bsz=false`; adjust `trainer.ppo_max_token_len_per_gpu` and `trainer.ulysses_sequence_parallel_size` when `trainer.use_dynamic_bsz=true`. Setting `trainer.trainer_config.actor_rollout_ref.actor.entropy_from_logits_with_chunking=true` may also help. - For explorer, adjust `explorer.rollout_model.tensor_parallel_size`. @@ -113,7 +113,7 @@ To debug a new workflow, use Trinity-RFT's debug mode with the following steps: 1. Launch the inference model via `trinity debug --config --module inference_model` -2. Debug the workflow in another terminal via `trinity debug --config --module workflow --output_file --plugin_dir ` +2. Debug the workflow in another terminal via `trinity debug --config --module workflow --output-file --plugin-dir ` Please refer to {ref}`Workflow Development Guide ` section for details. diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 9eb9cd90fb..505e5683fe 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -367,6 +367,12 @@ explorer: tensor_parallel_size: 1 eval_interval: 100 eval_on_startup: True + over_rollout: + ratio: 0.0 + wait_after_min: 30.0 + dynamic_timeout: + enable: false + ratio: 3.0 ``` - `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. @@ -381,6 +387,12 @@ explorer: - `auxiliary_models`: Additional models used for custom workflows. - `eval_interval`: Interval (in steps) for evaluating the model. - `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting. +- `over_rollout`: [Experimental] Configurations for over-rollout mechanism, which allows the explorer to proceed with fewer tasks than the full batch size. It effectively increases throughput in scenarios where some tasks take significantly longer to complete than others. Only applicable when dynamic synchronization (`synchronizer.sync_style` is not `fixed`) is used. + - `ratio`: Explorer will only wait for `(1 - ratio) * batch_size` of tasks at each step. Default is `0.0`, meaning waiting for all tasks. + - `wait_after_min`: After reaching the minimum task threshold, wait for this many seconds before proceeding. Default is `30.0` seconds. +- `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks. + - `enable`: Whether to enable dynamic timeout. Default is `false`. + - `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`. --- @@ -394,6 +406,7 @@ synchronizer: sync_interval: 10 sync_offset: 0 sync_timeout: 1200 + sync_style: 'fixed' ``` - `sync_method`: Method of synchronization. Options: @@ -402,6 +415,9 @@ synchronizer: - `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer. - `sync_offset`: Offset (in steps) of model weight synchronization between trainer and explorer. The explorer can run `sync_offset` steps before the trainer starts training. - `sync_timeout`: Timeout duration for synchronization. +- `sync_style`: Style of synchronization. Options: + - `fixed`: The explorer and trainer synchronize weights every `sync_interval` steps. + - `dynamic_by_explorer`: The explorer notifies the trainer to synchronize weights after completing `sync_interval` steps, regardless of how many steps the trainer has completed at this point. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index 366834ea53..17ab4cc22a 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -509,7 +509,7 @@ trinity debug --config --module inference_model 模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试: ```bash -trinity debug --config --module workflow --output_file --plugin_dir +trinity debug --config --module workflow --output-file --plugin-dir ``` - `config_file_path`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。 diff --git a/docs/sphinx_doc/source_zh/tutorial/faq.md b/docs/sphinx_doc/source_zh/tutorial/faq.md index 900dafd961..facd99539c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/faq.md +++ b/docs/sphinx_doc/source_zh/tutorial/faq.md @@ -93,7 +93,7 @@ ray start --head **A:** 以下参数可能有所帮助: -- 对于 trainer:当 `actor_rollout_ref.actor.use_dynamic_bsz=false` 时,调整 `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu`;当 `actor_rollout_ref.actor.use_dynamic_bsz=true` 时,调整 `actor_rollout_ref.actor.ppo_max_token_len_per_gpu` 和 `actor_rollout_ref.actor.ulysses_sequence_parallel_size`。设置 `actor_rollout_ref.actor.entropy_from_logits_with_chunking=true` 也可能有帮助。 +- 对于 trainer:当 `trainer.use_dynamic_bsz=false` 时,调整 `trainer.max_token_len_per_gpu`;当 `trainer.use_dynamic_bsz=true` 时,调整 `trainer.ppo_max_token_len_per_gpu` 和 `trainer.ulysses_sequence_parallel_size`。设置 `trainer.trainer_config.actor_rollout_ref.actor.entropy_from_logits_with_chunking=true` 也可能有帮助。 - 对于 explorer:调整 `explorer.rollout_model.tensor_parallel_size`。 ## 第三部分:调试方法 @@ -113,7 +113,7 @@ trinity run --config grpo_gsm8k/gsm8k.yaml 2>&1 | tee debug.log 1. 启动推理模型: `trinity debug --config --module inference_model` -2. 在另一个终端中进行工作流的调试:`trinity debug --config --module workflow --output_file --plugin_dir ` +2. 在另一个终端中进行工作流的调试:`trinity debug --config --module workflow --output-file --plugin-dir ` 更多详细信息,请参阅{ref}`工作流开发指南 `章节。 diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index c1e1847254..be9e56265c 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -378,6 +378,12 @@ explorer: - `auxiliary_models`: 用于自定义工作流的额外模型。 - `eval_interval`: 模型评估的间隔(以步为单位)。 - `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此重启时不会触发。 +- `over_rollout`: [实验性] 超量 rollout 机制的配置,允许 explorer 在每个步骤中使用少于完整批次大小的任务继续进行。这在某些任务显著耗时较长的场景中能有效地提高吞吐量。仅当使用动态同步(`synchronizer.sync_style` 不是 `fixed`)时适用。 + - `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。 + - `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。 +- `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。 + - `enable`: 是否启用动态超时。默认为 `false`。 + - `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。 --- @@ -391,6 +397,7 @@ synchronizer: sync_interval: 10 sync_offset: 0 sync_timeout: 1200 + sync_style: 'fixed' ``` - `sync_method`: 同步方法。选项: @@ -399,6 +406,9 @@ synchronizer: - `sync_interval`: trainer 和 explorer 之间模型权重同步的间隔(步)。 - `sync_offset`: trainer 和 explorer 之间模型权重同步的偏移量(步)。explorer 可在 trainer 开始训练前运行 `sync_offset` 步。 - `sync_timeout`: 同步超时时间。 +- `sync_style`: 同步风格。选项: + - `fixed`: explorer 和 trainer 每隔 `sync_interval` 步同步一次权重。 + - `dynamic_by_explorer`: explorer 在完成 `sync_interval` 步后通知 trainer 同步权重,而不管此时 trainer 已完成多少步。 --- diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 306cc8e447..9495b86f7a 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -5,10 +5,11 @@ import ray import torch +from parameterized import parameterized from tests.tools import get_template_config from trinity.common.config import ExperienceBufferConfig -from trinity.common.constants import StorageType +from trinity.common.constants import StorageType, SyncStyle from trinity.common.experience import EID, Experience from trinity.common.models.model import InferenceModel from trinity.common.workflows import Task @@ -46,17 +47,23 @@ def run(self) -> List[Experience]: elif self.error_type == "auxiliary_models": assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 - return [ - Experience( - tokens=torch.zeros(5), - prompt_length=2, - prompt_text=self.error_type or "success", - eid=EID(run=i + self.run_id_base, step=step), - info={"repeat_times": self.repeat_times}, - ) - for step in range(self.step_num) - for i in range(self.repeat_times) - ] + exps = [] + for i in range(self.repeat_times): + run_level_metrics = {"run_metrics": float(i + self.run_id_base)} + run_level_exps = [] + for step in range(self.step_num): + run_level_exps.append( + Experience( + tokens=torch.zeros(5), + prompt_length=2, + prompt_text=self.error_type or "success", + eid=EID(run=i + self.run_id_base, step=step), + info={"repeat_times": self.repeat_times}, + ) + ) + run_level_exps[-1].metrics = run_level_metrics + exps.extend(run_level_exps) + return exps @WORKFLOWS.register_module("dummy_nonrepeat_workflow") @@ -67,22 +74,29 @@ def __init__(self, *, task, model, auxiliary_models): super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.reset_flag = False self.step_num = task.workflow_args.get("step_num", 1) + self.metrics = task.workflow_args.get("metrics", [0]) def reset(self, task: Task): self.task = task self.reset_flag = True + self.step_num = task.workflow_args.get("step_num", 1) + self.metrics = task.workflow_args.get("metrics", [0]) def run(self) -> List[Experience]: - return [ + exps = [ Experience( eid=EID(run=self.run_id_base, step=step), tokens=torch.zeros(5), prompt_length=2, prompt_text="success", info={"reset_flag": self.reset_flag}, + metrics={ + "run_metrics": self.metrics[step % len(self.metrics)], + }, ) for step in range(self.step_num) ] + return exps @WORKFLOWS.register_module("dummy_async_workflow") @@ -99,16 +113,22 @@ def set_repeat_times(self, repeat_times, run_id_base): self.run_id_base = run_id_base async def run_async(self): - return [ - Experience( - eid=EID(run=i + self.run_id_base, step=step), - tokens=torch.zeros(5), - prompt_length=2, - prompt_text="success", - ) - for step in range(self.step_num) - for i in range(self.repeat_times) - ] + exps = [] + for i in range(self.repeat_times): + run_level_metrics = {"run_metrics": float(i + self.run_id_base)} + run_level_exps = [] + for step in range(self.step_num): + run_level_exps.append( + Experience( + eid=EID(run=i + self.run_id_base, step=step), + tokens=torch.zeros(5), + prompt_length=2, + prompt_text="success", + ) + ) + run_level_exps[-1].metrics = run_level_metrics + exps.extend(run_level_exps) + return exps def run(self): raise RuntimeError("This method should not be called") @@ -490,7 +510,7 @@ async def test_split_tasks(self): tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4 scheduler.schedule(tasks, batch_id=1) statuses, exps = await scheduler.get_results(batch_id=1) - self.assertEqual(len(statuses), 4 * 4) + self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4 * 8) exp_list.extend(exps) _, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1) @@ -499,7 +519,7 @@ async def test_split_tasks(self): tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3 scheduler.schedule(tasks, batch_id=2) statuses, exps = await scheduler.get_results(batch_id=2) - self.assertEqual(len(statuses), 4 * 3) + self.assertEqual(len(statuses), 4) self.assertEqual(len(exps), 4 * 5) exp_list.extend(exps) _, exps = await scheduler.get_results(batch_id=2, min_num=1, timeout=1) @@ -508,7 +528,7 @@ async def test_split_tasks(self): tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1 scheduler.schedule(tasks, batch_id=3) statuses, exps = await scheduler.get_results(batch_id=3) - self.assertEqual(len(statuses), 3 * 1) + self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3 * 1) exp_list.extend(exps) _, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1) @@ -535,7 +555,7 @@ async def test_multi_step_execution(self): for i in range(1, n_steps + 1): scheduler.schedule(tasks, batch_id=i) statuses, exps = await scheduler.get_results(batch_id=i) - self.assertEqual(len(statuses), 2 * 4) + self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 2 * 4) await scheduler.stop() @@ -553,7 +573,7 @@ async def test_non_repeatable_workflow(self): for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) statuses, exps = await scheduler.get_results(batch_id=i) - self.assertEqual(len(statuses), task_num * repeat_times / 2) + self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times) exp_list.extend(exps) @@ -594,7 +614,7 @@ async def test_async_workflow(self): for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) statuses, exps = await scheduler.get_results(batch_id=i) - self.assertEqual(len(statuses), task_num * repeat_times / 2) + self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -624,7 +644,7 @@ async def test_stepwise_experience_eid(self): for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) statuses, exps = await scheduler.get_results(batch_id=i) - self.assertEqual(len(statuses), task_num * repeat_times / 2) + self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -644,7 +664,7 @@ async def test_stepwise_experience_eid(self): for i in range(1, batch_num + 1): scheduler.schedule(tasks, batch_id=i) statuses, exps = await scheduler.get_results(batch_id=i) - self.assertEqual(len(statuses), task_num * repeat_times / 2) + self.assertEqual(len(statuses), task_num) self.assertEqual(len(exps), task_num * repeat_times * step_num) exp_list.extend(exps) @@ -656,6 +676,104 @@ async def test_stepwise_experience_eid(self): unique_ids = [exp.eid.uid for exp in exp_list] self.assertEqual(len(unique_ids), len(set(unique_ids))) + @parameterized.expand( + [ + (2,), + (None,), + ] + ) + async def test_metric_calculation_with_repeatable_workflow(self, max_repeat_times_per_runner): + self.config.explorer.max_repeat_times_per_runner = max_repeat_times_per_runner + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + tasks = [] + tasks.extend(generate_tasks(total_num=1, step_num=1, repeat_times=4, repeatable=True)) + tasks.extend(generate_tasks(total_num=1, step_num=4, repeat_times=8, repeatable=True)) + scheduler.schedule(tasks, batch_id=0) + statuses, exps = await scheduler.get_results(batch_id=0) + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 1 * 4 * 1 + 1 * 8 * 4) + self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 1.5) # (0+1+2+3)/4 + self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 3.5) # (0+1+2+3+4+5+6+7)/8 + + @parameterized.expand( + [ + (2,), + (None,), + ] + ) + async def test_metric_calculation_with_non_repeatable_workflow( + self, max_repeat_times_per_runner + ): + self.config.explorer.max_repeat_times_per_runner = max_repeat_times_per_runner + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + tasks = [] + tasks.extend(generate_tasks(total_num=1, step_num=3, repeat_times=4, repeatable=False)) + tasks[-1].workflow_args["metrics"] = [1.0, 2.0, 3.0] + tasks.extend(generate_tasks(total_num=1, step_num=8, repeat_times=5, repeatable=False)) + tasks[-1].workflow_args["metrics"] = [2 * i for i in range(8)] + scheduler.schedule(tasks, batch_id=0) + statuses, exps = await scheduler.get_results(batch_id=0) + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 1 * 4 * 3 + 1 * 5 * 8) + self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 2.0) # (1+2+3)/3 + self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 7.0) # (0+2+4+6+8+10+12+14)/8 + + async def test_over_rollout_min_wait(self): + self.config.explorer.over_rollout.ratio = 0.5 + self.config.explorer.over_rollout.wait_after_min = 3 + self.config.explorer.max_repeat_times_per_runner = None + self.config.buffer.batch_size = 4 + self.config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + tasks = [] + tasks.extend(generate_tasks(0, timeout_num=2, repeat_times=1, timeout_seconds=1)) + tasks.extend(generate_tasks(0, timeout_num=1, repeat_times=1, timeout_seconds=3)) + tasks.extend(generate_tasks(0, timeout_num=1, repeat_times=1, timeout_seconds=6)) + scheduler.schedule(tasks, batch_id=0) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=2) + self.assertEqual(len(statuses), 3) + self.assertEqual(len(exps), 3 * 1) + + async def test_dynamic_timeout(self): + self.config.explorer.dynamic_timeout.enable = True + self.config.explorer.dynamic_timeout.ratio = 3.0 + self.config.buffer.batch_size = 4 + self.config.explorer.max_timeout = 20 + self.config.explorer.max_retry_times = 0 # no retry here + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + tasks = [] + # generate 4 tasks that will run 1 second + tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=1)) + scheduler.schedule(tasks, batch_id=0) # first step will not use dynamic timeout + statuses, exps = await scheduler.get_results(batch_id=0) + self.assertEqual(len(statuses), 4) + # dynamic timeout will be set to 3.0 * 1.0 = 3.0 seconds for next step + tasks = [] + tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=4)) + st = time.time() + scheduler.schedule(tasks, batch_id=1) + statuses, exps = await scheduler.get_results(batch_id=1) + et = time.time() + self.assertTrue( + et - st < 4 + ) # should wait about 1 * 3.0 seconds, here we set 4 seconds timeout + self.assertEqual(len(exps), 0) + self.assertEqual(len(statuses), 4) + # tasks take 2 seconds, which is within the dynamic timeout 3.0 * 1.0 = 3.0 seconds + tasks = [] + tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=2)) + scheduler.schedule(tasks, batch_id=2) + statuses, exps = await scheduler.get_results(batch_id=2) + self.assertEqual(len(statuses), 4) + self.assertEqual(len(exps), 4) + def tearDown(self): try: ray.shutdown() diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index c99b7d1dda..641ef6e239 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -987,3 +987,48 @@ def test_trainer(self): def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestOverRollout(BaseTrainerCase): + def test_trainer(self): + self.config.algorithm.repeat_times = 4 + self.config.buffer.batch_size = 4 + self.config.buffer.total_steps = 2 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.name = f"explore-over-rollout-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.explorer.over_rollout.ratio = 0.5 # set over rollout rate to 50%, which means only wait for 2 (4 * 50%) tasks in each steps + self.config.explorer.over_rollout.wait_after_min = 0 + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.advantage_fn_args = { + "epsilon": 1e-6, + } + self.config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER + self.config.synchronizer.sync_interval = 1 + self.config.check_and_update() + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + eval_metrics = parser.metric_list("eval") + self.assertTrue(len(eval_metrics) == 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) + self.assertTrue(parser.metric_exist("experience_pipeline/experience_count")) + experience_counts = parser.metric_values("experience_pipeline/experience_count") + self.assertTrue(len(experience_counts) == 2) + for count in experience_counts: + self.assertTrue( + count >= 2 * 4 + ) # at least process 2 tasks in each step, repeat_times is 4 + pg_loss = parser.metric_values("actor/pg_loss") + self.assertTrue(len(pg_loss) >= 1) # trainer only has at least 1 step + exp_save_path = self.config.buffer.trainer_input.experience_buffer.path + with open(exp_save_path, "r", encoding="utf-8") as f: + lines = f.readlines() + self.assertTrue( + len(lines) >= 2 * 4 * 2 + ) # at least contain total_steps * repeat_times * batch_size * min_waited_tasks + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/common/config.py b/trinity/common/config.py index c722959b96..fc7cea2c00 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -138,6 +138,24 @@ class ReplayBufferConfig: priority_fn_args: Dict = field(default_factory=lambda: {"decay": 2.0}) +@dataclass +class OverRolloutConfig: + """Config for over-rollout in explorer.""" + + ratio: float = 0.0 # explorer will only wait for (1 - over_rollout.ratio) * batch_size of tasks at each step + wait_after_min: float = 30.0 # wait 30 s after reaching minimum task threshold + # more settings will be added in the future + # e.g., postpone tasks into the next step if not finished in time + + +@dataclass +class DynamicTimeoutConfig: + """Config for dynamic timeout in explorer.""" + + enable: bool = False + ratio: float = 3.0 # the timeout for each step will be min(max_timeout, average_time_per_task * dynamic_timeout.ratio) + + @dataclass class StorageConfig: """Storage config for both taskset and experience buffer. @@ -599,7 +617,7 @@ class ExplorerConfig: # for workflow runner # number of workflow runners. runner_per_model: int = 8 # number of runners per each rollout model - max_timeout: int = 1800 # wait each task for 30 minutes + max_timeout: int = 1800 # wait each task for 30 minutes at most max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout env_vars: dict = field(default_factory=dict) # environment variables for workflow runner max_repeat_times_per_runner: Optional[ @@ -629,6 +647,9 @@ class ExplorerConfig: service_status_check_interval: int = 60 # keep at least 1 model in running status min_running_model_num: int = 1 + # Experimental feature + over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig) + dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig) @dataclass @@ -1198,6 +1219,15 @@ def check_and_update(self) -> Config: # noqa: C901 for args in rollout_args + length_args: set_if_none(aux_model, args, getattr(self.model, args)) + if self.explorer.over_rollout.ratio > 0.0: + if not (0.0 <= self.explorer.over_rollout.ratio < 1.0): + raise ValueError("over_rollout_ratio should be in [0.0, 1.0)") + if self.synchronizer.sync_style == SyncStyle.FIXED: + raise ValueError( + "over_rollout_ratio is not compatible with fixed sync_style, please set " + "`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`." + ) + # for lora configs if self.model.lora_configs is not None: self.explorer.rollout_model.enable_lora = True diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index eeb90cec12..b59a1c0833 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -376,6 +376,7 @@ async def convert_messages_to_experience( logprobs=logprobs[prompt_length - 1 :], prompt_length=prompt_length, action_mask=action_mask[prompt_length:], # Exclude the prompt tokens + messages=messages, ) async def shutdown(self): diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 038c1dd5f9..76a4767704 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import math import os import time import traceback @@ -47,6 +48,7 @@ def __init__(self, config: Config): explorer_state = self.state.load_explorer() self.explore_step_num = explorer_state.get("latest_iteration", 0) self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 + self.last_monitored_step = self.explore_step_num if self.explore_step_num > 0 else -1 self.synchronizer = Synchronizer.get_actor(config) self.config = config self.models, self.auxiliary_models = create_inference_models(config) @@ -62,10 +64,15 @@ def __init__(self, config: Config): role=self.config.explorer.name, config=config, ) - self.batch_size = config.buffer.batch_size - self.update_interval = ( - self.config.synchronizer.sync_interval * self.config.buffer.batch_size - ) + if config.explorer.over_rollout.ratio > 0.0: + self.min_wait_num = math.ceil( + config.buffer.batch_size * (1 - config.explorer.over_rollout.ratio) + ) + self.logger.info( + f"Over rollout is enabled. Explorer will only wait for {self.min_wait_num} tasks in each step." + ) + else: + self.min_wait_num = None self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL self.pending_eval_tasks = deque() @@ -110,8 +117,10 @@ async def setup_weight_sync_group( await asyncio.gather(*refs) async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: + self.logger.info(f"Start to update model weights from checkpoint at step {step_num}.") step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num) await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) + self.logger.info(f"Model weights updated to checkpoint at step {step_num}.") return step_num # type: ignore async def _pull_latest_weights(self): @@ -309,6 +318,8 @@ async def benchmark(self) -> bool: ] ) for step_num in all_ckp_steps: + if step_num <= self.explore_step_num: + continue self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num) await self.eval() await self._finish_eval_step(prefix="bench") @@ -317,8 +328,9 @@ async def benchmark(self) -> bool: async def save_checkpoint(self, sync_weight: bool = False) -> None: if self.scheduler: await self._finish_steps( - self.last_sync_step + 1, self.explore_step_num, self.model_version + self.last_monitored_step + 1, self.explore_step_num, self.model_version ) + self.last_monitored_step = self.explore_step_num if sync_weight: # sync weights @@ -357,12 +369,14 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int async def _finish_explore_step(self, step: int, model_version: int) -> None: metric = {"rollout/model_version": model_version} with Timer(metric, "time/wait_explore_step"): - statuses, exps = await self.scheduler.get_results(batch_id=step) + statuses, exps = await self.scheduler.get_results( + batch_id=step, min_num=self.min_wait_num + ) pipeline_metrics = await self.experience_pipeline.process.remote(exps) self.taskset.update(pipeline_metrics) metric.update(pipeline_metrics) if statuses: - metric.update(gather_metrics([status.metric for status in statuses], "rollout")) + metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout")) self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: @@ -375,10 +389,10 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}") + eval_results, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}") metric.update( gather_metrics( - [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" + [status.metrics[0] for status in eval_results], f"{prefix}/{eval_task_name}" ) ) if self.eval_start_time is not None: diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index ae17649c86..ae2fdfe4a6 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -5,7 +5,7 @@ import time import traceback from collections import defaultdict, deque -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from typing import Dict, List, Optional, Tuple, Union import ray @@ -20,12 +20,26 @@ @dataclass class TaskWrapper: - """A wrapper for a task.""" + """A wrapper for a task. + Each task can run multiple times (repeat_times) on same or different runners. + """ task: Task batch_id: Union[int, str] - run_id_base: int = 0 - repeat_times: int = 1 + sub_task_num: int = 1 + results: List[Tuple[Status, List[Experience]]] = field(default_factory=list) + + +def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]: + """Calculate task level metrics from experiences.""" + if not metrics: + return {} + aggregated_metrics: Dict[str, List[float]] = defaultdict(list) + for m in metrics: + for key, value in m.items(): + if isinstance(value, (int, float)): + aggregated_metrics[key].append(value) + return {key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values} class RunnerWrapper: @@ -65,24 +79,44 @@ def _create_runner(self): async def prepare(self): await self.runner.prepare.remote() - async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, List, int]: + async def run_with_retry( + self, task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float + ) -> Tuple[Status, List, int, float]: """ + Args: + task (`TaskWrapper`): The task to run. + repeat_times (`int`): The number of times to repeat the task. + run_id_base (`int`): The base run id for this task runs. + timeout (`float`): The timeout for each task run. + Returns: `Status`: The return status of the task. `List`: The experiences generated by the task. `int`: The runner_id of current runner. + `float`: The time taken to run the task. """ last_exception_msg = None await self.runner.__ray_ready__.remote() start_time = time.time() - status = Status(ok=False, metric=dict()) + status = Status(ok=False, metrics=list()) exps = [] + task2run = replace( + task.task, + rollout_args=replace( + task.task.rollout_args, + n=repeat_times, + ), + ) try: for attempt in range(self.retry_times + 1): try: status, exps = await asyncio.wait_for( - self.runner.run_task.remote(task.task, task.repeat_times, task.run_id_base), - self.timeout, + self.runner.run_task.remote( + task=task2run, + repeat_times=repeat_times, + run_id_base=run_id_base, + ), + timeout=timeout, ) if status.ok: break @@ -91,17 +125,17 @@ async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, List, int]: except asyncio.TimeoutError: last_exception_msg = f"Timeout when running task of batch {task.batch_id} at runner {self.runner_id} at attempt {attempt + 1}: {task.task}" self.logger.error(last_exception_msg) - status = Status(ok=False, metric=dict(), message=last_exception_msg) + status = Status(ok=False, metrics=list(), message=last_exception_msg) except Exception: last_exception_msg = traceback.format_exc() self.logger.warning( f"Task execution attempt {attempt + 1} failed:\n{last_exception_msg}" ) - status = Status(ok=False, metric=dict(), message=last_exception_msg) + status = Status(ok=False, metrics=list(), message=last_exception_msg) finally: end_time = time.time() - status.metric["task_run_time"] = end_time - start_time - return status, exps, self.runner_id + status.metrics.append({"time/task_execution": end_time - start_time}) + return status, exps, self.runner_id, end_time - start_time async def restart_runner(self): old_runner = self.runner @@ -128,7 +162,11 @@ def sort_batch_id(batch_id: Union[int, str]): class Scheduler: - """Scheduler for rollout tasks.""" + """Scheduler for rollout tasks. + + Supports scheduling tasks to multiple runners, retrying failed tasks, + and collecting results at different levels. + """ def __init__( self, @@ -144,17 +182,23 @@ def __init__( self.default_timeout = config.explorer.max_timeout * (config.explorer.max_retry_times + 1) self.max_retry_times = config.explorer.max_retry_times self.max_repeat_times = config.explorer.max_repeat_times_per_runner + self.default_batch_size = config.buffer.batch_size self.running = False self.runner_num = len(rollout_model) * config.explorer.runner_per_model self.runners: Dict[int, RunnerWrapper] = dict() - self.idle_runners = set() # runner_id + self.idle_runners = set() # runner_id of idle runners self.busy_runners = dict() # runner_id -> task - self.pending_tasks: Dict[Union[int, str], deque] = defaultdict(deque) # batch_id -> tasks + self.pending_tasks: Dict[Union[int, str], deque] = defaultdict( + deque + ) # batch_id -> (task, repeat_times, run_id_base) self.running_tasks: Dict[Union[int, str], set[asyncio.Future]] = defaultdict( set ) # batch_id -> futures + self.task_num_map: Dict[Union[int, str], int] = defaultdict( + int + ) # batch_id -> tasks scheduled under this batch_id self.running_task_map: Dict[asyncio.Future, TaskWrapper] = dict() # future -> task self.completed_tasks: Dict[ Union[int, str], deque[Tuple[Status, List[Experience]]] @@ -165,8 +209,8 @@ def __init__( self.scheduler_task: Optional[asyncio.Task] = None self.running = False - self.total_scheduled = 0 - self.total_completed = 0 + self.total_running_time = 0.0 + self.total_completed_tasks = 0 async def _create_runner( self, @@ -187,7 +231,7 @@ async def _create_runner( async def _restart_runner(self, runner_id: int): """Restart a runner.""" - self.runners[runner_id].restart_runner() + await self.runners[runner_id].restart_runner() if runner_id in self.busy_runners: task = self.busy_runners.pop(runner_id) @@ -218,10 +262,17 @@ async def _schedule_pending_tasks(self) -> None: task_queue = self.pending_tasks[batch_id] while task_queue and self.idle_runners: - task = task_queue.pop() + task, repeat_times, run_id_base = task_queue.pop() runner_id = self.idle_runners.pop() self.busy_runners[runner_id] = task - future = asyncio.create_task(self.runners[runner_id].run_with_retry(task)) + future = asyncio.create_task( + self.runners[runner_id].run_with_retry( + task, + repeat_times=repeat_times, + run_id_base=run_id_base, + timeout=self.dynamic_timeout(), + ) + ) self.running_task_map[future] = task future.add_done_callback(self.task_done_callback) self.running_tasks[batch_id].add(future) @@ -237,11 +288,26 @@ def task_done_callback(self, async_task: asyncio.Task): self.logger.error(f"Task {task.task.task_id} failed: {async_task.exception()}") return else: - status, exps, runner_id = async_task.result() - self.completed_tasks[task.batch_id].appendleft((status, exps)) + status, exps, runner_id, run_time = async_task.result() + self.total_running_time += run_time + self.total_completed_tasks += 1 + task.results.append((status, exps)) self.busy_runners.pop(runner_id) self.idle_runners.add(runner_id) - self.logger.debug(f"Task completed (batch_id {task.batch_id}), success: {status.ok}") + if len(task.results) == task.sub_task_num: + task_experiences = [] + task_metrics = [] + all_success = True + for s, exp in task.results: + task_metrics.extend(s.metrics) + task_experiences.extend(exp) + if not s.ok: + all_success = False + task_status = Status( + ok=all_success, metrics=[calculate_task_level_metrics(task_metrics)] + ) + self.completed_tasks[task.batch_id].appendleft((task_status, task_experiences)) + self.logger.debug(f"Task completed (batch_id {task.batch_id}).") if task.batch_id in self.running_tasks: self.running_tasks[task.batch_id].remove(async_task) @@ -257,6 +323,7 @@ def _clear_timeout_tasks(self, batch_id: Union[int, str]) -> None: for future in self.running_tasks[batch_id]: future.cancel() del self.running_tasks[batch_id] + self.task_num_map.pop(batch_id, None) async def start(self) -> None: if self.running: @@ -299,41 +366,39 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: """ if not tasks: return + self.task_num_map[batch_id] += len(tasks) self._split_and_submit_tasks(tasks, batch_id=batch_id) def _split_and_submit_tasks(self, tasks: List[Task], batch_id: Union[int, str]) -> None: for i, task in enumerate(tasks): assert task.repeat_times is not None, "Task repeat_times should not be None" + task_wrapper = TaskWrapper( + task=replace(task, batch_id=batch_id, task_id=i), + batch_id=batch_id, + ) if self.max_repeat_times is None: - self.pending_tasks[batch_id].appendleft( - TaskWrapper( - task=replace(task, batch_id=batch_id, task_id=i), - batch_id=batch_id, - run_id_base=0, - repeat_times=task.repeat_times, - ) - ) + task_wrapper.sub_task_num = 1 + self.pending_tasks[batch_id].appendleft((task_wrapper, task.repeat_times, 0)) continue - rest_repeat_times = task.repeat_times - run_id_base = 0 - while rest_repeat_times > 0: - repeat_times = min(self.max_repeat_times, rest_repeat_times) - task_wrapper = TaskWrapper( - task=replace( - task, - batch_id=batch_id, - task_id=i, - rollout_args=replace( - task.rollout_args, n=repeat_times - ), # deprecated: use TaskWrapper.repeat_times - ), - batch_id=batch_id, - run_id_base=run_id_base, - repeat_times=repeat_times, - ) - run_id_base += repeat_times - rest_repeat_times -= repeat_times - self.pending_tasks[batch_id].appendleft(task_wrapper) + sub_tasks = [] + for run_id_base in range(0, task.repeat_times, self.max_repeat_times): + repeat_times = min(self.max_repeat_times, task.repeat_times - run_id_base) + sub_tasks.append((task_wrapper, repeat_times, run_id_base)) + task_wrapper.sub_task_num = len(sub_tasks) + self.pending_tasks[batch_id].extendleft(sub_tasks) + + def dynamic_timeout(self, timeout: Optional[float] = None) -> float: + """Calculate dynamic timeout based on historical data.""" + max_timeout = timeout or self.default_timeout + if not self.config.explorer.dynamic_timeout.enable: + return max_timeout + if self.total_completed_tasks < self.default_batch_size: + return max_timeout + avg_time_per_task = self.total_running_time / self.total_completed_tasks + return min( + max_timeout, + avg_time_per_task * self.config.explorer.dynamic_timeout.ratio, + ) async def get_results( self, @@ -352,43 +417,52 @@ async def get_results( """ timeout = timeout or self.default_timeout start_time = time.time() + scheduled_num = self.task_num_map.get(batch_id, 0) if min_num is None: - min_num = sum( - len(tasks) # type: ignore [misc] - for tasks in ( - self.pending_tasks.get(batch_id, []), - self.running_tasks.get(batch_id, []), - self.completed_tasks.get(batch_id, []), - ) + min_num = scheduled_num + elif min_num > scheduled_num: + self.logger.warning( + f"Requested min_num {min_num} is greater than scheduled tasks {scheduled_num} at batch_id {batch_id}. Adjusting min_num to {scheduled_num}." ) + min_num = scheduled_num self.logger.debug(f"Waiting for {min_num} tasks to complete...") - + min_threshold_reached_time = None while time.time() - start_time <= timeout: completed_count = len(self.completed_tasks.get(batch_id, [])) if completed_count >= min_num: - break + min_threshold_reached_time = min_threshold_reached_time or time.time() + if (completed_count >= scheduled_num) or ( + time.time() - min_threshold_reached_time + >= self.config.explorer.over_rollout.wait_after_min + ): + break await asyncio.sleep(0.1) if time.time() - start_time > timeout: - self.logger.error(f"Timed out waiting for tasks to complete after {timeout} seconds") + self.logger.error( + f"Timed out waiting for tasks at batch {batch_id} to complete after {timeout} seconds" + ) if clear_timeout_tasks: self._clear_timeout_tasks(batch_id=batch_id) + runners_to_restart = [] for runner_id, task in list(self.busy_runners.items()): if task.batch_id == batch_id: - await self._restart_runner(runner_id) + runners_to_restart.append(runner_id) + asyncio.gather( + *[self._restart_runner(runner_id) for runner_id in runners_to_restart] + ) statuses = [] experiences = [] completed_queue = self.completed_tasks.get(batch_id, deque()) - for _ in range(min_num): - if completed_queue: - status, exps = completed_queue.pop() - statuses.append(status) - if isinstance(exps, list): - experiences.extend(exps) - else: - experiences.append(exps) + while completed_queue: + status, exps = completed_queue.pop() + statuses.append(status) + if isinstance(exps, list): + experiences.extend(exps) + else: + experiences.append(exps) if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: del self.completed_tasks[batch_id] @@ -446,8 +520,8 @@ async def wait_all( if clear_timeout_tasks: for batch_id in self.pending_tasks.keys() | self.running_tasks.keys(): self._clear_timeout_tasks(batch_id) - busy_runner_ids = list(self.busy_runners.keys()) - for runner_id in busy_runner_ids: - await self._restart_runner(runner_id) + asyncio.gather( + *[self._restart_runner(runner_id) for runner_id in self.busy_runners.keys()] + ) raise TimeoutError(error_msg) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 5c6a3933d4..85af23aa1b 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -5,7 +5,7 @@ import traceback from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from trinity.buffer import get_buffer_reader from trinity.common.config import Config @@ -21,10 +21,30 @@ class Status: """Status of the task running result.""" ok: bool - metric: dict[str, float] + metrics: List[Dict[str, float]] + # A list of metric dictionaries, where each dictionary is from a single run. message: Optional[str] = None +def calculate_run_level_metrics(experiences: List[Experience]) -> Dict[str, float]: + """Calculate metrics from experiences. + + For non-repeatable workflows, this function will average the metrics from experiences + generated by each run, which is equivalent to calculating run level metrics. + + For repeatable workflows, please do not use this function. + """ + run_level_metrics: Dict[str, List[float]] = defaultdict(list) + for exp in experiences: + if exp.metrics: + for k, v in exp.metrics.items(): + run_level_metrics[k].append(v) + averaged_metrics: Dict[str, float] = {} + for key, values in run_level_metrics.items(): + averaged_metrics[key] = sum(values) / len(values) + return averaged_metrics + + class WorkflowRunner: """A Ray remote actor to run the workflow and generate experiences.""" @@ -96,22 +116,35 @@ async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]: exps = workflow_instance.run() return exps - async def _run_task(self, task: Task, repeat_times: int, run_id_base: int) -> List[Experience]: + async def _run_task( + self, task: Task, repeat_times: int, run_id_base: int + ) -> Tuple[List[Experience], List[Dict]]: """Init workflow from the task and run it.""" self._create_workflow_instance(task) if self.workflow_instance.repeatable: self.workflow_instance.set_repeat_times(repeat_times, run_id_base) + st = time.time() exps = await self._run_workflow(self.workflow_instance) + task_execution_time = time.time() - st + # repeatable workflow cannot calculate run level metrics, we use experience level metrics directly + run_metrics = [exp.metrics for exp in exps if exp.metrics] + for metric in run_metrics: + metric["time/task_execution"] = task_execution_time else: exps = [] + run_metrics = [] for i in range(repeat_times): + st = time.time() new_exps = await self._run_workflow(self.workflow_instance) + run_metric = calculate_run_level_metrics(new_exps) + run_metric["time/task_execution"] = time.time() - st + run_metrics.append(run_metric) for exp in new_exps: exp.eid.run = run_id_base + i exps.extend(new_exps) if i < repeat_times - 1: self._create_workflow_instance(task) - return exps + return exps, run_metrics async def run_task( self, @@ -123,12 +156,11 @@ async def run_task( # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead try: st = time.time() - exps = await self._run_task(task, repeat_times, run_id_base) + exps, metrics = await self._run_task(task, repeat_times, run_id_base) assert exps is not None and len(exps) > 0, "An empty experience is generated" - metrics: dict[str, List[float]] = defaultdict(list) model_version = await self.model_wrapper.model_version_async # set eid for each experience - for i, exp in enumerate(exps): + for exp in exps: exp.eid.batch = task.batch_id # keep exp.eid.task if it has been set before (e.g., in workflow) if exp.eid.task == "": # "" is the default value @@ -141,25 +173,20 @@ async def run_task( if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} - for k, v in exp.metrics.items(): - metrics[k].append(v) - # We get the average of metrics into the state - metric = {} - metric["time_per_task"] = time.time() - st - if metrics: - for k, v in metrics.items(): - metric[k] = sum(v) / len(v) # type: ignore if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer - return Status(True, metric=metric), [] + return Status(True, metrics=metrics), [] else: - return Status(True, metric=metric), exps + return Status(True, metrics=metrics), exps except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") - return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), [] + return ( + Status(False, metrics=[{"time/task_execution": time.time() - st}], message=str(e)), + [], + ) class DebugWorkflowRunner(WorkflowRunner): @@ -186,7 +213,7 @@ async def debug(self) -> None: with VizTracer(output_file=self.output_file): status, exps = await self.run_task(task, task.repeat_times, 0) if status.ok: - print(f"Task {task.task_id} completed successfully with metrics:\n{status.metric}") + print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") for exp in exps: print(f"Generated experience:\n{exp}") else: