Skip to content
Merged
126 changes: 96 additions & 30 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import ray
import torch
from parameterized import parameterized

from tests.tools import get_template_config
from trinity.common.config import ExperienceBufferConfig
Expand Down Expand Up @@ -46,17 +47,23 @@ def run(self) -> List[Experience]:
elif self.error_type == "auxiliary_models":
assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2

return [
Experience(
tokens=torch.zeros(5),
prompt_length=2,
prompt_text=self.error_type or "success",
eid=EID(run=i + self.run_id_base, step=step),
info={"repeat_times": self.repeat_times},
)
for step in range(self.step_num)
for i in range(self.repeat_times)
]
exps = []
for i in range(self.repeat_times):
run_level_metrics = {"run_metrics": float(i + self.run_id_base)}
run_level_exps = []
for step in range(self.step_num):
run_level_exps.append(
Experience(
tokens=torch.zeros(5),
prompt_length=2,
prompt_text=self.error_type or "success",
eid=EID(run=i + self.run_id_base, step=step),
info={"repeat_times": self.repeat_times},
)
)
run_level_exps[-1].metrics = run_level_metrics
exps.extend(run_level_exps)
return exps


@WORKFLOWS.register_module("dummy_nonrepeat_workflow")
Expand All @@ -67,22 +74,29 @@ def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.reset_flag = False
self.step_num = task.workflow_args.get("step_num", 1)
self.metrics = task.workflow_args.get("metrics", [0])

def reset(self, task: Task):
self.task = task
self.reset_flag = True
self.step_num = task.workflow_args.get("step_num", 1)
self.metrics = task.workflow_args.get("metrics", [0])

def run(self) -> List[Experience]:
return [
exps = [
Experience(
eid=EID(run=self.run_id_base, step=step),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
info={"reset_flag": self.reset_flag},
metrics={
"run_metrics": self.metrics[step % len(self.metrics)],
},
)
for step in range(self.step_num)
]
return exps


@WORKFLOWS.register_module("dummy_async_workflow")
Expand All @@ -99,16 +113,22 @@ def set_repeat_times(self, repeat_times, run_id_base):
self.run_id_base = run_id_base

async def run_async(self):
return [
Experience(
eid=EID(run=i + self.run_id_base, step=step),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
)
for step in range(self.step_num)
for i in range(self.repeat_times)
]
exps = []
for i in range(self.repeat_times):
run_level_metrics = {"run_metrics": float(i + self.run_id_base)}
run_level_exps = []
for step in range(self.step_num):
run_level_exps.append(
Experience(
eid=EID(run=i + self.run_id_base, step=step),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text="success",
)
)
run_level_exps[-1].metrics = run_level_metrics
exps.extend(run_level_exps)
return exps

def run(self):
raise RuntimeError("This method should not be called")
Expand Down Expand Up @@ -490,7 +510,7 @@ async def test_split_tasks(self):
tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4
scheduler.schedule(tasks, batch_id=1)
statuses, exps = await scheduler.get_results(batch_id=1)
self.assertEqual(len(statuses), 4 * 4)
self.assertEqual(len(statuses), 4)
self.assertEqual(len(exps), 4 * 8)
exp_list.extend(exps)
_, exps = await scheduler.get_results(batch_id=1, min_num=1, timeout=1)
Expand All @@ -499,7 +519,7 @@ async def test_split_tasks(self):
tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3
scheduler.schedule(tasks, batch_id=2)
statuses, exps = await scheduler.get_results(batch_id=2)
self.assertEqual(len(statuses), 4 * 3)
self.assertEqual(len(statuses), 4)
self.assertEqual(len(exps), 4 * 5)
exp_list.extend(exps)
_, exps = await scheduler.get_results(batch_id=2, min_num=1, timeout=1)
Expand All @@ -508,7 +528,7 @@ async def test_split_tasks(self):
tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1
scheduler.schedule(tasks, batch_id=3)
statuses, exps = await scheduler.get_results(batch_id=3)
self.assertEqual(len(statuses), 3 * 1)
self.assertEqual(len(statuses), 3)
self.assertEqual(len(exps), 3 * 1)
exp_list.extend(exps)
_, exps = await scheduler.get_results(batch_id=3, min_num=1, timeout=1)
Expand All @@ -535,7 +555,7 @@ async def test_multi_step_execution(self):
for i in range(1, n_steps + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), 2 * 4)
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 2 * 4)

