diff --git a/pyproject.toml b/pyproject.toml index 61059b8c4e..7322f2e4c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ agent = [ "agentscope" ] rm_gallery = [ - "rm-gallery" + "rm-gallery>=0.1.1" ] dev = [ "pre-commit>=2.17.0", diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index ef8077ff79..20ecb179c9 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -403,31 +403,44 @@ async def test_split_tasks(self): self.config.check_and_update() scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() + exp_list = [] tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4 scheduler.schedule(tasks, batch_id=1) results = await scheduler.get_results(batch_id=1) self.assertEqual(len(results), 4 * 4) - self.assertEqual(len(self.queue.read(batch_size=4 * 8)), 4 * 8) + exps = self.queue.read(batch_size=4 * 8) + self.assertEqual(len(exps), 4 * 8) + exp_list.extend(exps) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3 - scheduler.schedule(tasks, batch_id=1) - results = await scheduler.get_results(batch_id=1) + scheduler.schedule(tasks, batch_id=2) + results = await scheduler.get_results(batch_id=2) self.assertEqual(len(results), 4 * 3) - self.assertEqual(len(self.queue.read(batch_size=4 * 5)), 4 * 5) + exps = self.queue.read(batch_size=4 * 5) + self.assertEqual(len(exps), 4 * 5) + exp_list.extend(exps) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1 - scheduler.schedule(tasks, batch_id=1) - results = await scheduler.get_results(batch_id=1) + scheduler.schedule(tasks, batch_id=3) + results = await scheduler.get_results(batch_id=3) self.assertEqual(len(results), 3 * 1) - self.assertEqual(len(self.queue.read(batch_size=3 * 1)), 3 * 1) + exps = self.queue.read(batch_size=3 * 1) + self.assertEqual(len(exps), 3 * 1) + exp_list.extend(exps) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) + # test group_id and unique_id + group_ids = [exp.group_id for exp in exp_list] + self.assertEqual(len(set(group_ids)), 11) # 4 + 4 + 3 + unique_ids = [exp.unique_id for exp in exp_list] + self.assertEqual(len(unique_ids), len(set(unique_ids))) + await scheduler.stop() def tearDown(self): diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 944d9573a0..20a8c5064c 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -22,6 +22,7 @@ class MockResponse: reward: float = 0.0 metrics: Optional[Dict[str, float]] = None info: Optional[Dict] = None + unique_id: Optional[str] = "0" class DummyWorkflow(Workflow): @@ -237,7 +238,6 @@ def test_gsm8k_workflow(self) -> None: self.assertEqual(experiences[2].reward, -0.1) self.assertEqual(experiences[3].reward, 1.1) - @unittest.skip("Skip for now, need to fix import issues of RM-Gallery") def test_rm_gallery_workflow(self) -> None: model = MagicMock() model.chat.return_value = [ diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 0db24ab3a5..c6858931b1 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -91,6 +91,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { "uid": np.array(experiences.group_ids), + "unique_ids": np.array(experiences.unique_ids), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index d578248c1d..f9df00ee4e 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -14,6 +14,7 @@ def to_data_proto(experiences: Experiences) -> DataProto: position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { "uid": np.array(experiences.group_ids), + "unique_ids": np.array(experiences.unique_ids), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/common/experience.py b/trinity/common/experience.py index fd499c1ba6..56bdd1c4c5 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -27,6 +27,7 @@ class Experience: info: Optional[dict] = None metrics: Optional[dict[str, float]] = None group_id: str = "" # for grpo + unique_id: str = "" def __post_init__(self): if self.action_mask is not None: @@ -96,6 +97,7 @@ class Experiences: prompt_length: int logprobs: Optional[Tensor] group_ids: List[str] + unique_ids: List[str] @property def batch_size(self) -> int: @@ -119,10 +121,12 @@ def gather_experiences( logprobs=torch.empty(0, dtype=torch.float32), prompt_length=torch.empty(0, dtype=torch.int32), group_ids=[], + unique_ids=[], ) max_prompt_length = max([exp.prompt_length for exp in experiences]) max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) group_ids = [exp.group_id for exp in experiences] + unique_ids = [exp.unique_id for exp in experiences] tokens_dtype = experiences[0].tokens.dtype tokens = torch.stack( [ @@ -209,6 +213,7 @@ def gather_experiences( return cls( group_ids=group_ids, + unique_ids=unique_ids, tokens=tokens, rewards=rewards, attention_masks=attention_masks, @@ -250,6 +255,7 @@ def gather_dpo_experiences( logprobs=torch.empty(0, dtype=torch.float32), prompt_length=torch.empty(0, dtype=torch.int32), group_ids=[], + unique_ids=[], ) # TODO: exp.tokens in DPO are prompt tokens @@ -262,6 +268,11 @@ def gather_dpo_experiences( max_response_length = max([len(response) for response in response_tokens]) # type: ignore group_ids = list(chain.from_iterable([repeat(exp.group_id, 2) for exp in experiences])) + unique_ids = list( + chain.from_iterable( + [(f"{exp.unique_id}/1", f"{exp.unique_id}/0") for exp in experiences] + ) + ) tokens_dtype = experiences[0].tokens.dtype tokens = torch.stack( [ @@ -298,6 +309,7 @@ def gather_dpo_experiences( return cls( group_ids=group_ids, + unique_ids=unique_ids, tokens=tokens, attention_masks=attention_masks, prompt_length=max_prompt_length, diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py index ee443ed085..c6fe5bb58a 100644 --- a/trinity/common/rewards/reward_fn.py +++ b/trinity/common/rewards/reward_fn.py @@ -69,7 +69,7 @@ def _build_sample_from_experience( ] sample = DataSample( - unique_id="0", # TODO: Generate unique ID + unique_id=experience.unique_id, input=to_rm_gallery_messages(messages), output=output, metadata=experience.info, diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index ffe684ef61..459793d2c8 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -56,7 +56,7 @@ def _create_runner(self): "env_vars": self.config.explorer.env_vars, }, ) - .remote(self.config, self.rollout_model, self.auxiliary_models) + .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) ) async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, int]: diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index e7eec9240f..b973a114d2 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -2,6 +2,7 @@ """The Workflow Runner Moudle.""" import time import traceback +import uuid from collections import defaultdict from dataclasses import dataclass from typing import List, Optional @@ -31,6 +32,7 @@ def __init__( config: Config, model: InferenceModel, auxiliary_models: Optional[List[InferenceModel]] = None, + runner_id: Optional[int] = None, ) -> None: self.config = config self.experience_buffer = get_buffer_writer( @@ -52,6 +54,7 @@ def __init__( self.auxiliary_models.append(api_client) self.logger = get_logger(__name__) self.workflow_instance = None + self.runner_id = runner_id def is_alive(self): return True @@ -78,8 +81,11 @@ def run_task(self, task: Task) -> Status: assert exps is not None and len(exps) > 0, "An empty experience is generated" metrics: dict[str, List[float]] = defaultdict(list) # set group id - for exp in exps: + for idx, exp in enumerate(exps): setattr(exp, "group_id", task.group_id) + setattr( + exp, "unique_id", f"{task.group_id}/{self.runner_id}/{str(uuid.uuid4())[:6]}" + ) if not hasattr(exp, "info") or exp.info is None: exp.info = {}