Skip to content

Commit c36a09b

Browse files
committed
fix metric calcuation to aggregate by task_id
1 parent c929ec8 commit c36a09b

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

trinity/explorer/explorer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from trinity.common.models import create_inference_models
2727
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
2828
from trinity.explorer.scheduler import Scheduler
29+
from trinity.explorer.workflow_runner import group_metrics
2930
from trinity.manager.state_manager import StateManager
3031
from trinity.manager.synchronizer import Synchronizer
3132
from trinity.utils.annotations import Experimental
@@ -362,7 +363,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None:
362363
self.taskset.update(pipeline_metrics)
363364
metric.update(pipeline_metrics)
364365
if statuses:
365-
metric.update(gather_metrics([status.metric for status in statuses], "rollout"))
366+
metric.update(gather_metrics(group_metrics(statuses), "rollout"))
366367
self.monitor.log(metric, step=step)
367368

368369
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
@@ -376,11 +377,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
376377
return
377378
self.pending_eval_tasks.popleft()
378379
eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}")
379-
metric.update(
380-
gather_metrics(
381-
[status.metric for status in eval_results], f"{prefix}/{eval_task_name}"
382-
)
383-
)
380+
metric.update(gather_metrics(group_metrics(eval_results), f"{prefix}/{eval_task_name}"))
384381
if self.eval_start_time is not None:
385382
metric.update({"time/eval": time.time() - self.eval_start_time})
386383
self.eval_start_time = None

trinity/explorer/scheduler.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -381,23 +381,21 @@ async def get_results(
381381
statuses = []
382382
experiences = []
383383
completed_queue = self.completed_tasks.get(batch_id, deque())
384-
for _ in range(min_num):
385-
if completed_queue:
386-
status, exps = completed_queue.pop()
387-
statuses.append(status)
388-
if isinstance(exps, list):
389-
experiences.extend(exps)
390-
else:
391-
experiences.append(exps)
392-
393-
if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]:
394-
del self.completed_tasks[batch_id]
395-
396-
completed_count = len(statuses)
384+
completed_count = len(completed_queue)
397385
if completed_count < min_num:
398386
self.logger.warning(
399387
f"Timeout reached, only {completed_count}/{min_num} tasks completed"
400388
)
389+
while completed_queue:
390+
status, exps = completed_queue.pop()
391+
statuses.append(status)
392+
if isinstance(exps, list):
393+
experiences.extend(exps)
394+
else:
395+
experiences.append(exps)
396+
397+
if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]:
398+
del self.completed_tasks[batch_id]
401399

402400
return statuses, experiences
403401

trinity/explorer/workflow_runner.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import traceback
66
from collections import defaultdict
77
from dataclasses import dataclass
8-
from typing import List, Optional, Tuple
8+
from typing import List, Optional, Tuple, Union
99

1010
from trinity.buffer import get_buffer_reader
1111
from trinity.common.config import Config
@@ -21,8 +21,26 @@ class Status:
2121
"""Status of the task running result."""
2222

2323
ok: bool
24-
metric: dict[str, float]
24+
metric: dict[str, Union[float, List[float]]]
2525
message: Optional[str] = None
26+
task_id: Union[int, str] = ""
27+
28+
29+
def group_metrics(statuses: List[Status]):
30+
task2metrics = {}
31+
for status in statuses:
32+
task_id = status.task_id
33+
metric = status.metric
34+
if task_id not in task2metrics:
35+
task2metrics[task_id] = metric
36+
else:
37+
for k, v in metric.items():
38+
task2metrics[task_id][k] += v # type: ignore
39+
metric_list = [
40+
{k: sum(v) / len(v) if isinstance(v, list) else v for k, v in metrics.items()}
41+
for metrics in task2metrics.values()
42+
]
43+
return metric_list
2644

2745

2846
class WorkflowRunner:
@@ -144,22 +162,27 @@ async def run_task(
144162
for k, v in exp.metrics.items():
145163
metrics[k].append(v)
146164
# We get the average of metrics into the state
147-
metric = {}
148-
metric["time_per_task"] = time.time() - st
149-
if metrics:
150-
for k, v in metrics.items():
151-
metric[k] = sum(v) / len(v) # type: ignore
165+
metric: dict[str, Union[float, List[float]]] = {"time_per_task": time.time() - st}
166+
metric.update(metrics)
152167

153168
if task.is_eval:
154169
# If the task is an evaluation task, we do not record the experiences to the buffer
155-
return Status(True, metric=metric), []
170+
return Status(True, metric=metric, task_id=task.task_id), []
156171
else:
157-
return Status(True, metric=metric), exps
172+
return Status(True, metric=metric, task_id=task.task_id), exps
158173

159174
except Exception as e:
160175
error_trace_back = traceback.format_exc()
161176
self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}")
162-
return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), []
177+
return (
178+
Status(
179+
False,
180+
metric={"time_per_task": time.time() - st},
181+
message=str(e),
182+
task_id=task.task_id,
183+
),
184+
[],
185+
)
163186

164187

165188
class DebugWorkflowRunner(WorkflowRunner):

0 commit comments

Comments
 (0)