await scheduler.stop()
Expand All @@ -553,7 +573,7 @@ async def test_non_repeatable_workflow(self):
for i in range(1, batch_num + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), task_num * repeat_times / 2)
self.assertEqual(len(statuses), task_num)
self.assertEqual(len(exps), task_num * repeat_times)
exp_list.extend(exps)

Expand Down Expand Up @@ -594,7 +614,7 @@ async def test_async_workflow(self):
for i in range(1, batch_num + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), task_num * repeat_times / 2)
self.assertEqual(len(statuses), task_num)
self.assertEqual(len(exps), task_num * repeat_times * step_num)
exp_list.extend(exps)

Expand Down Expand Up @@ -624,7 +644,7 @@ async def test_stepwise_experience_eid(self):
for i in range(1, batch_num + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), task_num * repeat_times / 2)
self.assertEqual(len(statuses), task_num)
self.assertEqual(len(exps), task_num * repeat_times * step_num)
exp_list.extend(exps)

Expand All @@ -644,7 +664,7 @@ async def test_stepwise_experience_eid(self):
for i in range(1, batch_num + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), task_num * repeat_times / 2)
self.assertEqual(len(statuses), task_num)
self.assertEqual(len(exps), task_num * repeat_times * step_num)
exp_list.extend(exps)

Expand All @@ -656,6 +676,52 @@ async def test_stepwise_experience_eid(self):
unique_ids = [exp.eid.uid for exp in exp_list]
self.assertEqual(len(unique_ids), len(set(unique_ids)))

@parameterized.expand(
[
(2,),
(None,),
]
)
async def test_metric_calculation_with_repeatable_workflow(self, max_repeat_times_per_runner):
self.config.explorer.max_repeat_times_per_runner = max_repeat_times_per_runner
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()
tasks = []
tasks.extend(generate_tasks(total_num=1, step_num=1, repeat_times=4, repeatable=True))
tasks.extend(generate_tasks(total_num=1, step_num=4, repeat_times=8, repeatable=True))
scheduler.schedule(tasks, batch_id=0)
statuses, exps = await scheduler.get_results(batch_id=0)
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 1 * 4 * 1 + 1 * 8 * 4)
self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 1.5) # (0+1+2+3)/4
self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 3.5) # (0+1+2+3+4+5+6+7)/8

@parameterized.expand(
[
(2,),
(None,),
]
)
async def test_metric_calculation_with_non_repeatable_workflow(
self, max_repeat_times_per_runner
):
self.config.explorer.max_repeat_times_per_runner = max_repeat_times_per_runner
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()
tasks = []
tasks.extend(generate_tasks(total_num=1, step_num=3, repeat_times=4, repeatable=False))
tasks[-1].workflow_args["metrics"] = [1.0, 2.0, 3.0]
tasks.extend(generate_tasks(total_num=1, step_num=8, repeat_times=5, repeatable=False))
tasks[-1].workflow_args["metrics"] = [2 * i for i in range(8)]
scheduler.schedule(tasks, batch_id=0)
statuses, exps = await scheduler.get_results(batch_id=0)
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 1 * 4 * 3 + 1 * 5 * 8)
self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 2.0) # (1+2+3)/3
self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 7.0) # (0+2+4+6+8+10+12+14)/8

def tearDown(self):
try:
ray.shutdown()
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,3 +987,41 @@ def test_trainer(self):

def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)


