From c929ec8feefefc5f81bffc4f3a9c720f86fdedfe Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 10 Nov 2025 09:59:51 +0800 Subject: [PATCH 1/6] fix `default_sampling_params` and `simple_workflow` --- trinity/common/models/vllm_model.py | 2 -- trinity/common/workflows/workflow.py | 18 ++++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 0133586f7c..9bb8b23348 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -62,8 +62,6 @@ def __init__( ) self.default_sampling_params = vllm.SamplingParams( n=1, - temperature=0.0, - max_tokens=config.max_response_tokens, min_tokens=config.min_response_tokens, truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 8a493e161f..91716e1688 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -190,13 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc return experience -@WORKFLOWS.register_module("simple_workflow") -class SimpleWorkflow(Workflow): - """A workflow for simple single-round task.""" - - can_reset: bool = True - can_repeat: bool = True - +class BaseSimpleWorkflow(Workflow): def __init__( self, *, @@ -246,6 +240,14 @@ def format_messages(self): messages.append({"role": "assistant", "content": self.reply_prefix}) return messages + +@WORKFLOWS.register_module("simple_workflow") +class SimpleWorkflow(BaseSimpleWorkflow): + """A workflow for simple single-round task.""" + + can_reset: bool = True + can_repeat: bool = True + def run(self) -> List[Experience]: # TODO: Optimize the generate function messages = self.format_messages() @@ -272,7 +274,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_simple_workflow") -class AsyncSimpleWorkflow(Workflow): +class AsyncSimpleWorkflow(BaseSimpleWorkflow): is_async: bool = True async def run_async(self) -> List[Experience]: From c36a09b0ca108dcbd75c855958a7aeea396cd717 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 10 Nov 2025 14:31:08 +0800 Subject: [PATCH 2/6] fix metric calcuation to aggregate by task_id --- trinity/explorer/explorer.py | 9 ++---- trinity/explorer/scheduler.py | 24 ++++++++-------- trinity/explorer/workflow_runner.py | 43 ++++++++++++++++++++++------- 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 038c1dd5f9..1371e0d6f9 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -26,6 +26,7 @@ from trinity.common.models import create_inference_models from trinity.common.models.utils import get_checkpoint_dir_with_step_num from trinity.explorer.scheduler import Scheduler +from trinity.explorer.workflow_runner import group_metrics from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer from trinity.utils.annotations import Experimental @@ -362,7 +363,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: self.taskset.update(pipeline_metrics) metric.update(pipeline_metrics) if statuses: - metric.update(gather_metrics([status.metric for status in statuses], "rollout")) + metric.update(gather_metrics(group_metrics(statuses), "rollout")) self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: @@ -376,11 +377,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva return self.pending_eval_tasks.popleft() eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}") - metric.update( - gather_metrics( - [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" - ) - ) + metric.update(gather_metrics(group_metrics(eval_results), f"{prefix}/{eval_task_name}")) if self.eval_start_time is not None: metric.update({"time/eval": time.time() - self.eval_start_time}) self.eval_start_time = None diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index ae17649c86..04cf6b59ba 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -381,23 +381,21 @@ async def get_results( statuses = [] experiences = [] completed_queue = self.completed_tasks.get(batch_id, deque()) - for _ in range(min_num): - if completed_queue: - status, exps = completed_queue.pop() - statuses.append(status) - if isinstance(exps, list): - experiences.extend(exps) - else: - experiences.append(exps) - - if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: - del self.completed_tasks[batch_id] - - completed_count = len(statuses) + completed_count = len(completed_queue) if completed_count < min_num: self.logger.warning( f"Timeout reached, only {completed_count}/{min_num} tasks completed" ) + while completed_queue: + status, exps = completed_queue.pop() + statuses.append(status) + if isinstance(exps, list): + experiences.extend(exps) + else: + experiences.append(exps) + + if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: + del self.completed_tasks[batch_id] return statuses, experiences diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 5c6a3933d4..7312299682 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -5,7 +5,7 @@ import traceback from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from trinity.buffer import get_buffer_reader from trinity.common.config import Config @@ -21,8 +21,26 @@ class Status: """Status of the task running result.""" ok: bool - metric: dict[str, float] + metric: dict[str, Union[float, List[float]]] message: Optional[str] = None + task_id: Union[int, str] = "" + + +def group_metrics(statuses: List[Status]): + task2metrics = {} + for status in statuses: + task_id = status.task_id + metric = status.metric + if task_id not in task2metrics: + task2metrics[task_id] = metric + else: + for k, v in metric.items(): + task2metrics[task_id][k] += v # type: ignore + metric_list = [ + {k: sum(v) / len(v) if isinstance(v, list) else v for k, v in metrics.items()} + for metrics in task2metrics.values() + ] + return metric_list class WorkflowRunner: @@ -144,22 +162,27 @@ async def run_task( for k, v in exp.metrics.items(): metrics[k].append(v) # We get the average of metrics into the state - metric = {} - metric["time_per_task"] = time.time() - st - if metrics: - for k, v in metrics.items(): - metric[k] = sum(v) / len(v) # type: ignore + metric: dict[str, Union[float, List[float]]] = {"time_per_task": time.time() - st} + metric.update(metrics) if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer - return Status(True, metric=metric), [] + return Status(True, metric=metric, task_id=task.task_id), [] else: - return Status(True, metric=metric), exps + return Status(True, metric=metric, task_id=task.task_id), exps except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") - return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), [] + return ( + Status( + False, + metric={"time_per_task": time.time() - st}, + message=str(e), + task_id=task.task_id, + ), + [], + ) class DebugWorkflowRunner(WorkflowRunner): From 79665e04d9538fb097408a7eb63b45a569963587 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 10 Nov 2025 15:55:25 +0800 Subject: [PATCH 3/6] apply suggestions --- trinity/explorer/workflow_runner.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 7312299682..4bba3995be 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -36,10 +36,16 @@ def group_metrics(statuses: List[Status]): else: for k, v in metric.items(): task2metrics[task_id][k] += v # type: ignore - metric_list = [ - {k: sum(v) / len(v) if isinstance(v, list) else v for k, v in metrics.items()} - for metrics in task2metrics.values() - ] + + metric_list = [] + for metrics in task2metrics.values(): + agg_metrics = {} + for k, v in metrics.items(): + if isinstance(v, list): + agg_metrics[k] = sum(v) / len(v) + else: + agg_metrics[k] = v + metric_list.append(agg_metrics) return metric_list From 1e49cfb14789f9c014ab894e4a35995349308eee Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 11 Nov 2025 10:45:50 +0800 Subject: [PATCH 4/6] add unittest for task aggregated metric calcuation --- tests/explorer/explorer_test.py | 19 ++++++++++++++++++- tests/utils/plugins/my_workflow.py | 10 ++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 91222e9884..5035fec4bb 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -24,9 +24,10 @@ from trinity.buffer import get_buffer_reader from trinity.cli.launcher import explore, run_stage from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig -from trinity.common.constants import StorageType +from trinity.common.constants import PLUGIN_DIRS_ENV_VAR, StorageType from trinity.explorer.explorer import Explorer from trinity.manager.state_manager import StateManager +from trinity.utils.plugin_loader import load_plugins class BaseExplorerCase(RayUnittestBase): @@ -45,6 +46,22 @@ def setUp(self): self.config.explorer.eval_interval = 4 +class TestExplorerCountdownMaxRepeatTimes(BaseExplorerCase): + def test_explorer(self): + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.taskset.default_workflow_type = "custom_workflow" + self.config.algorithm.repeat_times = 4 + self.config.explorer.max_repeat_times_per_runner = 3 + self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.check_and_update() + os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join("tests", "utils", "plugins") + load_plugins() + explore(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + custom_metric_mean = parser.metric_values("rollout/custom_metric/mean") + self.assertEqual(custom_metric_mean, [0.75] * 8) + + class TestExplorerCountdownEval(BaseExplorerCase): def test_explorer(self): self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") diff --git a/tests/utils/plugins/my_workflow.py b/tests/utils/plugins/my_workflow.py index 471b2371cf..510b4ee529 100644 --- a/tests/utils/plugins/my_workflow.py +++ b/tests/utils/plugins/my_workflow.py @@ -1,6 +1,7 @@ from typing import List from trinity.common.workflows import WORKFLOWS, Workflow +from trinity.common.workflows.workflow import MathWorkflow @WORKFLOWS.register_module("my_workflow") @@ -17,3 +18,12 @@ def set_repeat_times(self, repeat_times, run_id_base): def run(self) -> List: return ["Hello world", "Hi"] + + +@WORKFLOWS.register_module("custom_workflow") +class CustomWorkflow(MathWorkflow): + def run(self): + responses = super().run() + for i, response in enumerate(responses): + response.metrics["custom_metric"] = i + return responses From 47422fed1e2a492fa2afee32baa0316aabf4396a Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 11 Nov 2025 11:38:09 +0800 Subject: [PATCH 5/6] bug fix for unittest --- trinity/common/models/vllm_model.py | 4 +++- trinity/explorer/workflow_runner.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 9bb8b23348..eeb90cec12 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -62,12 +62,14 @@ def __init__( ) self.default_sampling_params = vllm.SamplingParams( n=1, + temperature=config.temperature, + max_tokens=config.max_response_tokens, min_tokens=config.min_response_tokens, truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, output_kind=RequestOutputKind.FINAL_ONLY, - logprobs=0, + logprobs=config.logprobs, ignore_eos=config.ignore_eos, ) self.enable_thinking = config.enable_thinking diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 4bba3995be..574596028c 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -167,7 +167,7 @@ async def run_task( exp.metrics = {} for k, v in exp.metrics.items(): metrics[k].append(v) - # We get the average of metrics into the state + metric: dict[str, Union[float, List[float]]] = {"time_per_task": time.time() - st} metric.update(metrics) From f6554fe6532b9552b1d290c616cd6b627394b5f4 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 12 Nov 2025 19:41:13 +0800 Subject: [PATCH 6/6] revert metric modification --- tests/explorer/explorer_test.py | 19 +---------- tests/utils/plugins/my_workflow.py | 10 ------ trinity/explorer/explorer.py | 9 +++-- trinity/explorer/scheduler.py | 24 +++++++------- trinity/explorer/workflow_runner.py | 51 +++++++---------------------- 5 files changed, 31 insertions(+), 82 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 5035fec4bb..91222e9884 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -24,10 +24,9 @@ from trinity.buffer import get_buffer_reader from trinity.cli.launcher import explore, run_stage from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig -from trinity.common.constants import PLUGIN_DIRS_ENV_VAR, StorageType +from trinity.common.constants import StorageType from trinity.explorer.explorer import Explorer from trinity.manager.state_manager import StateManager -from trinity.utils.plugin_loader import load_plugins class BaseExplorerCase(RayUnittestBase): @@ -46,22 +45,6 @@ def setUp(self): self.config.explorer.eval_interval = 4 -class TestExplorerCountdownMaxRepeatTimes(BaseExplorerCase): - def test_explorer(self): - self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - self.config.buffer.explorer_input.taskset.default_workflow_type = "custom_workflow" - self.config.algorithm.repeat_times = 4 - self.config.explorer.max_repeat_times_per_runner = 3 - self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" - self.config.check_and_update() - os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join("tests", "utils", "plugins") - load_plugins() - explore(self.config) - parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) - custom_metric_mean = parser.metric_values("rollout/custom_metric/mean") - self.assertEqual(custom_metric_mean, [0.75] * 8) - - class TestExplorerCountdownEval(BaseExplorerCase): def test_explorer(self): self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") diff --git a/tests/utils/plugins/my_workflow.py b/tests/utils/plugins/my_workflow.py index 510b4ee529..471b2371cf 100644 --- a/tests/utils/plugins/my_workflow.py +++ b/tests/utils/plugins/my_workflow.py @@ -1,7 +1,6 @@ from typing import List from trinity.common.workflows import WORKFLOWS, Workflow -from trinity.common.workflows.workflow import MathWorkflow @WORKFLOWS.register_module("my_workflow") @@ -18,12 +17,3 @@ def set_repeat_times(self, repeat_times, run_id_base): def run(self) -> List: return ["Hello world", "Hi"] - - -@WORKFLOWS.register_module("custom_workflow") -class CustomWorkflow(MathWorkflow): - def run(self): - responses = super().run() - for i, response in enumerate(responses): - response.metrics["custom_metric"] = i - return responses diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 1371e0d6f9..038c1dd5f9 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -26,7 +26,6 @@ from trinity.common.models import create_inference_models from trinity.common.models.utils import get_checkpoint_dir_with_step_num from trinity.explorer.scheduler import Scheduler -from trinity.explorer.workflow_runner import group_metrics from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer from trinity.utils.annotations import Experimental @@ -363,7 +362,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: self.taskset.update(pipeline_metrics) metric.update(pipeline_metrics) if statuses: - metric.update(gather_metrics(group_metrics(statuses), "rollout")) + metric.update(gather_metrics([status.metric for status in statuses], "rollout")) self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: @@ -377,7 +376,11 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva return self.pending_eval_tasks.popleft() eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}") - metric.update(gather_metrics(group_metrics(eval_results), f"{prefix}/{eval_task_name}")) + metric.update( + gather_metrics( + [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" + ) + ) if self.eval_start_time is not None: metric.update({"time/eval": time.time() - self.eval_start_time}) self.eval_start_time = None diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 04cf6b59ba..ae17649c86 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -381,21 +381,23 @@ async def get_results( statuses = [] experiences = [] completed_queue = self.completed_tasks.get(batch_id, deque()) - completed_count = len(completed_queue) + for _ in range(min_num): + if completed_queue: + status, exps = completed_queue.pop() + statuses.append(status) + if isinstance(exps, list): + experiences.extend(exps) + else: + experiences.append(exps) + + if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: + del self.completed_tasks[batch_id] + + completed_count = len(statuses) if completed_count < min_num: self.logger.warning( f"Timeout reached, only {completed_count}/{min_num} tasks completed" ) - while completed_queue: - status, exps = completed_queue.pop() - statuses.append(status) - if isinstance(exps, list): - experiences.extend(exps) - else: - experiences.append(exps) - - if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: - del self.completed_tasks[batch_id] return statuses, experiences diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 574596028c..5c6a3933d4 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -5,7 +5,7 @@ import traceback from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple from trinity.buffer import get_buffer_reader from trinity.common.config import Config @@ -21,32 +21,8 @@ class Status: """Status of the task running result.""" ok: bool - metric: dict[str, Union[float, List[float]]] + metric: dict[str, float] message: Optional[str] = None - task_id: Union[int, str] = "" - - -def group_metrics(statuses: List[Status]): - task2metrics = {} - for status in statuses: - task_id = status.task_id - metric = status.metric - if task_id not in task2metrics: - task2metrics[task_id] = metric - else: - for k, v in metric.items(): - task2metrics[task_id][k] += v # type: ignore - - metric_list = [] - for metrics in task2metrics.values(): - agg_metrics = {} - for k, v in metrics.items(): - if isinstance(v, list): - agg_metrics[k] = sum(v) / len(v) - else: - agg_metrics[k] = v - metric_list.append(agg_metrics) - return metric_list class WorkflowRunner: @@ -167,28 +143,23 @@ async def run_task( exp.metrics = {} for k, v in exp.metrics.items(): metrics[k].append(v) - - metric: dict[str, Union[float, List[float]]] = {"time_per_task": time.time() - st} - metric.update(metrics) + # We get the average of metrics into the state + metric = {} + metric["time_per_task"] = time.time() - st + if metrics: + for k, v in metrics.items(): + metric[k] = sum(v) / len(v) # type: ignore if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer - return Status(True, metric=metric, task_id=task.task_id), [] + return Status(True, metric=metric), [] else: - return Status(True, metric=metric, task_id=task.task_id), exps + return Status(True, metric=metric), exps except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") - return ( - Status( - False, - metric={"time_per_task": time.time() - st}, - message=str(e), - task_id=task.task_id, - ), - [], - ) + return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), [] class DebugWorkflowRunner(WorkflowRunner):