Skip to content

Commit f6554fe

Browse files
committed
revert metric modification
1 parent 47422fe commit f6554fe

File tree

5 files changed

+31
-82
lines changed

5 files changed

+31
-82
lines changed

tests/explorer/explorer_test.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424
from trinity.buffer import get_buffer_reader
2525
from trinity.cli.launcher import explore, run_stage
2626
from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig
27-
from trinity.common.constants import PLUGIN_DIRS_ENV_VAR, StorageType
27+
from trinity.common.constants import StorageType
2828
from trinity.explorer.explorer import Explorer
2929
from trinity.manager.state_manager import StateManager
30-
from trinity.utils.plugin_loader import load_plugins
3130

3231

3332
class BaseExplorerCase(RayUnittestBase):
@@ -46,22 +45,6 @@ def setUp(self):
4645
self.config.explorer.eval_interval = 4
4746

4847

49-
class TestExplorerCountdownMaxRepeatTimes(BaseExplorerCase):
50-
def test_explorer(self):
51-
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
52-
self.config.buffer.explorer_input.taskset.default_workflow_type = "custom_workflow"
53-
self.config.algorithm.repeat_times = 4
54-
self.config.explorer.max_repeat_times_per_runner = 3
55-
self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
56-
self.config.check_and_update()
57-
os.environ[PLUGIN_DIRS_ENV_VAR] = os.path.join("tests", "utils", "plugins")
58-
load_plugins()
59-
explore(self.config)
60-
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
61-
custom_metric_mean = parser.metric_values("rollout/custom_metric/mean")
62-
self.assertEqual(custom_metric_mean, [0.75] * 8)
63-
64-
6548
class TestExplorerCountdownEval(BaseExplorerCase):
6649
def test_explorer(self):
6750
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")

tests/utils/plugins/my_workflow.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import List
22

33
from trinity.common.workflows import WORKFLOWS, Workflow
4-
from trinity.common.workflows.workflow import MathWorkflow
54

65

76
@WORKFLOWS.register_module("my_workflow")
@@ -18,12 +17,3 @@ def set_repeat_times(self, repeat_times, run_id_base):
1817

1918
def run(self) -> List:
2019
return ["Hello world", "Hi"]
21-
22-
23-
@WORKFLOWS.register_module("custom_workflow")
24-
class CustomWorkflow(MathWorkflow):
25-
def run(self):
26-
responses = super().run()
27-
for i, response in enumerate(responses):
28-
response.metrics["custom_metric"] = i
29-
return responses

trinity/explorer/explorer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from trinity.common.models import create_inference_models
2727
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
2828
from trinity.explorer.scheduler import Scheduler
29-
from trinity.explorer.workflow_runner import group_metrics
3029
from trinity.manager.state_manager import StateManager
3130
from trinity.manager.synchronizer import Synchronizer
3231
from trinity.utils.annotations import Experimental
@@ -363,7 +362,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None:
363362
self.taskset.update(pipeline_metrics)
364363
metric.update(pipeline_metrics)
365364
if statuses:
366-
metric.update(gather_metrics(group_metrics(statuses), "rollout"))
365+
metric.update(gather_metrics([status.metric for status in statuses], "rollout"))
367366
self.monitor.log(metric, step=step)
368367

369368
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
@@ -377,7 +376,11 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
377376
return
378377
self.pending_eval_tasks.popleft()
379378
eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}")
380-
metric.update(gather_metrics(group_metrics(eval_results), f"{prefix}/{eval_task_name}"))
379+
metric.update(
380+
gather_metrics(
381+
[status.metric for status in eval_results], f"{prefix}/{eval_task_name}"
382+
)
383+
)
381384
if self.eval_start_time is not None:
382385
metric.update({"time/eval": time.time() - self.eval_start_time})
383386
self.eval_start_time = None

trinity/explorer/scheduler.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -381,21 +381,23 @@ async def get_results(
381381
statuses = []
382382
experiences = []
383383
completed_queue = self.completed_tasks.get(batch_id, deque())
384-
completed_count = len(completed_queue)
384+
for _ in range(min_num):
385+
if completed_queue:
386+
status, exps = completed_queue.pop()
387+
statuses.append(status)
388+
if isinstance(exps, list):
389+
experiences.extend(exps)
390+
else:
391+
experiences.append(exps)
392+
393+
if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]:
394+
del self.completed_tasks[batch_id]
395+
396+
completed_count = len(statuses)
385397
if completed_count < min_num:
386398
self.logger.warning(
387399
f"Timeout reached, only {completed_count}/{min_num} tasks completed"
388400
)
389-
while completed_queue:
390-
status, exps = completed_queue.pop()
391-
statuses.append(status)
392-
if isinstance(exps, list):
393-
experiences.extend(exps)
394-
else:
395-
experiences.append(exps)
396-
397-
if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]:
398-
del self.completed_tasks[batch_id]
399401

400402
return statuses, experiences
401403

trinity/explorer/workflow_runner.py

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import traceback
66
from collections import defaultdict
77
from dataclasses import dataclass
8-
from typing import List, Optional, Tuple, Union
8+
from typing import List, Optional, Tuple
99

1010
from trinity.buffer import get_buffer_reader
1111
from trinity.common.config import Config
@@ -21,32 +21,8 @@ class Status:
2121
"""Status of the task running result."""
2222

2323
ok: bool
24-
metric: dict[str, Union[float, List[float]]]
24+
metric: dict[str, float]
2525
message: Optional[str] = None
26-
task_id: Union[int, str] = ""
27-
28-
29-
def group_metrics(statuses: List[Status]):
30-
task2metrics = {}
31-
for status in statuses:
32-
task_id = status.task_id
33-
metric = status.metric
34-
if task_id not in task2metrics:
35-
task2metrics[task_id] = metric
36-
else:
37-
for k, v in metric.items():
38-
task2metrics[task_id][k] += v # type: ignore
39-
40-
metric_list = []
41-
for metrics in task2metrics.values():
42-
agg_metrics = {}
43-
for k, v in metrics.items():
44-
if isinstance(v, list):
45-
agg_metrics[k] = sum(v) / len(v)
46-
else:
47-
agg_metrics[k] = v
48-
metric_list.append(agg_metrics)
49-
return metric_list
5026

5127

5228
class WorkflowRunner:
@@ -167,28 +143,23 @@ async def run_task(
167143
exp.metrics = {}
168144
for k, v in exp.metrics.items():
169145
metrics[k].append(v)
170-
171-
metric: dict[str, Union[float, List[float]]] = {"time_per_task": time.time() - st}
172-
metric.update(metrics)
146+
# We get the average of metrics into the state
147+
metric = {}
148+
metric["time_per_task"] = time.time() - st
149+
if metrics:
150+
for k, v in metrics.items():
151+
metric[k] = sum(v) / len(v) # type: ignore
173152

174153
if task.is_eval:
175154
# If the task is an evaluation task, we do not record the experiences to the buffer
176-
return Status(True, metric=metric, task_id=task.task_id), []
155+
return Status(True, metric=metric), []
177156
else:
178-
return Status(True, metric=metric, task_id=task.task_id), exps
157+
return Status(True, metric=metric), exps
179158

180159
except Exception as e:
181160
error_trace_back = traceback.format_exc()
182161
self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}")
183-
return (
184-
Status(
185-
False,
186-
metric={"time_per_task": time.time() - st},
187-
message=str(e),
188-
task_id=task.task_id,
189-
),
190-
[],
191-
)
162+
return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), []
192163

193164

194165
class DebugWorkflowRunner(WorkflowRunner):

0 commit comments

Comments
 (0)