class TestOverRollout(BaseTrainerCase):
def test_trainer(self):
self.config.algorithm.repeat_times = 4
self.config.buffer.batch_size = 4
self.config.buffer.total_steps = 2
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.name = f"explore-over-rollout-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.explorer.over_rollout_rate = 0.5 # set over rollout rate to 50%, which means only wait for 2 (4 * 50%) tasks in each steps
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.advantage_fn_args = {
"epsilon": 1e-6,
}
self.config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
self.config.synchronizer.sync_interval = 1
self.config.check_and_update()
both(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) == 0)
eval_metrics = parser.metric_list("eval")
self.assertTrue(len(eval_metrics) == 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
self.assertTrue(parser.metric_exist("experience_pipeline/experience_count"))
experience_counts = parser.metric_values("experience_pipeline/experience_count")
self.assertTrue(len(experience_counts) == 2)
for count in experience_counts:
self.assertEqual(count, 2 * 4) # only process 2 tasks in each step, repeat_times is 4
pg_loss = parser.metric_values("actor/pg_loss")
self.assertEqual(len(pg_loss), 1) # trainer only has 1 step
exp_save_path = self.config.buffer.trainer_input.experience_buffer.path
with open(exp_save_path, "r", encoding="utf-8") as f:
lines = f.readlines()
self.assertTrue(
len(lines), 2 * 4 * 2
) # step * repeat_times * batch_size * (1-over_rollout_rate)
13 changes: 13 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,9 @@ class ExplorerConfig:
service_status_check_interval: int = 60
# keep at least 1 model in running status
min_running_model_num: int = 1
# Experimental: set to a positive value to enable over rollout
# If set, explorer will only wait for (1 - over_rollout_rate) * batch_size of tasks at each step
over_rollout_rate: float = 0.0


@dataclass
Expand Down Expand Up @@ -1198,6 +1201,16 @@ def check_and_update(self) -> Config: # noqa: C901
for args in rollout_args + length_args:
set_if_none(aux_model, args, getattr(self.model, args))

if not (0.0 <= self.explorer.over_rollout_rate < 1.0):
raise ValueError("over_rollout_rate should be in [0.0, 1.0)")
if (
self.explorer.over_rollout_rate > 0.0
and self.synchronizer.sync_style == SyncStyle.FIXED
):
raise ValueError(
"over_rollout_rate is not compatible with fixed sync_style, please set sync_style to `dynamic_by_explorer` or `dynamic_by_trainer`."
)

# for lora configs
if self.model.lora_configs is not None:
self.explorer.rollout_model.enable_lora = True
Expand Down
22 changes: 15 additions & 7 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import math
import os
import time
import traceback
Expand Down Expand Up @@ -62,10 +63,15 @@ def __init__(self, config: Config):
role=self.config.explorer.name,
config=config,
)
self.batch_size = config.buffer.batch_size
self.update_interval = (
self.config.synchronizer.sync_interval * self.config.buffer.batch_size
)
if config.explorer.over_rollout_rate > 0.0:
self.min_wait_num = math.ceil(
config.buffer.batch_size * (1 - config.explorer.over_rollout_rate)
)
self.logger.info(
f"Over rollout is enabled. Explorer will only wait for {self.min_wait_num} tasks in each step."
)
else:
self.min_wait_num = None
self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL
self.pending_eval_tasks = deque()

Expand Down Expand Up @@ -357,12 +363,14 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int
async def _finish_explore_step(self, step: int, model_version: int) -> None:
metric = {"rollout/model_version": model_version}
with Timer(metric, "time/wait_explore_step"):
statuses, exps = await self.scheduler.get_results(batch_id=step)
statuses, exps = await self.scheduler.get_results(
batch_id=step, min_num=self.min_wait_num
)
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
self.taskset.update(pipeline_metrics)
metric.update(pipeline_metrics)
if statuses:
metric.update(gather_metrics([status.metric for status in statuses], "rollout"))
metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout"))
self.monitor.log(metric, step=step)

async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
Expand All @@ -378,7 +386,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}")
metric.update(
gather_metrics(
[status.metric for status in eval_results], f"{prefix}/{eval_task_name}"
[status.metrics[0] for status in eval_results], f"{prefix}/{eval_task_name}"
)
)
if self.eval_start_time is not None:
Expand Down
Loading