diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index e03426c46b..e488067f28 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -18,6 +18,7 @@ jobs: steps: - uses: actions/checkout@v4 with: + fetch-depth: 0 path: trinity-${{ github.run_id }} ref: refs/pull/${{ github.event.issue.number }}/head diff --git a/README.md b/README.md index 0684d5a364..eee0b6ca45 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,12 @@ It is designed to support diverse application scenarios and serve as a unified p ### Step 1: installation +Requirements: +- Python version >= 3.10, <= 3.12 +- CUDA version >= 12.4, <= 12.8 +- At least 2 GPUs + + Installation from source **(recommended)**: ```shell @@ -181,13 +187,15 @@ pip install -e .[flash_attn] # for zsh pip install -e .\[flash_attn\] # Try the following command if you encounter errors during flash-attn installation -# pip install flash-attn -v --no-build-isolation +# pip install flash-attn==2.8.0.post2 -v --no-build-isolation ``` Installation using pip: ```shell pip install trinity-rft==0.2.0 +# install flash-attn separately +pip install flash-attn==2.8.0.post2 ``` Installation from docker: @@ -206,13 +214,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest . docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest ``` - -**Requirements:** -Python version >= 3.10, -CUDA version >= 12.4, -and at least 2 GPUs. - - ### Step 2: prepare dataset and model diff --git a/README_zh.md b/README_zh.md index 0f652aa22f..6d0f8df8f2 100644 --- a/README_zh.md +++ b/README_zh.md @@ -151,6 +151,11 @@ Trinity-RFT是一个通用、灵活且易于使用的大语言模型强化微调 ### 第一步:安装 +环境要求: +- Python >= 3.10, <= 3.12 +- CUDA >= 12.4, <= 12.8 +- 至少 2 块 GPU + 源码安装 **(推荐)**: @@ -181,13 +186,15 @@ pip install -e .[flash_attn] # 适用于 zsh pip install -e .\[flash_attn\] # 如果安装 flash-attn 时遇到错误,可以尝试以下命令 -# pip install flash-attn -v --no-build-isolation +# pip install flash-attn==2.8.0.post2 -v --no-build-isolation ``` 使用 pip 安装: ```shell pip install trinity-rft==0.2.0 +# flash-attn 需要单独安装 +pip install flash-attn==2.8.0.post2 ``` 使用 Docker 安装: @@ -207,12 +214,6 @@ docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v = 3.10, -CUDA 版本 >= 12.4, -以及至少 2 块 GPU。 - - ### 第二步:准备数据集和模型 diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 0e08b8648b..52aa212b0a 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -142,7 +142,8 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.group_ids), + "uid": np.array([eid.tid for eid in experiences.eids]), + "unique_ids": np.array([eid.uid for eid in experiences.eids]), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/docs/sphinx_doc/source/tutorial/faq.md b/docs/sphinx_doc/source/tutorial/faq.md index e84639255c..cc6c3c461b 100644 --- a/docs/sphinx_doc/source/tutorial/faq.md +++ b/docs/sphinx_doc/source/tutorial/faq.md @@ -65,7 +65,7 @@ File ".../flash_attn/flash_attn_interface.py", line 15, in ‹module> ImportError: ... ``` -**A:** The `flash-attn` module is not properly installed. Try to fix it by running `pip install flash-attn` or `pip install flash-attn -v --no-build-isolation`. +**A:** The `flash-attn` module is not properly installed. Try to fix it by running `pip install flash-attn==2.8.0.post2` or `pip install flash-attn==2.8.0.post2 -v --no-build-isolation`. --- diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 230e55592e..ea4b966d47 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -90,6 +90,7 @@ algorithm: kl_penalty_fn: "none" kl_loss_fn: "k2" entropy_loss_fn: "default" + add_strategy: null ``` - `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`, `sft`, `mix`. @@ -99,7 +100,7 @@ algorithm: - `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward. - `kl_loss_fn`: The KL loss function used for computing KL loss. - `entropy_loss_fn`: The entropy loss function used for computing entropy loss. - +- `add_strategy`: Strategy for adding new experiences to the experience buffer. If set, explorer will collect experiences from workflow runners and pre-process them before adding to the buffer. --- diff --git a/pyproject.toml b/pyproject.toml index 621e5a989e..53ae5e7f4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10" dependencies = [ "verl==0.4.1", "ray[default]>=2.45.0", - "vllm==0.9.2", + "vllm>=0.9.1", "tensordict==0.6.2", "wandb", "omegaconf", diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 702e8b8ca1..11526c223f 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -30,7 +30,7 @@ class TestQueueBuffer(RayUnittestBaseAysnc): ) async def test_queue_buffer(self, name, use_priority_queue): meta = StorageConfig( - name="test_buffer", + name=name, algorithm_type="ppo", storage_type=StorageType.QUEUE, max_read_timeout=3, @@ -60,7 +60,6 @@ async def test_queue_buffer(self, name, use_priority_queue): exps = [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), - prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 22b1c739a6..7d43d04168 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -38,7 +38,6 @@ async def test_create_sql_buffer(self) -> None: prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), - action_mask=torch.tensor([j % 2 for j in range(i + 1)]), ) for i in range(1, put_batch_size + 1) ] @@ -54,7 +53,6 @@ async def test_create_sql_buffer(self) -> None: [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), - prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 947c4d4ecb..44f96e16f7 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -6,12 +6,141 @@ import torch from trinity.buffer.schema.sql_schema import ExperienceModel -from trinity.common.experience import Experience, Experiences +from trinity.common.experience import EID, Experience, Experiences db_url = os.path.join(os.path.dirname(__file__), "tmp", "test.db") dataset_path = os.path.join(os.path.dirname(__file__), "data") +class TestEID(unittest.TestCase): + def test_eid_properties(self): + # test properties + eid = EID(batch=1, task=2, run=3, step=4, suffix="abc123") + self.assertEqual(eid.uid, "1/2/3/4/abc123") + self.assertEqual(eid.sid, "1/2/4") + self.assertEqual(eid.rid, "1/2/3") + self.assertEqual(eid.tid, "1/2") + self.assertEqual(str(eid), "1/2/3/4/abc123") + self.assertIn("EID(batch=1, task=2, run=3, step=4, uuid=abc123)", repr(eid)) + + # test unique + eid1 = EID(batch=1, task=2, run=3, step=4) + eid2 = EID(batch=1, task=2, run=3, step=4) + self.assertNotEqual(eid1.suffix, eid2.suffix) + self.assertNotEqual(eid1.uid, eid2.uid) + + # test default + eid = EID() + eid2 = EID() + self.assertIsInstance(eid.suffix, str) + self.assertEqual(eid.batch, 0) + self.assertEqual(eid.task, 0) + self.assertEqual(eid.run, 0) + self.assertEqual(eid.step, 0) + self.assertNotEqual(eid.uid, eid2.uid) + + +class TestExperience(unittest.TestCase): + def test_single_turn_experience(self): + tokens = torch.tensor([10, 11, 12], dtype=torch.int32) + logprobs = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) + exp = Experience(tokens=tokens, logprobs=logprobs, reward=1.0, prompt_length=1) + self.assertEqual(exp.experience_type.name, "SINGLE_TURN") + self.assertTrue(torch.equal(exp.tokens, tokens)) + self.assertTrue(torch.equal(exp.logprobs, logprobs)) + self.assertEqual(exp.reward, 1.0) + self.assertEqual(exp.prompt_length, 1) + self.assertTrue(torch.equal(exp.action_mask, torch.tensor([0, 1, 1], dtype=torch.bool))) + + def test_multi_turn_experience(self): + tokens = torch.tensor([1, 2, 3, 4]) + logprobs = torch.tensor([0.1, 0.2, 0.3, 0.4]) + action_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool) + exp = Experience(tokens=tokens, logprobs=logprobs, reward=2.0, action_mask=action_mask) + self.assertEqual(exp.experience_type.name, "MULTI_TURN") + self.assertTrue(torch.equal(exp.action_mask, action_mask)) + self.assertEqual(exp.prompt_length, 1) + + def test_dpo_experience(self): + tokens = torch.tensor([1, 2]) + chosen = torch.tensor([3, 4]) + rejected = torch.tensor([5, 6]) + exp = Experience(tokens=tokens, chosen=chosen, rejected=rejected, reward=0.5) + self.assertEqual(exp.experience_type.name, "DPO") + self.assertTrue(torch.equal(exp.chosen, chosen)) + self.assertTrue(torch.equal(exp.rejected, rejected)) + self.assertEqual(exp.prompt_length, 2) + + def test_serialize_deserialize(self): + tokens = torch.tensor([1, 2, 3]) + exp = Experience(tokens=tokens, reward=1.23, prompt_length=1) + data = exp.serialize() + exp2 = Experience.deserialize(data) + self.assertTrue(torch.equal(exp.tokens, exp2.tokens)) + self.assertEqual(exp.reward, exp2.reward) + self.assertEqual(exp.prompt_length, exp2.prompt_length) + self.assertEqual(exp.experience_type, exp2.experience_type) + + def test_to_dict(self): + tokens = torch.tensor([1, 2, 3]) + exp = Experience( + tokens=tokens, reward=2.5, prompt_length=1, prompt_text="hi", response_text="yo" + ) + d = exp.to_dict() + self.assertIn("eid", d) + self.assertIn("type", d) + self.assertIn("reward", d) + self.assertEqual(d["prompt_text"], "hi") + self.assertEqual(d["response_text"], "yo") + self.assertEqual(d["reward"], 2.5) + + def test_gather(self): + # test empty gathering + batch = Experiences.gather_experiences([]) + self.assertEqual(batch.tokens.numel(), 0) + self.assertEqual(batch.rewards.numel(), 0) + self.assertEqual(batch.eids, []) + + # test single experience gathering + exp = Experience(tokens=torch.tensor([1, 2, 3]), reward=1.0, prompt_length=1) + batch = Experiences.gather_experiences([exp]) + self.assertEqual(batch.batch_size, 1) + self.assertTrue( + torch.equal(batch.tokens[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:]) + ) + self.assertEqual(batch.prompt_length, 1) + self.assertEqual(batch.rewards[0], 1.0) + + # test multiple experiences gathering + exps = [ + Experience(tokens=torch.tensor([1, 2]), reward=0.1, prompt_length=1), + Experience(tokens=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2), + ] + batch = Experiences.gather_experiences(exps) + self.assertEqual(batch.batch_size, 2) + self.assertEqual(batch.prompt_length, 2) + self.assertEqual(batch.tokens.shape[1], 3) + self.assertEqual(batch.rewards[0], 0.1) + self.assertEqual(batch.rewards[1], 0.2) + + def test_action_mask_and_logprobs_type(self): + exp = Experience(tokens=[1, 2, 3], logprobs=[0.1, 0.2, 0.3], prompt_length=1) + self.assertIsInstance(exp.tokens, torch.Tensor) + self.assertIsInstance(exp.logprobs, torch.Tensor) + self.assertIsInstance(exp.action_mask, torch.Tensor) + + def test_assertions(self): + # prompt_length must be > 0 + with self.assertRaises(AssertionError): + Experience(tokens=[1, 2, 3], prompt_length=0) + # tokens must be longer than prompt_length for single-turn + with self.assertRaises(AssertionError): + Experience(tokens=[1, 2], prompt_length=2) + # DPO: tokens must match prompt_length + exp = Experience(tokens=[1, 2], chosen=[3], rejected=[4], prompt_length=1) + exp.prompt_length = 2 # should automatically adjust + + class TestExperienceConversion(unittest.TestCase): """Test cases for ExperienceModel""" @@ -21,22 +150,20 @@ def test_experience_model_experience_conversion(self): reward = 0.6 prompt_length = 2 logprobs = torch.tensor([0, 0, 0.1], dtype=torch.float32) - action_mask = torch.tensor([1, 0, 1], dtype=torch.bool) experience = Experience( tokens=tokens, reward=reward, prompt_length=prompt_length, logprobs=logprobs, - action_mask=action_mask, ) model = ExperienceModel.from_experience(experience) - experience = model.to_experience() - self.assertTrue(torch.equal(experience.tokens, tokens)) - self.assertEqual(experience.prompt_length, prompt_length) - self.assertEqual(experience.reward, reward) - self.assertTrue(torch.equal(experience.logprobs, logprobs)) - self.assertTrue(torch.equal(experience.action_mask, action_mask)) + new_experience = model.to_experience() + self.assertTrue(torch.equal(new_experience.tokens, tokens)) + self.assertEqual(new_experience.prompt_length, prompt_length) + self.assertEqual(new_experience.reward, reward) + self.assertTrue(torch.equal(new_experience.logprobs, logprobs)) + self.assertTrue(torch.equal(new_experience.action_mask, experience.action_mask)) def test_batch_conversion(self): exps = [ @@ -45,33 +172,72 @@ def test_batch_conversion(self): prompt_length=1, reward=float(0.1), logprobs=torch.tensor([0, 0.1]), - action_mask=torch.tensor([1, 0]), ), Experience( tokens=torch.tensor([1, 2, 3]), prompt_length=2, reward=float(0.2), logprobs=torch.tensor([0, 0, 0.1]), - action_mask=torch.tensor([1, 0, 1]), ), + ] + batch = Experiences.gather_experiences(exps) + self.assertEqual(batch.batch_size, 2) + self.assertEqual(batch.prompt_length, 2) + prompt_length = batch.prompt_length + for i in range(batch.batch_size): + self.assertEqual(batch.rewards[i], exps[i].reward) + self.assertTrue( + torch.all( + batch.tokens[i][ + prompt_length + - exps[i].prompt_length : prompt_length + - exps[i].prompt_length + + exps[i].tokens.size(0) + ] + == exps[i].tokens + ) + ) + self.assertTrue( + torch.all( + batch.logprobs[i][ + prompt_length + - exps[i].prompt_length : prompt_length + + exps[i].tokens.size(0) + - exps[i].prompt_length + ] + == exps[i].logprobs + ) + ) + self.assertTrue( + torch.all( + batch.action_masks[i][ + prompt_length + - exps[i].prompt_length : prompt_length + - exps[i].prompt_length + + exps[i].action_mask.size(0) + ] + == exps[i].action_mask + ) + ) + + def test_multiturn_experience_batch_converstion(self): + exps = [ Experience( tokens=torch.tensor([1, 2, 3, 4]), - prompt_length=2, reward=float(0.3), logprobs=torch.tensor([0, 0, 0.1, 0.2]), action_mask=torch.tensor([1, 0, 1, 0]), ), Experience( tokens=torch.tensor([1, 2, 3, 4]), - prompt_length=3, reward=float(0.4), logprobs=torch.tensor([0, 0, 0, 0.1]), - action_mask=torch.tensor([1, 0, 1, 0]), + action_mask=torch.tensor([1, 0, 0, 1]), ), ] batch = Experiences.gather_experiences(exps) - self.assertEqual(batch.batch_size, 4) - self.assertEqual(batch.prompt_length, 3) + self.assertEqual(batch.batch_size, 2) + self.assertEqual(batch.prompt_length, 1) prompt_length = batch.prompt_length for i in range(batch.batch_size): self.assertEqual(batch.rewards[i], exps[i].reward) @@ -109,6 +275,37 @@ def test_batch_conversion(self): ) ) + def test_dpo_experience_batch_conversion(self): + exps = [ + Experience( + tokens=torch.tensor([1, 2]), + chosen=torch.tensor([3, 4]), + rejected=torch.tensor([5, 6]), + ), + Experience( + tokens=torch.tensor([7, 8, 9]), + chosen=torch.tensor([10, 11]), + rejected=torch.tensor([12, 13]), + ), + ] + batch = Experiences.gather_experiences(exps) + self.assertEqual(batch.batch_size, 4) + self.assertEqual(batch.prompt_length, 3) + prompt_length = batch.prompt_length + for i in range(batch.batch_size): + j = i // 2 + self.assertTrue( + torch.all( + batch.tokens[i][ + prompt_length + - exps[j].prompt_length : prompt_length + - exps[j].prompt_length + + exps[j].tokens.size(0) + ] + == exps[j].tokens + ) + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index cd697606cf..2c9c94f1bc 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -11,6 +11,7 @@ get_template_config, get_unittest_dataset_config, ) +from trinity.buffer import get_buffer_reader from trinity.cli.launcher import explore @@ -71,3 +72,56 @@ def test_explorer(self): eval_metrics = parser.metric_list("eval") self.assertTrue(len(eval_metrics) == 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + + +class TestExplorerWithAddStrategy(BaseExplorerCase): + def test_explorer(self): + import ray + + from trinity.explorer.explorer import Explorer + + self.config.algorithm.repeat_times = 2 + self.config.buffer.total_epochs = 1 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.buffer.explorer_input.add_strategy = "random" + self.config.name = f"explore-add-strategy-{datetime.now().strftime('%Y%m%d%H%M%S')}" + # some step may be skipped due to same reward + self.config.algorithm.add_strategy = "reward_variance" + self.config.check_and_update() + explorer = ( + ray.remote(Explorer) + .options( + name=self.config.explorer.name, + namespace=ray.get_runtime_context().namespace, + ) + .remote(self.config) + ) + ray.get(explorer.prepare.remote()) + ray.get(explorer.sync_weight.remote()) + ray.get(explorer.explore.remote()) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + eval_metrics = parser.metric_list("eval") + self.assertTrue(len(eval_metrics) == 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + self.assertTrue(parser.metric_exist("rollout/experience_count")) + experience_counts = parser.metric_values("rollout/experience_count") + self.assertTrue(len(experience_counts) == 4) + for count in experience_counts: + self.assertTrue(count >= 0) + self.assertTrue(count <= 2 * 4) # repeat_times * batch_size + self.assertTrue(count % 2 == 0) # should be multiple of repeat_times + + reader = get_buffer_reader( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) + exps = [] + try: + batch = reader.read() + exps.extend(batch) + except StopIteration: + pass + self.assertTrue(len(exps) <= 4 * 2 * 4) # step * repeat_times * batch_size + self.assertTrue(len(exps) % (2 * 4) == 0) # should be multiple of repeat_times + ray.get(explorer.shutdown.remote()) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index d25c394bff..6a88e2a125 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -19,8 +19,8 @@ @WORKFLOWS.register_module("dummy_workflow") class DummyWorkflow(Workflow): - def __init__(self, model, task, auxiliary_models): - super().__init__(model, task, auxiliary_models) + def __init__(self, *, task, model, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.error_type = task.raw_task.get("error_type", "") self.seconds = None self.repeat_times = task.rollout_args.n @@ -176,13 +176,14 @@ async def test_get_results(self): tasks = generate_tasks(8) scheduler.schedule(tasks, batch_id=0) - results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) - self.assertEqual(len(results), 8) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + self.assertEqual(len(statuses), 8) + self.assertEqual(len(exps), 0) self.assertEqual(len(self.queue.read(batch_size=8)), 8) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) - for result in results: + for result in statuses: self.assertTrue(result.ok) for batch_id in range(1, 4): @@ -191,8 +192,9 @@ async def test_get_results(self): for batch_id in range(1, 4): self.assertTrue(scheduler.has_step(batch_id)) - results = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10) - self.assertEqual(len(results), 4) + statuses, exps = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10) + self.assertEqual(len(statuses), 4) + self.assertEqual(len(exps), 0) self.assertFalse(scheduler.has_step(batch_id)) self.assertEqual(len(self.queue.read(batch_size=4)), 4) with self.assertRaises(TimeoutError): @@ -201,8 +203,9 @@ async def test_get_results(self): tasks = generate_tasks(3) scheduler.schedule(tasks, batch_id=4) self.assertTrue(scheduler.has_step(4)) - results = await scheduler.get_results(batch_id=4) - self.assertEqual(len(results), 3) + statuses, exps = await scheduler.get_results(batch_id=4) + self.assertEqual(len(statuses), 3) + self.assertEqual(len(exps), 0) self.assertFalse(scheduler.has_step(4)) self.assertEqual(len(self.queue.read(batch_size=3)), 3) @@ -211,11 +214,11 @@ async def test_get_results(self): scheduler.schedule(tasks, batch_id=0) start_time = time.time() - results = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=4, timeout=3) end_time = time.time() self.assertLessEqual(end_time - start_time, 5) - self.assertEqual(len(results), 2) + self.assertEqual(len(statuses), 2) self.assertEqual(len(self.queue.read(batch_size=2)), 2) # test run tasks after timeout @@ -223,10 +226,10 @@ async def test_get_results(self): scheduler.schedule(tasks, batch_id=0) # actor restart is slow, set a big timeout - results = await scheduler.get_results(batch_id=0, timeout=20) - self.assertEqual(len(results), 4) + statuses, exps = await scheduler.get_results(batch_id=0, timeout=20) + self.assertEqual(len(statuses), 4) - success_count = sum(1 for r in results if r.ok) + success_count = sum(1 for r in statuses if r.ok) self.assertEqual(success_count, 4) self.assertEqual(len(self.queue.read(batch_size=4)), 4) with self.assertRaises(TimeoutError): @@ -235,10 +238,10 @@ async def test_get_results(self): # test exception tasks tasks = generate_tasks(1, exception_num=3) scheduler.schedule(tasks, batch_id=1) - results = await scheduler.get_results(batch_id=1, timeout=5) - self.assertEqual(len(results), 4) + statuses, exps = await scheduler.get_results(batch_id=1, timeout=5) + self.assertEqual(len(statuses), 4) - success_count = sum(1 for r in results if r.ok) + success_count = sum(1 for r in statuses if r.ok) self.assertEqual(success_count, 1) self.assertEqual(len(self.queue.read(batch_size=1)), 1) with self.assertRaises(TimeoutError): @@ -247,11 +250,16 @@ async def test_get_results(self): # test clear_timeout_tasks tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3) scheduler.schedule(tasks, batch_id=2) - results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False) - self.assertEqual(len(results), 3) + statuses, exps = await scheduler.get_results( + batch_id=2, timeout=2, clear_timeout_tasks=False + ) + self.assertEqual(len(statuses), 3) self.assertEqual(len(self.queue.read(batch_size=3)), 3) - results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False) - self.assertEqual(len(results), 1) + statuses, exps = await scheduler.get_results( + batch_id=2, timeout=2, clear_timeout_tasks=False + ) + self.assertEqual(len(statuses), 1) + self.assertEqual(len(exps), 0) self.assertEqual(len(self.queue.read(batch_size=1)), 1) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) @@ -277,10 +285,10 @@ async def test_wait_all(self): self.assertEqual(len(scheduler.pending_tasks), 0) self.assertEqual(len(scheduler.running_tasks), 0) - results0 = await scheduler.get_results(batch_id=0, min_num=4, timeout=1) - results1 = await scheduler.get_results(batch_id=1, min_num=3, timeout=1) - self.assertEqual(len(results0), 4) - self.assertEqual(len(results1), 3) + status0, exps0 = await scheduler.get_results(batch_id=0, min_num=4, timeout=1) + status1, exps1 = await scheduler.get_results(batch_id=1, min_num=3, timeout=1) + self.assertEqual(len(status0), 4) + self.assertEqual(len(status1), 3) # test timeout tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10) @@ -342,9 +350,9 @@ async def schedule_tasks(batch_id, num_tasks): schedule_tasks(2, 2), ) - self.assertEqual(len(results[0]), 3) - self.assertEqual(len(results[1]), 4) - self.assertEqual(len(results[2]), 2) + self.assertEqual(len(results[0][0]), 3) + self.assertEqual(len(results[1][0]), 4) + self.assertEqual(len(results[2][0]), 2) await scheduler.stop() @@ -354,47 +362,55 @@ async def test_scheduler_restart_after_stop(self): await scheduler.start() tasks = generate_tasks(2) scheduler.schedule(tasks, batch_id=0) - results = await scheduler.get_results(batch_id=0, min_num=2, timeout=10) + results, exps = await scheduler.get_results(batch_id=0, min_num=2, timeout=10) self.assertEqual(len(results), 2) + self.assertEqual(len(exps), 0) await scheduler.stop() + self.config.explorer.collect_experiences = True await scheduler.start() - tasks = generate_tasks(3) + tasks = generate_tasks(3, repeat_times=2) scheduler.schedule(tasks, batch_id=1) - results = await scheduler.get_results(batch_id=1, min_num=3, timeout=10) + results, exps = await scheduler.get_results(batch_id=1, min_num=3, timeout=10) self.assertEqual(len(results), 3) + self.assertEqual(len(exps), 3 * 2) await scheduler.stop() async def test_scheduler_all_methods(self): + self.config.explorer.collect_experiences = True scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() tasks = generate_tasks(8) scheduler.schedule(tasks, batch_id=0) self.assertTrue(scheduler.has_step(0)) - results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) - self.assertEqual(len(results), 8) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=8, timeout=20) + self.assertEqual(len(statuses), 8) + self.assertEqual(len(exps), 8) scheduler.schedule(tasks, batch_id=1) scheduler.schedule(tasks[:4], batch_id=2) self.assertFalse(scheduler.has_step(0)) - results = await scheduler.get_results(batch_id=0, min_num=8) + statuses, exps = await scheduler.get_results(batch_id=0, min_num=8) self.assertFalse(scheduler.has_step(0)) - self.assertEqual(len(results), 0) # batch_id 0 has no more tasks + self.assertEqual(len(statuses), 0) # batch_id 0 has no more tasks + self.assertEqual(len(exps), 0) self.assertFalse(scheduler.has_step(0)) self.assertTrue(scheduler.has_step(1)) self.assertTrue(scheduler.has_step(2)) await scheduler.wait_all() st = time.time() - results = await scheduler.get_results(batch_id=1) + statuses, exps = await scheduler.get_results(batch_id=1) et = time.time() self.assertTrue(et - st < 1.0) - self.assertEqual(len(results), 8) + self.assertEqual(len(statuses), 8) + self.assertEqual(len(exps), 8) self.assertFalse(scheduler.has_step(1)) self.assertTrue(scheduler.has_step(2)) st = time.time() - results = await scheduler.get_results(batch_id=2) + statuses, exps = await scheduler.get_results(batch_id=2) et = time.time() self.assertTrue(et - st < 1.0) - self.assertEqual(len(results), 4) + self.assertEqual(len(statuses), 4) + self.assertEqual(len(exps), 4) self.assertFalse(scheduler.has_step(2)) await scheduler.stop() @@ -407,8 +423,8 @@ async def test_split_tasks(self): tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4 scheduler.schedule(tasks, batch_id=1) - results = await scheduler.get_results(batch_id=1) - self.assertEqual(len(results), 4 * 4) + statuses, exps = await scheduler.get_results(batch_id=1) + self.assertEqual(len(statuses), 4 * 4) exps = self.queue.read(batch_size=4 * 8) self.assertEqual(len(exps), 4 * 8) exp_list.extend(exps) @@ -417,8 +433,8 @@ async def test_split_tasks(self): tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3 scheduler.schedule(tasks, batch_id=2) - results = await scheduler.get_results(batch_id=2) - self.assertEqual(len(results), 4 * 3) + statuses, exps = await scheduler.get_results(batch_id=2) + self.assertEqual(len(statuses), 4 * 3) exps = self.queue.read(batch_size=4 * 5) self.assertEqual(len(exps), 4 * 5) exp_list.extend(exps) @@ -427,18 +443,18 @@ async def test_split_tasks(self): tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1 scheduler.schedule(tasks, batch_id=3) - results = await scheduler.get_results(batch_id=3) - self.assertEqual(len(results), 3 * 1) + statuses, exps = await scheduler.get_results(batch_id=3) + self.assertEqual(len(statuses), 3 * 1) exps = self.queue.read(batch_size=3 * 1) self.assertEqual(len(exps), 3 * 1) exp_list.extend(exps) with self.assertRaises(TimeoutError): self.queue.read(batch_size=1) - # test group_id and unique_id - group_ids = [exp.group_id for exp in exp_list] + # test task_id and unique_id + group_ids = [exp.eid.tid for exp in exp_list] self.assertEqual(len(set(group_ids)), 11) # 4 + 4 + 3 - unique_ids = [exp.unique_id for exp in exp_list] + unique_ids = [exp.eid.uid for exp in exp_list] self.assertEqual(len(unique_ids), len(set(unique_ids))) await scheduler.stop() diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 2132a591a0..cc85eaf8c7 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -31,7 +31,7 @@ class MockResponse: class DummyWorkflow(Workflow): def __init__(self, model, task: Task, auxiliary_models=None): - super().__init__(model, task, auxiliary_models) + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.obj = task.raw_task self.output_format = task.workflow_args["output_format"] diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py index 01aa2f3967..e35e6eab0d 100644 --- a/tests/utils/plugin_test.py +++ b/tests/utils/plugin_test.py @@ -11,7 +11,7 @@ class PluginActor: def run(self): my_plugin_cls = WORKFLOWS.get("my_workflow") - return my_plugin_cls(None, None).run() + return my_plugin_cls(task=None, model=None).run() class TestPluginLoader(unittest.TestCase): @@ -22,7 +22,7 @@ def test_load_plugins(self): load_plugins(Path(__file__).resolve().parent / "plugins") my_plugin_cls = WORKFLOWS.get("my_workflow") self.assertIsNotNone(my_plugin_cls) - my_plugin = my_plugin_cls(None, None, None) + my_plugin = my_plugin_cls(task=None, model=None, auxiliary_models=None) self.assertTrue(my_plugin.__module__.startswith("trinity.plugins")) res = my_plugin.run() self.assertEqual(res[0], "Hello world") diff --git a/tests/utils/plugins/my_workflow.py b/tests/utils/plugins/my_workflow.py index b999590a01..624969ee89 100644 --- a/tests/utils/plugins/my_workflow.py +++ b/tests/utils/plugins/my_workflow.py @@ -5,8 +5,8 @@ @WORKFLOWS.register_module("my_workflow") class MyWorkflow(Workflow): - def __init__(self, model, task, auxiliary_models=None): - super().__init__(model, task, auxiliary_models) + def __init__(self, *, task, model, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) def run(self) -> List: return ["Hello world", "Hi"] diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 667aa10d74..b5b03d2075 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -1,3 +1,4 @@ +from trinity.algorithm.add_strategy import ADD_STRATEGY, AddStrategy from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn @@ -18,4 +19,6 @@ "ENTROPY_LOSS_FN", "SampleStrategy", "SAMPLE_STRATEGY", + "AddStrategy", + "ADD_STRATEGY", ] diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py new file mode 100644 index 0000000000..d1bbc84e1c --- /dev/null +++ b/trinity/algorithm/add_strategy/__init__.py @@ -0,0 +1,11 @@ +from trinity.algorithm.add_strategy.add_strategy import ( + ADD_STRATEGY, + AddStrategy, + RewardVarianceAddStrategy, +) + +__all__ = [ + "ADD_STRATEGY", + "AddStrategy", + "RewardVarianceAddStrategy", +] diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py new file mode 100644 index 0000000000..ca81d6927b --- /dev/null +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Literal + +import numpy as np + +from trinity.buffer import BufferWriter +from trinity.common.experience import Experience +from trinity.utils.registry import Registry + +ADD_STRATEGY = Registry("add_strategy") + + +class AddStrategy(ABC): + def __init__(self, writer: BufferWriter, **kwargs) -> None: + self.writer = writer + + @abstractmethod + async def add(self, experiences: List[Experience], step: int) -> int: + """Add experiences to the buffer. + + Args: + experiences (`Experience`): The experiences to be added. + step (`int`): The current step number. + + Returns: + `int`: The number of experiences added to the buffer. + """ + + @classmethod + @abstractmethod + def default_args(cls) -> dict: + """Get the default arguments of the add strategy. + + Returns: + `dict`: The default arguments. + """ + + +@ADD_STRATEGY.register_module("reward_variance") +class RewardVarianceAddStrategy(AddStrategy): + """An example AddStrategy that filters experiences based on a reward variance threshold.""" + + def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None: + super().__init__(writer) + self.variance_threshold = variance_threshold + + async def add(self, experiences: List[Experience], step: int) -> int: + cnt = 0 + grouped_experiences = group_by(experiences, id_type="task") + for _, group_exps in grouped_experiences.items(): + if len(group_exps) < 2: + continue + # check if the rewards are the same + rewards = [exp.reward for exp in group_exps] + variance = np.var(rewards) + if variance <= self.variance_threshold: + continue + cnt += len(group_exps) + await self.writer.write_async(group_exps) + return cnt + + @classmethod + def default_args(cls) -> dict: + return {"variance_threshold": 0.0} + + +def group_by( + experiences: List[Experience], id_type: Literal["task", "run", "step"] +) -> Dict[str, List[Experience]]: + """Group experiences by ID.""" + if id_type == "task": + id_type = "tid" + elif id_type == "run": + id_type = "rid" + elif id_type == "step": + id_type = "sid" + else: + raise ValueError(f"Unknown id_type: {id_type}") + grouped = {} + for exp in experiences: + group_id = getattr(exp.eid, id_type) + if group_id not in grouped: + grouped[group_id] = [] + grouped[group_id].append(exp) + return grouped diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 54f5c3d296..cf2aaa823c 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -145,7 +145,7 @@ class DPOAlgorithm(AlgorithmType): @classmethod def default_config(cls) -> Dict: return { - "sample_strategy": "dpo", + "sample_strategy": "warmup", "policy_loss_fn": "dpo", "kl_loss_fn": "k2", "entropy_loss_fn": "default", @@ -153,7 +153,7 @@ def default_config(cls) -> Dict: @classmethod def check_config(cls, config: Config) -> None: - if config.model == "train": + if config.mode == "train": if ( config.buffer.trainer_input.experience_buffer is None or not config.buffer.trainer_input.experience_buffer.path diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 80a4af7d49..60f908afe2 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -91,8 +91,8 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.group_ids), - "unique_ids": np.array(experiences.unique_ids), + "uid": np.array([eid.tid for eid in experiences.eids]), + "unique_ids": np.array([eid.uid for eid in experiences.eids]), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 6e530d32ce..b923ab17a6 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -120,23 +120,3 @@ def warmup_state(self, step: int) -> Tuple[bool, bool]: @classmethod def default_args(cls) -> dict: return {} - - -@SAMPLE_STRATEGY.register_module("dpo") -class DPOSampleStrategy(WarmupSampleStrategy): - def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: - metrics = {} - with Timer(metrics, "read_time"): - if step <= self.sft_warmup_steps: - exp_list = self.sft_buffer.read() - else: - exp_list = self.exp_buffer.read() - repr_samples = representative_sample(exp_list) - with Timer(metrics, "gather_time"): - exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore - if self.trainer_type == "verl": - with Timer(metrics, "convert_time"): - data = to_data_proto(exps) - return data, metrics, repr_samples - else: - raise NotImplementedError(f"backend {self.trainer_type} is not supported") diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index f9df00ee4e..cba97e6d9e 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -13,8 +13,8 @@ def to_data_proto(experiences: Experiences) -> DataProto: cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { - "uid": np.array(experiences.group_ids), - "unique_ids": np.array(experiences.unique_ids), + "uid": np.array([eid.tid for eid in experiences.eids]), + "unique_ids": np.array([eid.uid for eid in experiences.eids]), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), diff --git a/trinity/buffer/__init__.py b/trinity/buffer/__init__.py index 7e11b73a44..e7cc4c1b9b 100644 --- a/trinity/buffer/__init__.py +++ b/trinity/buffer/__init__.py @@ -1,7 +1,15 @@ -from trinity.buffer.buffer import Buffer, get_buffer_reader, get_buffer_writer +from trinity.buffer.buffer import ( + Buffer, + BufferReader, + BufferWriter, + get_buffer_reader, + get_buffer_writer, +) __all__ = [ "Buffer", + "BufferReader", + "BufferWriter", "get_buffer_reader", "get_buffer_writer", ] diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 1f5e29a2b9..de3481224c 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -132,12 +132,12 @@ def read( tokens = self.tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt" )[0] - prompt_tokens = self.tokenizer.apply_chat_template( + prompt_tokens_ids = self.tokenizer.apply_chat_template( messages[:-1], add_generation_prompt=True, return_tensors="pt" )[0] experience = Experience( tokens=tokens, - prompt_length=len(prompt_tokens), + prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) @@ -155,13 +155,13 @@ def read( full_messages, add_generation_prompt=False, return_tensors="pt" )[0] - prompt_tokens = self.tokenizer.apply_chat_template( + prompt_tokens_ids = self.tokenizer.apply_chat_template( prompt_messages, add_generation_prompt=True, return_tensors="pt" )[0] experience = Experience( tokens=tokens, - prompt_length=len(prompt_tokens), + prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) @@ -171,10 +171,10 @@ def read( prompt = sample[self.prompt_key] response = sample[self.response_key] tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] - prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] + prompt_tokens_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] experience = Experience( tokens=tokens, - prompt_length=len(prompt_tokens), + prompt_length=len(prompt_tokens_ids), ) exp_list.append(experience) else: @@ -252,7 +252,6 @@ def read( )[0][prompt_length:] experience = Experience( tokens=prompt_tokens, - prompt_length=len(prompt_tokens), chosen=chosen_tokens, rejected=rejected_tokens, ) diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index 21289c7768..5ac3da2666 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -85,14 +85,13 @@ def from_messages( """Convert a list of messages into a single instance of SFT data.""" from trinity.common.models.utils import tokenize_and_mask_messages_hf - token_ids, action_mask = tokenize_and_mask_messages_hf( + tokens, action_mask = tokenize_and_mask_messages_hf( tokenizer=tokenizer, messages=messages, chat_template=chat_template, ) exp = Experience( - tokens=token_ids, - prompt_length=0, + tokens=tokens, action_mask=action_mask, info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])}, ) diff --git a/trinity/common/config.py b/trinity/common/config.py index 5a9494504d..13b37ac26d 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -226,6 +226,11 @@ class AlgorithmConfig: # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 + # the strategy for adding experiences to the buffer + add_strategy: Optional[str] = None + add_strategy_args: Optional[dict] = None + + # the strategy for sampling experiences from the buffer sample_strategy: Optional[str] = None sample_strategy_args: Optional[dict] = None @@ -341,6 +346,12 @@ class ExplorerConfig: # for benchmark bench_on_latest_checkpoint: bool = False # only benchmark the latest checkpoint + # ! DO NOT SET + # Explorer collects experiences from workflow runners + # some algorithms (e.g., DAPO) need to collect experiences generated by the same task and do some post-processing + # will automatically set to True if `algorithm.add_strategy` is not None + collect_experiences: bool = False + @dataclass class TrainerConfig: @@ -638,6 +649,7 @@ def _check_buffer(self) -> None: # noqa: C901 def _check_algorithm(self) -> None: from trinity.algorithm import ( + ADD_STRATEGY, ADVANTAGE_FN, ENTROPY_LOSS_FN, KL_FN, @@ -661,42 +673,23 @@ def _check_algorithm(self) -> None: if getattr(self.algorithm, key, None) is None: setattr(self.algorithm, key, value) - # TODO: simplify the following code - sample_strategy_cls = SAMPLE_STRATEGY.get(self.algorithm.sample_strategy) - if sample_strategy_cls is None: - raise ValueError(f"Invalid sample_strategy: {self.algorithm.sample_strategy}") - if self.algorithm.sample_strategy_args is None: - self.algorithm.sample_strategy_args = sample_strategy_cls.default_args() - - policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) - if policy_fn_cls is None: - raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}") - if self.algorithm.policy_loss_fn_args is None: - self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args() - - advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn) - if advantage_fn_cls is None: - raise ValueError(f"Invalid advantage_fn: {self.algorithm.advantage_fn}") - if self.algorithm.advantage_fn_args is None: - self.algorithm.advantage_fn_args = advantage_fn_cls.default_args() - - kl_loss_fn_cls = KL_FN.get(self.algorithm.kl_loss_fn) - if kl_loss_fn_cls is None: - raise ValueError(f"Invalid kl_loss_fn: {self.algorithm.kl_loss_fn}") - if self.algorithm.kl_loss_fn_args is None: - self.algorithm.kl_loss_fn_args = kl_loss_fn_cls.default_args() - - kl_penalty_fn_cls = KL_FN.get(self.algorithm.kl_penalty_fn) - if kl_penalty_fn_cls is None: - raise ValueError(f"Invalid kl_penalty_fn: {self.algorithm.kl_penalty_fn}") - if self.algorithm.kl_penalty_fn_args is None: - self.algorithm.kl_penalty_fn_args = kl_penalty_fn_cls.default_args() - - entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.algorithm.entropy_loss_fn) - if entropy_loss_fn_cls is None: - raise ValueError(f"Invalid entropy_loss_fn: {self.algorithm.entropy_loss_fn}") - if self.algorithm.entropy_loss_fn_args is None: - self.algorithm.entropy_loss_fn_args = entropy_loss_fn_cls.default_args() + def check_and_set(name, registry, args_attr): + fn_cls = registry.get(getattr(self.algorithm, name)) + if fn_cls is None: + raise ValueError(f"Invalid {name}: {getattr(self.algorithm, name)}") + if getattr(self.algorithm, args_attr) is None: + setattr(self.algorithm, args_attr, fn_cls.default_args()) + return fn_cls + + if self.algorithm.add_strategy is not None: + check_and_set("add_strategy", ADD_STRATEGY, "add_strategy_args") + self.explorer.collect_experiences = True + check_and_set("sample_strategy", SAMPLE_STRATEGY, "sample_strategy_args") + check_and_set("policy_loss_fn", POLICY_LOSS_FN, "policy_loss_fn_args") + check_and_set("advantage_fn", ADVANTAGE_FN, "advantage_fn_args") + check_and_set("kl_loss_fn", KL_FN, "kl_loss_fn_args") + check_and_set("kl_penalty_fn", KL_FN, "kl_penalty_fn_args") + check_and_set("entropy_loss_fn", ENTROPY_LOSS_FN, "entropy_loss_fn_args") def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 56bdd1c4c5..0c5f98e89c 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -3,8 +3,9 @@ from __future__ import annotations import pickle -from dataclasses import dataclass -from itertools import chain, repeat +import uuid +from dataclasses import dataclass, field +from enum import Enum from typing import List, Optional import torch @@ -12,59 +13,195 @@ @dataclass -class Experience: - """A single experience.""" +class EID(dict): + """Experience ID class to uniquely identify an experience. - tokens: Tensor # [seq] - prompt_length: int - logprobs: Optional[Tensor] = None # [seq] + To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows. + """ + + # TODO: do we need to add project/name here to make it unique across different projects? + # Batch number, e.g., the explorer step num + # Automatically set by the workflow runner + batch: int = 0 + # Task number, e.g., the task sequence in the batch, the first task in the batch has task=0 + # Automatically set by the workflow runner + task: int = 0 # Task sequence in the batch, e.g., the first task in the batch has task=0 + # Run id, e.g., the first run in the task has run=0 + # User should set this field in custom workflows when creating experiences + run: int = 0 + # Step number when running the task, e.g., the first step in the task has step=0 + # User should set this field in custom workflows when creating experiences + step: int = 0 + suffix: str = field( + default_factory=lambda: uuid.uuid4().hex[:6] + ) # Unique identifier suffix, e.g., a UUID + + @property + def uid(self) -> str: + """An unique identifier for the experience.""" + return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}" + + @property + def sid(self) -> str: + """Step ID of the experience. + + For example, experiences generated by all runs of a same task at the same step will have the same sid. + """ + return f"{self.batch}/{self.task}/{self.step}" + + @property + def rid(self) -> str: + """Run ID of the experience. + + For example, experiences generated by one run of a task at all steps will have the same run_id. + """ + return f"{self.batch}/{self.task}/{self.run}" + + @property + def tid(self) -> str: + """Task ID for the experience. + + For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid. + """ + return f"{self.batch}/{self.task}" + + def __str__(self): + return self.uid + + def __repr__(self): + return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})" + + +class ExperienceType(Enum): + """Enum for experience types.""" + + SINGLE_TURN = "single_turn" # Single-turn experience, e.g., a prompt-response pair + MULTI_TURN = "multi_turn" # Multi-turn experience, e.g., a conversation history + DPO = "dpo" # DPO experience, e.g., a chosen and rejected response pair + + +@dataclass +class Experience: + eid: EID = field(default_factory=EID) # Unique identifier for the experience + tokens: Optional[Tensor] = None # [seq_length] + logprobs: Optional[Tensor] = None # [seq_length] reward: Optional[float] = None - prompt_text: Optional[str] = None - response_text: Optional[str] = None - action_mask: Optional[Tensor] = None - chosen: Optional[Tensor] = None # for dpo - rejected: Optional[Tensor] = None # for dpo - info: Optional[dict] = None - metrics: Optional[dict[str, float]] = None - group_id: str = "" # for grpo - unique_id: str = "" - - def __post_init__(self): - if self.action_mask is not None: + # Type of the experience, automatically set based on the presence of action_mask or chosen/rejected + experience_type: ExperienceType = ExperienceType.SINGLE_TURN + info: Optional[dict] = field( + default_factory=dict + ) # Additional information about the experience + metrics: Optional[dict[str, float]] = field( + default_factory=dict + ) # Metrics associated with the experience + + # for single-turn experiences + prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks + response_text: Optional[str] = None # Text of the response + prompt_text: Optional[str] = None # Text of the prompt + + # for multi-turn experiences + action_mask: Optional[ + Tensor + ] = None # Action mask which indicates which tokens are generated by the model + messages: Optional[List[dict]] = None # List of messages + + # for dpo experiences + chosen: Optional[Tensor] = None # Token ids of the chosen response [resp_length] + rejected: Optional[Tensor] = None # Token ids of the rejected response [resp_length] + chosen_text: Optional[str] = None # Text of the chosen response + rejected_text: Optional[str] = None # Text of the rejected response + + def __init__( + self, + *, + eid=None, + tokens, + logprobs=None, + reward=None, + info=None, + metrics=None, + prompt_length=1, + response_text=None, + prompt_text=None, + action_mask=None, + messages=None, + chosen=None, + rejected=None, + chosen_text=None, + rejected_text=None, + ): + if action_mask is not None: + experience_type = ExperienceType.MULTI_TURN + elif chosen is not None and rejected is not None: + experience_type = ExperienceType.DPO + else: + experience_type = ExperienceType.SINGLE_TURN + + if experience_type == ExperienceType.SINGLE_TURN: + assert ( + prompt_length > 0 + ), "Prompt length must be greater than 0 for single-turn experiences." assert ( - self.action_mask.shape == self.tokens.shape - ), "The provided action_mask must have the same shape as tokens." + len(tokens) > prompt_length + ), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." + action_mask = torch.zeros(len(tokens), dtype=torch.bool) + action_mask[prompt_length:] = 1 + elif experience_type == ExperienceType.MULTI_TURN: + prompt_length = 1 + elif experience_type == ExperienceType.DPO: + prompt_length = len(tokens) + + self.eid = eid or EID() + self.tokens = tokens + self.logprobs = logprobs + self.reward = reward + self.experience_type = experience_type + self.info = info or {} + self.metrics = metrics or {} + self.prompt_length = prompt_length + self.response_text = response_text + self.prompt_text = prompt_text + self.action_mask = action_mask + self.messages = messages + self.chosen = chosen + self.rejected = rejected + self.chosen_text = chosen_text + self.rejected_text = rejected_text - # explicit type cast if not isinstance(self.tokens, Tensor): - self.tokens = Tensor(self.tokens) + self.tokens = torch.tensor(self.tokens) if self.logprobs is not None and not isinstance(self.logprobs, Tensor): - self.logprobs = Tensor(self.logprobs) + self.logprobs = torch.tensor(self.logprobs) if self.action_mask is not None and not isinstance(self.action_mask, Tensor): - self.action_mask = Tensor(self.action_mask) + self.action_mask = torch.tensor(self.action_mask) if self.chosen is not None and not isinstance(self.chosen, Tensor): - self.chosen = Tensor(self.chosen) + self.chosen = torch.tensor(self.chosen) if self.rejected is not None and not isinstance(self.rejected, Tensor): - self.rejected = Tensor(self.rejected) + self.rejected = torch.tensor(self.rejected) def serialize(self) -> bytes: """Serialize the experience to bytes.""" return pickle.dumps(self) - @staticmethod - def deserialize(data: bytes) -> Experience: - """Deserialize the experience from bytes.""" + @classmethod + def deserialize(cls, data: bytes) -> Experience: return pickle.loads(data) def to_dict(self) -> dict: """Convert the experience to a dictionary.""" res = { - "prompt_text": self.prompt_text, + "eid": self.eid, + "type": self.experience_type.value, "info": self.info, "metrics": self.metrics, } + if self.prompt_text is not None: + res["prompt_text"] = self.prompt_text if self.response_text is not None: res["response_text"] = self.response_text + if self.messages is not None: + res["messages"] = self.messages if self.chosen is not None: res["chosen"] = self.chosen.tolist() if self.rejected is not None: @@ -73,6 +210,91 @@ def to_dict(self) -> dict: res["reward"] = float(self.reward) return res + @classmethod + def gather(cls, experiences: List[Experience], pad_token_id: int = 0) -> Experiences: + if len(experiences) == 0: + return empty_experiences() + exp_type = experiences[0].experience_type + if exp_type == ExperienceType.DPO: + experiences = split_dpo_experience_to_single_turn(experiences) + max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var] + max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type] + eids = [exp.eid for exp in experiences] + + # Gather tokens + tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id) + + # Gather rewards + if experiences[0].reward is not None: + rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) + else: + rewards = None + + # gather action_masks + action_masks = gather_action_masks(experiences, max_prompt_length, max_response_length) + + # gather attention_masks + attention_masks = gather_attention_masks( + experiences, max_prompt_length, max_response_length + ) + + # gather logprobs + + if all(exp.logprobs is not None for exp in experiences): + logprobs = gather_logprobs(experiences, max_prompt_length, max_response_length) + else: + logprobs = None + + return Experiences( + eids=eids, + tokens=tokens, + rewards=rewards, + attention_masks=attention_masks, + action_masks=action_masks, + prompt_length=max_prompt_length, + logprobs=logprobs, + ) + + +def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]: + single_turn_experiences = [] + for exp in experiences: + single_turn_experiences.append( + Experience( + eid=EID( + batch=exp.eid.batch, + task=exp.eid.task, + step=exp.eid.step, + run=exp.eid.run, + ), + tokens=torch.cat([exp.tokens, exp.chosen]), + reward=exp.reward, + info=exp.info, + metrics=exp.metrics, + prompt_length=len(exp.tokens), # type: ignore [arg-type] + prompt_text=exp.prompt_text, + response_text=exp.chosen_text, + ) + ) + single_turn_experiences.append( + Experience( + eid=EID( + batch=exp.eid.batch, + task=exp.eid.task, + step=exp.eid.step, + run=exp.eid.run, + ), + tokens=torch.cat([exp.tokens, exp.rejected]), + reward=exp.reward, + info=exp.info, + metrics=exp.metrics, + prompt_length=len(exp.tokens), # type: ignore [arg-type] + prompt_text=exp.prompt_text, + response_text=exp.rejected_text, + ) + ) + return single_turn_experiences + @dataclass(frozen=True) class Experiences: @@ -90,14 +312,13 @@ class Experiences: >>> exp2: |......1111111111111|1111111........| """ + eids: List[EID] # Experience IDs of each experience in the batch tokens: Tensor rewards: Tensor attention_masks: Tensor action_masks: Optional[Tensor] prompt_length: int logprobs: Optional[Tensor] - group_ids: List[str] - unique_ids: List[str] @property def batch_size(self) -> int: @@ -113,207 +334,103 @@ def gather_experiences( This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length. """ if len(experiences) == 0: - return Experiences( - tokens=torch.empty(0, dtype=torch.int32), - rewards=torch.empty(0, dtype=torch.float32), - attention_masks=torch.empty(0, dtype=torch.bool), - action_masks=torch.empty(0, dtype=torch.bool), - logprobs=torch.empty(0, dtype=torch.float32), - prompt_length=torch.empty(0, dtype=torch.int32), - group_ids=[], - unique_ids=[], - ) - max_prompt_length = max([exp.prompt_length for exp in experiences]) - max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) - group_ids = [exp.group_id for exp in experiences] - unique_ids = [exp.unique_id for exp in experiences] - tokens_dtype = experiences[0].tokens.dtype - tokens = torch.stack( - [ - torch.cat( - [ - torch.full( - (max_prompt_length - exp.prompt_length,), - pad_token_id, - dtype=tokens_dtype, - ), - exp.tokens, - torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), - pad_token_id, - dtype=tokens_dtype, - ), - ] - ) - for exp in experiences - ] - ) - if experiences[0].reward is not None: - rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) - else: - rewards = None - - # Calculate the action_masks according to the provided experience.action_mask - if experiences[0].action_mask is not None: - action_mask_dtype = experiences[0].action_mask.dtype - action_masks = torch.stack( + return empty_experiences() + return experiences[0].__class__.gather(experiences, pad_token_id=pad_token_id) + + +def empty_experiences() -> Experiences: + return Experiences( + tokens=torch.empty(0, dtype=torch.int32), + rewards=torch.empty(0, dtype=torch.float32), + attention_masks=torch.empty(0, dtype=torch.bool), + action_masks=torch.empty(0, dtype=torch.bool), + logprobs=torch.empty(0, dtype=torch.float32), + prompt_length=torch.empty(0, dtype=torch.int32), + eids=[], + ) + + +def gather_token_ids( + experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int +) -> Tensor: + token_ids_dtype = experiences[0].tokens.dtype + return torch.stack( + [ + torch.cat( [ - torch.cat( - [ - torch.full( - (max_prompt_length - exp.prompt_length,), - 0, - dtype=action_mask_dtype, - ), - exp.action_mask, - torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), - 0, - dtype=action_mask_dtype, - ), - ] - ) - for exp in experiences + torch.full( + (max_prompt_length - exp.prompt_length,), + pad_token_id, + dtype=token_ids_dtype, + ), + exp.tokens, + torch.full( + (max_response_length + exp.prompt_length - len(exp.tokens),), + pad_token_id, + dtype=token_ids_dtype, + ), ] ) - else: - action_masks = None - attention_masks = torch.zeros( - (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool - ) - for i, exp in enumerate(experiences): - start = max_prompt_length - exp.prompt_length - end = start + len(exp.tokens) - attention_masks[i, start:end] = 1 + for exp in experiences + ] + ) - if all(exp.logprobs is not None for exp in experiences): - logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] - logprobs = torch.stack( + +def gather_action_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: + return torch.stack( + [ + torch.cat( [ - torch.cat( - [ - torch.full( - (max_prompt_length - exp.prompt_length,), - 0.0, - dtype=logprob_dtype, - ), - exp.logprobs, - torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), - 0.0, - dtype=logprob_dtype, - ), - ] - ) - for exp in experiences + torch.full( + (max_prompt_length - exp.prompt_length,), + 0, + dtype=torch.bool, + ), + exp.action_mask, + torch.full( + (max_response_length + exp.prompt_length - len(exp.tokens),), + 0, + dtype=torch.bool, + ), ] ) - else: - logprobs = None + for exp in experiences + ] + ) - return cls( - group_ids=group_ids, - unique_ids=unique_ids, - tokens=tokens, - rewards=rewards, - attention_masks=attention_masks, - action_masks=action_masks, - prompt_length=max_prompt_length, - logprobs=logprobs, - ) - @classmethod - def gather_dpo_experiences( - cls, experiences: list[Experience], pad_token_id: int = 0 - ) -> Experiences: - """Gather a batch of dpo experiences from a list of experiences. +def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: + attention_masks = torch.zeros( + (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool + ) - Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L849 + for i, exp in enumerate(experiences): + start = max_prompt_length - exp.prompt_length + end = start + len(exp.tokens) + attention_masks[i, start:end] = 1 - Note: We arrange inputs in the order of (chosen, rejected, chosen, rejected, ...) - to ensure that each pair of (chosen, rejected) is not split by subsequent operations + return attention_masks - Args: - Experiences: `(list[Experience])` - - `"prompt"`: token ids of the prompt - - `"chosen"`: token ids of the chosen response - - `"rejected"`: token ids of the rejected response - pad_token_id: `(int)` - The pad token id. - Returns: - Experiences: - - `"tokens"`: Concatenated chosen and rejected completion input IDs of shape `(2 * batch_size, max_completion_length)`. - - `"attention_masks"`: Concatenated chosen and rejected attention masks of shape `(2 * batch_size, max_completion_length)`. - """ - if len(experiences) == 0: - return Experiences( - tokens=torch.empty(0, dtype=torch.int32), - rewards=torch.empty(0, dtype=torch.float32), - attention_masks=torch.empty(0, dtype=torch.bool), - action_masks=torch.empty(0, dtype=torch.bool), - logprobs=torch.empty(0, dtype=torch.float32), - prompt_length=torch.empty(0, dtype=torch.int32), - group_ids=[], - unique_ids=[], - ) - - # TODO: exp.tokens in DPO are prompt tokens - prompt_tokens = list(chain.from_iterable([repeat(exp.tokens, 2) for exp in experiences])) - max_prompt_length = max([exp.prompt_length for exp in experiences]) - - chosen_tokens = [exp.chosen for exp in experiences] - rejected_tokens = [exp.rejected for exp in experiences] - response_tokens = list(chain.from_iterable(zip(chosen_tokens, rejected_tokens))) - max_response_length = max([len(response) for response in response_tokens]) # type: ignore - - group_ids = list(chain.from_iterable([repeat(exp.group_id, 2) for exp in experiences])) - unique_ids = list( - chain.from_iterable( - [(f"{exp.unique_id}/1", f"{exp.unique_id}/0") for exp in experiences] +def gather_logprobs(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: + logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] + return torch.stack( + [ + torch.cat( + [ + torch.full( + (max_prompt_length - exp.prompt_length,), + 0.0, + dtype=logprob_dtype, + ), + exp.logprobs, + torch.full( + (max_response_length + exp.prompt_length - len(exp.tokens),), + 0.0, + dtype=logprob_dtype, + ), + ] ) - ) - tokens_dtype = experiences[0].tokens.dtype - tokens = torch.stack( - [ - torch.cat( - [ - torch.full( - (max_prompt_length - len(prompt),), - pad_token_id, - dtype=tokens_dtype, - ), - prompt, - response, - torch.full( - (max_response_length - len(response),), # type: ignore - pad_token_id, - dtype=tokens_dtype, - ), - ] - ) - for prompt, response in zip(prompt_tokens, response_tokens) - ] - ) - - attention_masks = torch.zeros( - (len(tokens), max_prompt_length + max_response_length), dtype=torch.bool - ) - - for (i, prompt), response in zip(enumerate(prompt_tokens), response_tokens): - start = max_prompt_length - len(prompt) - end = max_prompt_length + len(response) # type: ignore - attention_masks[i, start:end] = 1 - - assert len(tokens) == 2 * len(experiences) - - return cls( - group_ids=group_ids, - unique_ids=unique_ids, - tokens=tokens, - attention_masks=attention_masks, - prompt_length=max_prompt_length, - rewards=None, - action_masks=None, - logprobs=None, - ) + for exp in experiences + ] + ) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 2568c88005..96d700678e 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -4,7 +4,7 @@ import socket import time from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Union +from typing import Any, List, Sequence, Tuple, Union import openai import ray @@ -18,11 +18,11 @@ class InferenceModel(ABC): """A model for high performance for rollout inference.""" - async def generate(self, prompt: str, **kwargs) -> List[Experience]: + async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: """Generate a responses from a prompt in async.""" raise NotImplementedError - async def chat(self, messages: List[dict], **kwargs) -> List[Experience]: + async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: """Generate experiences from a list of history chat messages in async.""" raise NotImplementedError @@ -207,7 +207,6 @@ def convert_api_output_to_experience( ) ), prompt_length=len(output.prompt_token_ids), - prompt_text=None, response_text=choice.message.content, ) for choice in output.choices diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 01b8135511..30c3e00b3e 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -3,7 +3,7 @@ import os import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import aiohttp import ray @@ -29,7 +29,6 @@ class vLLMRolloutModel(InferenceModel): Args: config (Config): The config. - kwargs (dict): The keyword arguments for the engine. """ def __init__( @@ -103,7 +102,7 @@ def __init__( self.api_server_host = None self.api_server_port = None - async def chat(self, messages: List[Dict], **kwargs) -> List[Experience]: + async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]: """Chat with the model with a list of messages in async. Args: @@ -134,7 +133,7 @@ async def chat(self, messages: List[Dict], **kwargs) -> List[Experience]: ) return await self.generate(prompt=prompt, **kwargs) - async def generate(self, prompt: str, **kwargs) -> List[Experience]: + async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: """Generate a response from the provided prompt in async. Args: @@ -224,7 +223,6 @@ async def convert_messages_to_experience(self, messages: List[dict]) -> Experien logprobs = await self.logprobs(token_ids=token_ids.tolist()) return Experience( tokens=token_ids, - prompt_length=len(token_ids), logprobs=logprobs, action_mask=action_mask, ) diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index 45940fdfae..f8e1d41720 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -19,14 +19,15 @@ class MathRMWorkflow(SimpleWorkflow): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.reset(task) super().__init__( - model=model, task=task, + model=model, auxiliary_models=auxiliary_models, ) diff --git a/trinity/common/workflows/step_wise_workflow.py b/trinity/common/workflows/step_wise_workflow.py new file mode 100644 index 0000000000..db571a7b71 --- /dev/null +++ b/trinity/common/workflows/step_wise_workflow.py @@ -0,0 +1,126 @@ +from abc import abstractmethod + +import openai + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import Task, Workflow + + +class StepWiseRewardWorkflow(Workflow): + """A workflow that implements step-wise rewards for tasks.""" + + def __init__(self, *, task: Task, model: ModelWrapper, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + assert model.enable_history, ( + "Rollout Model must have history enabled for step-wise rewards, please " + "set `explorer.rollout_model.enable_history` to `True` in your config." + ) + # use the rollout model's OpenAI client to write your agent application + self.client: openai.OpenAI = model.get_openai_client() + + def run(self) -> list[Experience]: + """Run the workflow and return a list of experiences with step-wise rewards.""" + experiences = [] + for step in range(self.max_step_num): + # Run a single step of the agent application + continue_run = self.step(step_num=step) + # Collect experiences data of the current step + exps = self.model.extract_experience_from_history() + # Calculate the reward for the current step + reward = self.reward(exps, step_num=step) + for exp in exps: + exp.reward = reward + # set the step number in each experience + exp.eid.step = step + # Store the step experiences + experiences.extend(exps) + if not continue_run: + break + + return experiences + + @abstractmethod + def step(self, step_num: int) -> bool: + """Run a single step of your agent application. + + Args: + step_num (int): The current step number. + + Returns: + bool: Whether to continue running the agent application. + + Tips: + You can use the openai client (`self.client`) to migrate your existing + applications at low cost. + """ + pass + + @abstractmethod + def reward(self, exps: list[Experience], step_num: int) -> float: + """Calculate the reward for the given experiences at the specified step.""" + pass + + @property + @abstractmethod + def max_step_num(self): + """Return the maximum number of steps in the task.""" + + +class RewardPropagationWorkflow(Workflow): + """A workflow that propagates rewards across multiple turns.""" + + def __init__(self, *, task: Task, model: ModelWrapper, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + assert model.enable_history, ( + "Rollout Model must have history enabled for step-wise rewards, please " + "set `explorer.rollout_model.enable_history` to `True` in your config." + ) + # use the rollout model's OpenAI client to write your agent application + self.client: openai.OpenAI = model.get_openai_client() + + def run(self) -> list[Experience]: + """Run the workflow and return a list of experiences with step-wise rewards.""" + experiences = [] + for step in range(self.max_step_num): + # Run a single step of the agent application + continue_run = self.step(step_num=step) + # Collect experiences data of the current step + exps = self.model.extract_experience_from_history() + # set the step number in each experience + for exp in exps: + exp.eid.step = step + # Store the step experiences + experiences.extend(exps) + if not continue_run: + break + reward = self.reward(experiences) + for exp in experiences: + exp.reward = reward + return experiences + + @abstractmethod + def step(self, step_num: int) -> bool: + """Run a single step of your agent application. + + Args: + step_num (int): The current step number. + + Returns: + bool: Whether to continue running the agent application. + + Tips: + You can use the openai client (`self.client`) to migrate your existing + applications at low cost. + """ + pass + + @abstractmethod + def reward(self, exps: list[Experience]) -> float: + """Calculate the reward for the given experiences of the entire run.""" + pass + + @property + @abstractmethod + def max_step_num(self): + """Return the maximum number of steps in the task.""" diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index e9549d9e2e..ceabdc771a 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -8,7 +8,6 @@ from typing import Any, List, Optional, Type, Union import openai -import torch from trinity.common.config import FormatConfig, GenerationConfig from trinity.common.experience import Experience @@ -25,7 +24,7 @@ @dataclass -class Task: +class Task(dict): """A Task class that defines a task and its associated reward function / workflow.""" workflow: Type[Workflow] @@ -37,7 +36,9 @@ class Task: reward_fn: Optional[Type[RewardFn]] = None raw_task: Optional[dict] = None # The raw data sample - group_id: Optional[str] = None # for GRPO-like algorithms, automatically assigned + # automatically assigned ids + batch_id: int = 0 + task_id: int = 0 def to_workflow( self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None @@ -84,10 +85,12 @@ class Workflow(ABC): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): + self.task = task self.model = model self.auxiliary_models = auxiliary_models @@ -102,6 +105,7 @@ def reset(self, task: Task): @abstractmethod def run(self) -> List[Experience]: """Run workflow and return a list of experiences.""" + raise NotImplementedError class MultiTurnWorkflow(Workflow): @@ -111,13 +115,14 @@ class MultiTurnWorkflow(Workflow): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): super().__init__( - model=model, task=task, + model=model, auxiliary_models=auxiliary_models, ) @@ -135,8 +140,6 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc log_probs = log_probs * generation_mask assert tokens.shape == log_probs.shape - # set prompt length to the first 1 in the gen_mask - prompt_length = torch.where(generation_mask == 1)[0][0].item() metrics = {} for k, v in info.items(): @@ -145,7 +148,6 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc experience = Experience( tokens=tokens, - prompt_length=prompt_length, action_mask=generation_mask, reward=reward, logprobs=log_probs, @@ -161,14 +163,15 @@ class SimpleWorkflow(Workflow): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.reset(task) super().__init__( - model=model, task=task, + model=model, auxiliary_models=auxiliary_models, ) @@ -236,14 +239,15 @@ class MathWorkflow(SimpleWorkflow): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.reset(task) super().__init__( - model=model, task=task, + model=model, auxiliary_models=auxiliary_models, ) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index c22516be42..dc2b796246 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -11,6 +11,7 @@ import torch +from trinity.algorithm import ADD_STRATEGY from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.buffer import get_buffer_writer from trinity.buffer.buffer import get_buffer_reader @@ -80,6 +81,15 @@ def __init__(self, config: Config): self.status = RunningStatus.RUNNING self.logger.info("Finished initializing Explorer.") self._ready_to_sync_condition = asyncio.Condition() + self.collect_experiences = self.config.explorer.collect_experiences + self.generated_experience_cnt = 0 + if self.collect_experiences: + assert ( + self.experience_buffer is not None + ), "Experience buffer is required when collect_experiences is True." + self.add_strategy = ADD_STRATEGY.get(self.config.algorithm.add_strategy)( + self.experience_buffer, **self.config.algorithm.add_strategy_args + ) async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None @@ -280,12 +290,12 @@ async def benchmark(self) -> bool: if self.config.explorer.bench_on_latest_checkpoint: self.explore_step_num = await self._checkpoint_weights_update() self.eval() - await self._log_eval_metrics(prefix="bench") + await self._finish_eval_step(prefix="bench") return True # benchmark on base model if self.config.explorer.eval_on_startup: - await self._log_eval_metrics(prefix="bench") + await self._finish_eval_step(prefix="bench") # benchmark on all checkpoints all_ckp_steps = sorted( @@ -299,16 +309,17 @@ async def benchmark(self) -> bool: for step_num in all_ckp_steps: self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num) self.eval() - await self._log_eval_metrics(prefix="bench") + await self._finish_eval_step(prefix="bench") return True async def save_checkpoint(self, sync_weight: bool = False) -> None: - # wait for all tasks to complete - self.logger.info("Waiting for all tasks to complete") - await self.scheduler.wait_all() - self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") + if not self.config.explorer.collect_experiences: + # wait for all tasks to complete + self.logger.info("Waiting for all tasks to complete") + await self.scheduler.wait_all() + self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") log_task = asyncio.create_task( - self._log_metrics(self.last_sync_step + 1, self.explore_step_num) + self._finish_steps(self.last_sync_step + 1, self.explore_step_num) ) if sync_weight: @@ -335,19 +346,24 @@ async def sync_weight(self) -> None: # call this method before training start to load the latest model weights await self.save_checkpoint(sync_weight=True) - async def _log_metrics(self, start_step: int, end_step: int) -> None: + async def _finish_steps(self, start_step: int, end_step: int) -> None: for step in range(start_step, end_step + 1): self.logger.info(f"Log metrics of step {step}") - await self._log_explore_metrics(step=step) - await self._log_eval_metrics(step=step) + await self._finish_explore_step(step=step) + await self._finish_eval_step(step=step) - async def _log_explore_metrics(self, step: int) -> None: - results = await self.scheduler.get_results(batch_id=step) - if results: - metric = gather_metrics([status.metric for status in results], "rollout") + async def _finish_explore_step(self, step: int) -> None: + statuses, exps = await self.scheduler.get_results(batch_id=step) + metric = {} + if self.config.explorer.collect_experiences: + exp_cnt = await self.add_strategy.add(exps, step) + self.generated_experience_cnt += exp_cnt + metric["rollout/experience_count"] = exp_cnt + if statuses: + metric.update(gather_metrics([status.metric for status in statuses], "rollout")) self.monitor.log(metric, step=step) - async def _log_eval_metrics(self, step: Optional[int] = None, prefix: str = "eval") -> None: + async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: if not self.pending_eval_tasks: return step = step or self.explore_step_num @@ -358,7 +374,7 @@ async def _log_eval_metrics(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - eval_results = await self.scheduler.get_results(f"{step}/{eval_task_name}") + eval_results, _ = await self.scheduler.get_results(f"{step}/{eval_task_name}") metric.update( gather_metrics( [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 459793d2c8..5580d20fa0 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -11,6 +11,7 @@ import ray from trinity.common.config import Config +from trinity.common.experience import Experience from trinity.common.models import InferenceModel from trinity.common.workflows import Task from trinity.explorer.workflow_runner import Status, WorkflowRunner @@ -59,21 +60,23 @@ def _create_runner(self): .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) ) - async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, int]: + async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, List, int]: """ Returns: `Status`: The return status of the task. + `List`: The experiences generated by the task. `int`: The runner_id of current runner. """ last_exception_msg = None await self.runner.__ray_ready__.remote() start_time = time.time() status = Status(ok=False, metric=dict()) + exps = [] try: for attempt in range(self.retry_times + 1): try: task.task.rollout_args.n = task.repeat_times - status = await asyncio.wait_for( + status, exps = await asyncio.wait_for( self.runner.run_task.remote(task.task), self.timeout ) if status.ok: @@ -93,7 +96,7 @@ async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, int]: finally: end_time = time.time() status.metric["task_run_time"] = end_time - start_time - return status, self.runner_id + return status, exps, self.runner_id def restart_runner(self): old_runner = self.runner @@ -147,7 +150,9 @@ def __init__( set ) # batch_id -> futures self.running_task_map: Dict[asyncio.Future, TaskWrapper] = dict() # future -> task - self.completed_tasks: Dict[Union[int, str], deque[Status]] = defaultdict( + self.completed_tasks: Dict[ + Union[int, str], deque[Tuple[Status, List[Experience]]] + ] = defaultdict( deque ) # batch_id -> results @@ -225,13 +230,11 @@ def task_done_callback(self, async_task: asyncio.Task): self.logger.error(f"Task {task.task_id} failed: {async_task.exception()}") return else: - task_result, runner_id = async_task.result() - self.completed_tasks[task.batch_id].appendleft(task_result) + status, exps, runner_id = async_task.result() + self.completed_tasks[task.batch_id].appendleft((status, exps)) self.busy_runners.pop(runner_id) self.idle_runners.add(runner_id) - self.logger.debug( - f"Task completed (batch_id {task.batch_id}), success: {task_result.ok}" - ) + self.logger.debug(f"Task completed (batch_id {task.batch_id}), success: {status.ok}") if task.batch_id in self.running_tasks: self.running_tasks[task.batch_id].remove(async_task) @@ -294,8 +297,8 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: def _split_and_submit_tasks(self, tasks: List[Task], batch_id: Union[int, str]) -> None: for i, task in enumerate(tasks): - group_id = f"{batch_id}/{i}" - task.group_id = group_id + task.batch_id = batch_id + task.task_id = i if self.max_repeat_times is None: self.pending_tasks[batch_id].appendleft( TaskWrapper( @@ -321,7 +324,7 @@ async def get_results( min_num: Optional[int] = None, timeout: Optional[float] = None, clear_timeout_tasks: bool = True, - ) -> List[Status]: + ) -> Tuple[List[Status], List[Experience]]: """Get the result of tasks at the specific batch_id. Args: @@ -333,18 +336,19 @@ async def get_results( timeout = timeout or self.default_timeout start_time = time.time() if min_num is None: - min_num = 0 - if batch_id in self.pending_tasks: - min_num += len(self.pending_tasks[batch_id]) - if batch_id in self.running_tasks: - min_num += len(self.running_tasks[batch_id]) - if batch_id in self.completed_tasks: - min_num += len(self.completed_tasks[batch_id]) + min_num = sum( + len(tasks) # type: ignore [misc] + for tasks in ( + self.pending_tasks.get(batch_id, []), + self.running_tasks.get(batch_id, []), + self.completed_tasks.get(batch_id, []), + ) + ) self.logger.debug(f"Waiting for {min_num} tasks to complete...") - while time.time() - start_time < timeout: - completed_count = len(self.completed_tasks[batch_id]) + while time.time() - start_time <= timeout: + completed_count = len(self.completed_tasks.get(batch_id, [])) if completed_count >= min_num: break await asyncio.sleep(0.1) @@ -353,25 +357,32 @@ async def get_results( self.logger.error(f"Timed out waiting for tasks to complete after {timeout} seconds") if clear_timeout_tasks: self._clear_timeout_tasks(batch_id=batch_id) - for runner_id in list(self.busy_runners.keys()): - if self.busy_runners[runner_id].batch_id == batch_id: + for runner_id, task in list(self.busy_runners.items()): + if task.batch_id == batch_id: self._restart_runner(runner_id) - results = [] + statuses = [] + experiences = [] + completed_queue = self.completed_tasks.get(batch_id, deque()) for _ in range(min_num): - if len(self.completed_tasks[batch_id]) > 0: - results.append(self.completed_tasks[batch_id].pop()) - - if not self.completed_tasks[batch_id]: + if completed_queue: + status, exps = completed_queue.pop() + statuses.append(status) + if isinstance(exps, list): + experiences.extend(exps) + else: + experiences.append(exps) + + if batch_id in self.completed_tasks and not self.completed_tasks[batch_id]: del self.completed_tasks[batch_id] - completed_count = len(results) + completed_count = len(statuses) if completed_count < min_num: self.logger.warning( f"Timeout reached, only {completed_count}/{min_num} tasks completed" ) - return results + return statuses, experiences def has_step(self, batch_id: Union[int, str]) -> bool: return ( diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index b973a114d2..608db16abb 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- -"""The Workflow Runner Moudle.""" +"""The Workflow Runner Module.""" import time import traceback -import uuid from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from trinity.buffer import get_buffer_writer from trinity.common.config import Config @@ -25,7 +24,7 @@ class Status: class WorkflowRunner: - """A Ray remote actor to run the workflow and put the returned experiences into the buffer.""" + """A Ray remote actor to run the workflow and generate experiences.""" def __init__( self, @@ -55,6 +54,7 @@ def __init__( self.logger = get_logger(__name__) self.workflow_instance = None self.runner_id = runner_id + self.return_experiences = self.config.explorer.collect_experiences def is_alive(self): return True @@ -73,20 +73,18 @@ def _run_task(self, task: Task) -> List[Experience]: self.workflow_instance.reset(task) return self.workflow_instance.run() - def run_task(self, task: Task) -> Status: + def run_task(self, task: Task) -> Tuple[Status, List[Experience]]: """Run the task and return the states.""" + # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead try: st = time.time() exps = self._run_task(task) assert exps is not None and len(exps) > 0, "An empty experience is generated" metrics: dict[str, List[float]] = defaultdict(list) # set group id - for idx, exp in enumerate(exps): - setattr(exp, "group_id", task.group_id) - setattr( - exp, "unique_id", f"{task.group_id}/{self.runner_id}/{str(uuid.uuid4())[:6]}" - ) - + for _, exp in enumerate(exps): + exp.eid.batch = task.batch_id + exp.eid.task = task.task_id if not hasattr(exp, "info") or exp.info is None: exp.info = {} exp.info["model_version"] = self.model_wrapper.model_version @@ -102,10 +100,17 @@ def run_task(self, task: Task) -> Status: if metrics: for k, v in metrics.items(): metric[k] = sum(v) / len(v) # type: ignore - if not task.is_eval: + + if task.is_eval: + # If the task is an evaluation task, we do not record the experiences to the buffer + return Status(True, metric=metric), [] + elif self.return_experiences: + return Status(True, metric=metric), exps + else: self.experience_buffer.write(exps) - return Status(True, metric=metric) + return Status(True, metric=metric), [] + except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") - return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)) + return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), [] diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 75e6c57f46..00a67ee002 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -27,8 +27,7 @@ import torch import torch.distributed import torch.distributed as dist - -# import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically. +import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically. from codetiming import Timer from omegaconf import DictConfig, OmegaConf, open_dict from peft import LoraConfig, TaskType, get_peft_model