Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def __init__(
)
self.default_sampling_params = vllm.SamplingParams(
n=1,
temperature=0.0,
max_tokens=config.max_response_tokens,
min_tokens=config.min_response_tokens,
truncate_prompt_tokens=config.max_prompt_tokens,
skip_special_tokens=True,
Expand Down
18 changes: 10 additions & 8 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc
return experience


@WORKFLOWS.register_module("simple_workflow")
class SimpleWorkflow(Workflow):
"""A workflow for simple single-round task."""

can_reset: bool = True
can_repeat: bool = True

class BaseSimpleWorkflow(Workflow):
def __init__(
self,
*,
Expand Down Expand Up @@ -246,6 +240,14 @@ def format_messages(self):
messages.append({"role": "assistant", "content": self.reply_prefix})
return messages


@WORKFLOWS.register_module("simple_workflow")
class SimpleWorkflow(BaseSimpleWorkflow):
"""A workflow for simple single-round task."""

can_reset: bool = True
can_repeat: bool = True

def run(self) -> List[Experience]:
# TODO: Optimize the generate function
messages = self.format_messages()
Expand All @@ -272,7 +274,7 @@ def run(self) -> List[Experience]:


@WORKFLOWS.register_module("async_simple_workflow")
class AsyncSimpleWorkflow(Workflow):
class AsyncSimpleWorkflow(BaseSimpleWorkflow):
is_async: bool = True

async def run_async(self) -> List[Experience]:
Expand Down
9 changes: 3 additions & 6 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from trinity.common.models import create_inference_models
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
from trinity.explorer.scheduler import Scheduler
from trinity.explorer.workflow_runner import group_metrics
from trinity.manager.state_manager import StateManager
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.annotations import Experimental
Expand Down Expand Up @@ -362,7 +363,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None:
self.taskset.update(pipeline_metrics)
metric.update(pipeline_metrics)
if statuses:
metric.update(gather_metrics([status.metric for status in statuses], "rollout"))
metric.update(gather_metrics(group_metrics(statuses), "rollout"))
self.monitor.log(metric, step=step)

async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
Expand All @@ -376,11 +377,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
return
self.pending_eval_tasks.popleft()
eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}")
metric.update(
gather_metrics(
[status.metric for status in eval_results], f"{prefix}/{eval_task_name}"
)
)
metric.update(gather_metrics(group_metrics(eval_results), f"{prefix}/{eval_task_name}"))
if self.eval_start_time is not None:
metric.update({"time/eval": time.time() - self.eval_start_time})
self.eval_start_time = None
Expand Down
24 changes: 11 additions & 13 deletions trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,23 +381,21 @@ async def get_results(
statuses = []
experiences = []
completed_queue = self.completed_tasks.get(batch_id, deque())
for _ in range(min_num):
if completed_queue:
status, exps = completed_queue.pop()
statuses.append(status)
if isinstance(exps, list):
experiences.extend(exps)
else:
experiences.append(exps)

if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]:
del self.completed_tasks[batch_id]

completed_count = len(statuses)
completed_count = len(completed_queue)
if completed_count < min_num:
self.logger.warning(
f"Timeout reached, only {completed_count}/{min_num} tasks completed"
)
while completed_queue:
status, exps = completed_queue.pop()
statuses.append(status)
if isinstance(exps, list):
experiences.extend(exps)
else:
experiences.append(exps)

if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]:
del self.completed_tasks[batch_id]

return statuses, experiences

Expand Down
49 changes: 39 additions & 10 deletions trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

from trinity.buffer import get_buffer_reader
from trinity.common.config import Config
Expand All @@ -21,8 +21,32 @@ class Status:
"""Status of the task running result."""

ok: bool
metric: dict[str, float]
metric: dict[str, Union[float, List[float]]]
message: Optional[str] = None
task_id: Union[int, str] = ""


def group_metrics(statuses: List[Status]):
task2metrics = {}
for status in statuses:
task_id = status.task_id
metric = status.metric
if task_id not in task2metrics:
task2metrics[task_id] = metric
else:
for k, v in metric.items():
task2metrics[task_id][k] += v # type: ignore

metric_list = []
for metrics in task2metrics.values():
agg_metrics = {}
for k, v in metrics.items():
if isinstance(v, list):
agg_metrics[k] = sum(v) / len(v)
else:
agg_metrics[k] = v
metric_list.append(agg_metrics)
return metric_list


class WorkflowRunner:
Expand Down Expand Up @@ -144,22 +168,27 @@ async def run_task(
for k, v in exp.metrics.items():
metrics[k].append(v)
# We get the average of metrics into the state
metric = {}
metric["time_per_task"] = time.time() - st
if metrics:
for k, v in metrics.items():
metric[k] = sum(v) / len(v) # type: ignore
metric: dict[str, Union[float, List[float]]] = {"time_per_task": time.time() - st}
metric.update(metrics)

if task.is_eval:
# If the task is an evaluation task, we do not record the experiences to the buffer
return Status(True, metric=metric), []
return Status(True, metric=metric, task_id=task.task_id), []
else:
return Status(True, metric=metric), exps
return Status(True, metric=metric, task_id=task.task_id), exps

except Exception as e:
error_trace_back = traceback.format_exc()
self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}")
return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), []
return (
Status(
False,
metric={"time_per_task": time.time() - st},
message=str(e),
task_id=task.task_id,
),
[],
)


class DebugWorkflowRunner(WorkflowRunner):
Expand Down