diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 7a388390d8..430f872489 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -105,6 +105,10 @@ buffer: format: prompt_key: 'question' response_key: 'answer' + rollout_args: + repeat_times: 1 + temperature: 1.0 + logprobs: 0 eval_tasksets: [] default_workflow_type: 'math_workflow' default_reward_fn_type: 'countdown_reward' @@ -123,6 +127,9 @@ buffer: - `buffer.explorer_input.taskset.path`: The path to the taskset. - `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`. - `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`. +- `buffer.explorer_input.taskset.rollout_args.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`. +- `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`. +- `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`. - `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default. - `buffer.explorer_input.default_workflow_type`: The default workflow type for `taskset` and `eval_tasksets`. - `buffer.explorer_input.default_reward_fn_type`: The default reward function type for `taskset` and `eval_tasksets`. @@ -145,10 +152,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 5 use_ray: false backend: 'nccl' max_pending_requests: 32 @@ -162,10 +166,7 @@ explorer: - `explorer.enable_prefix_caching`: Whether to enable prefix caching. Default is `False`. - `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`. - `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`. -- `explorer.temperature`: The temperature used in vLLM. Default is `1.0`. - `explorer.seed`: The seed used in vLLM. Default is `42`. -- `explorer.logprobs`: The logprobs used in vLLM. Default is `0`. -- `explorer.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `5`. - `explorer.use_ray`: Whether to use Ray. Default is `False`. - `explorer.backend`: The backend used in vLLM. Default is `nccl`. - `explorer.max_pending_requests`: The maximum number of pending requests. Default is `32`. diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml index 65dc61ce94..673da76a59 100644 --- a/examples/async_gsm8k/explorer.yaml +++ b/examples/async_gsm8k/explorer.yaml @@ -23,6 +23,10 @@ buffer: format: prompt_key: 'question' response_key: 'answer' + rollout_args: + repeat_times: 8 + temperature: 1.0 + logprobs: 0 default_workflow_type: 'math_workflow' trainer_input: experience_buffer: @@ -37,10 +41,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 8 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml index 85a9afbf11..df193c3f37 100644 --- a/examples/async_gsm8k/trainer.yaml +++ b/examples/async_gsm8k/trainer.yaml @@ -22,6 +22,10 @@ buffer: format: prompt_key: 'question' response_key: 'answer' + rollout_args: + repeat_times: 8 + temperature: 1.0 + logprobs: 0 default_workflow_type: 'math_workflow' trainer_input: experience_buffer: @@ -36,10 +40,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 8 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 1e812201a0..de459f9230 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -23,22 +23,6 @@ buffer: prompt_key: prompt chosen_key: chosen rejected_key: rejected -explorer: - engine_type: vllm_async - engine_num: 0 - runner_num: 32 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - temperature: 1.0 - seed: 42 - logprobs: 0 - repeat_times: 1 # NOTE - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 synchronizer: sync_method: 'checkpoint' sync_interval: 30 diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index b083d7874d..18dc2595e6 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -19,6 +19,10 @@ buffer: path: 'scripts/data_prepare/alfworld_data' format: prompt_key: 'game_file' + rollout_args: + repeat_times: 8 + temperature: 1.0 + logprobs: 0 default_workflow_type: 'alfworld_workflow' trainer_input: experience_buffer: @@ -33,10 +37,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 8 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 7e25bfaaee..748a5ac0e5 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -36,6 +36,10 @@ buffer: format: prompt_key: 'question' response_key: 'answer' + rollout_args: + repeat_times: 8 + temperature: 1.0 + logprobs: 0 eval_tasksets: - name: gsm8k-eval storage_type: file @@ -65,10 +69,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 8 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index 06f88fe818..b22291b09a 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -21,6 +21,10 @@ buffer: format: prompt_key: 'question' response_key: 'gt_answer' + rollout_args: + repeat_times: 8 + temperature: 1.0 + logprobs: 0 default_workflow_type: 'math_workflow' trainer_input: experience_buffer: @@ -35,10 +39,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 8 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml index ffe30d44f0..6ba88d51e2 100644 --- a/examples/grpo_sciworld/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -19,6 +19,10 @@ buffer: path: 'scripts/data_prepare/sciworld_data' format: prompt_key: 'game_file' + rollout_args: + repeat_times: 8 + temperature: 1.0 + logprobs: 0 default_workflow_type: 'sciworld_workflow' trainer_input: experience_buffer: @@ -33,10 +37,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 8 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml index 451495433f..d5b59d67b0 100644 --- a/examples/grpo_webshop/webshop.yaml +++ b/examples/grpo_webshop/webshop.yaml @@ -19,6 +19,10 @@ buffer: path: 'scripts/data_prepare/webshop_data' format: prompt_key: 'task_id' + rollout_args: + repeat_times: 8 + temperature: 1.0 + logprobs: 0 default_workflow_type: 'webshop_workflow' trainer_input: experience_buffer: @@ -33,10 +37,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 8 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml index dcfeee47db..4739400f1a 100644 --- a/examples/opmd_gsm8k/opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml @@ -20,6 +20,10 @@ buffer: format: prompt_key: 'question' response_key: 'answer' + rollout_args: + repeat_times: 8 + temperature: 1.0 + logprobs: 0 default_workflow_type: 'math_workflow' trainer_input: experience_buffer: @@ -34,10 +38,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 8 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml index 941c0ef97b..f7ad9c4362 100644 --- a/examples/ppo_countdown/countdown.yaml +++ b/examples/ppo_countdown/countdown.yaml @@ -21,6 +21,10 @@ buffer: format: prompt_key: 'question' response_key: 'answer' + rollout_args: + repeat_times: 5 + temperature: 1.0 + logprobs: 0 default_workflow_type: 'math_workflow' default_reward_fn_type: 'countdown_reward' trainer_input: @@ -36,10 +40,7 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 - repeat_times: 5 use_ray: false backend: 'nccl' max_pending_requests: 32 diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 9b1d5f9997..090cb6f2fd 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -85,8 +85,9 @@ def get_model_path() -> str: class BaseTestModelWrapper: def test_generate(self): prompts = ["Hello, world!", "Hello, my name is"] - results = self.model_wrapper.generate(prompts) - self.assertEqual(len(results), len(prompts) * self.config.explorer.repeat_times) + repeat_times = self.config.buffer.explorer_input.taskset.rollout_args.repeat_times + results = self.model_wrapper.generate(prompts, n=repeat_times, temperature=1.0) + self.assertEqual(len(results), len(prompts) * repeat_times) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the weather like today?"}, @@ -96,8 +97,8 @@ def test_generate(self): }, {"role": "user", "content": "OK, thanks!"}, ] - results = self.model_wrapper.chat(messages) - self.assertEqual(len(results), self.config.explorer.repeat_times) + results = self.model_wrapper.chat(messages, n=repeat_times, temperature=1.0) + self.assertEqual(len(results), repeat_times) for result in results: input_logprobs = result.logprobs[: result.prompt_length] output_logprobs = result.logprobs[result.prompt_length :] @@ -135,7 +136,7 @@ def setUp(self): self.config.explorer.engine_type = "vllm" self.config.explorer.tensor_parallel_size = 1 self.config.explorer.engine_num = 2 - self.config.explorer.repeat_times = 2 + self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) @@ -149,7 +150,7 @@ def setUp(self): self.config.explorer.engine_type = "vllm_async" self.config.explorer.engine_num = 2 self.config.explorer.tensor_parallel_size = 1 - self.config.explorer.repeat_times = 2 + self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) @@ -176,7 +177,7 @@ def setUp(self): self.config.explorer.engine_type = "vllm_async" self.config.explorer.engine_num = 2 self.config.explorer.tensor_parallel_size = 2 - self.config.explorer.repeat_times = 2 + self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 self.config.explorer.use_v1 = True self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 013aa29f64..f4570571a1 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -22,7 +22,7 @@ def setUp(self): self.config.global_config.batch_size = 4 self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" - self.config.explorer.repeat_times = 2 + self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 self.config.monitor.monitor_type = MonitorType.TENSORBOARD self.config.monitor.project = "Trinity-unittest" self.config.model.checkpoint_path = get_checkpoint_path() diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 49f6f3d924..5d63fcdc63 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -6,12 +6,13 @@ import ray import torch +from tests.tools import get_unittest_dataset_config from trinity.buffer.reader.queue_reader import QueueReader from trinity.common.config import StorageConfig, load_config from trinity.common.constants import AlgorithmType, StorageType from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel -from trinity.common.task import Task +from trinity.common.workflows import Task from trinity.common.workflows.workflow import WORKFLOWS, Workflow from trinity.explorer.runner_pool import RunnerPool @@ -20,9 +21,9 @@ @WORKFLOWS.register_module("dummy_workflow") class DummyWorkflow(Workflow): - def __init__(self, model, **kwargs): - super().__init__(model) - self.error_type = kwargs.get("task_desc") + def __init__(self, model, task): + super().__init__(model, task) + self.error_type = task.task_desc self.seconds = None if "timeout" in self.error_type: self.seconds = int(self.error_type.split("_")[-1]) @@ -81,30 +82,61 @@ def setUp(self): def test_runner_pool(self): pool = RunnerPool(self.config, [DummyModel.remote(), DummyModel.remote()]) + taskset_config = get_unittest_dataset_config("countdown") tasks = [ Task( - task_desc="timeout_100", workflow=DummyWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "timeout_100", + }, ), Task( - task_desc="exception", workflow=DummyWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "exception", + }, ), Task( - task_desc="timeout_2", workflow=DummyWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "timeout_2", + }, ), Task( - task_desc="success", workflow=DummyWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "success", + }, ), Task( - task_desc="timeout_101", workflow=DummyWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "timeout_101", + }, ), Task( - task_desc="exit", workflow=DummyWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "exit", + }, ), ] diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index e27c75a95d..3c00733e54 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -4,7 +4,9 @@ from dataclasses import dataclass from unittest.mock import MagicMock +from tests.tools import get_unittest_dataset_config from trinity.common.workflows import MathWorkflow +from trinity.common.workflows.workflow import Task @dataclass @@ -27,9 +29,19 @@ def test_math_workflow(self) -> None: MockResponse("\nOnly thinking\n"), MockResponse("ThinkingAnswer is not end1"), ] - workflow = MathWorkflow(model=model, task_desc="1+1=", truth="2") + taskset_config = get_unittest_dataset_config("countdown") + task = Task( + workflow=MathWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "1+1=", + taskset_config.format.response_key: "2", + }, + ) + workflow = task.to_workflow(model=model) experiences = workflow.run() - print(experiences) self.assertEqual(len(experiences), 9) self.assertEqual(experiences[0].reward, 0.9) self.assertEqual(experiences[1].reward, -0.1) @@ -51,7 +63,18 @@ def test_math_fraction_workflow(self) -> None: MockResponse(r"\boxed{\frac{1} {10}}"), MockResponse(r"The answer is \boxed{\frac{40}{400}}"), ] - workflow = MathWorkflow(model=model, task_desc=r"\frac{40}{400}", truth=r"\frac{40}{400}") + taskset_config = get_unittest_dataset_config("countdown") + task = Task( + workflow=MathWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: r"\frac{40}{400}", + taskset_config.format.response_key: r"\frac{40}{400}", + }, + ) + workflow = task.to_workflow(model=model) experiences = workflow.run() self.assertEqual(len(experiences), 6) self.assertEqual(experiences[0].reward, 0.9) @@ -68,11 +91,18 @@ def test_math_complex_workflow(self) -> None: r"$\boxed{\dfrac{108 + 31\sqrt{5}}{216}} \quad \text{and} \quad \boxed{\dfrac{108 - 31\sqrt{5}}{216}}$" ), ] - workflow = MathWorkflow( - model=model, - task_desc="", - truth=r"$x_{1}=\frac{1}{2}+\frac{31\sqrt{5}}{216},\quadx_{2}=\frac{1}{2}-\frac{31\sqrt{5}}{216}$", + taskset_config = get_unittest_dataset_config("countdown") + task = Task( + workflow=MathWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "", + taskset_config.format.response_key: r"$x_{1}=\frac{1}{2}+\frac{31\sqrt{5}}{216},\quadx_{2}=\frac{1}{2}-\frac{31\sqrt{5}}{216}$", + }, ) + workflow = task.to_workflow(model=model) experiences = workflow.run() self.assertEqual(len(experiences), 1) self.assertEqual(experiences[0].reward, 0.9) @@ -84,11 +114,18 @@ def test_gsm8k_workflow(self) -> None: MockResponse(" 36.0 "), MockResponse("Kim's total points are 6 + 30 = 36 "), ] - workflow = MathWorkflow( - model=model, - task_desc="", - truth=r"36", + taskset_config = get_unittest_dataset_config("countdown") + task = Task( + workflow=MathWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "", + taskset_config.format.response_key: r"36", + }, ) + workflow = task.to_workflow(model=model) experiences = workflow.run() # self.assertEqual(len(experiences), 1) self.assertEqual(experiences[0].reward, 1.1) diff --git a/tests/template/config.yaml b/tests/template/config.yaml index aee283bccb..0cf3ba3cf9 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -18,7 +18,7 @@ buffer: taskset: name: taskset storage_type: file - path: '' + path: 'placeholder' split: 'train' default_workflow_type: '' default_reward_fn_type: '' @@ -26,14 +26,11 @@ explorer: engine_type: vllm_async engine_num: 2 runner_num: 4 - repeat_times: 1 tensor_parallel_size: 1 enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 backend: nccl use_ray: false use_v1: true diff --git a/tests/test_data/template.yaml b/tests/test_data/template.yaml index 2ac9addf28..21fe407842 100644 --- a/tests/test_data/template.yaml +++ b/tests/test_data/template.yaml @@ -25,6 +25,4 @@ explorer: enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 1.0 seed: 42 - logprobs: 0 diff --git a/tests/tools.py b/tests/tools.py index be415fba75..a650638a0d 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -6,7 +6,13 @@ import ray from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from trinity.common.config import Config, FormatConfig, StorageConfig, load_config +from trinity.common.config import ( + Config, + FormatConfig, + GenerationConfig, + StorageConfig, + load_config, +) def get_template_config() -> Config: @@ -45,6 +51,11 @@ def get_unittest_dataset_config( prompt_key="question", response_key="answer", ), + rollout_args=GenerationConfig( + repeat_times=1, + temperature=1.0, + logprobs=0, + ), default_workflow_type="math_workflow", default_reward_fn_type="countdown_reward", ) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 62f4abf745..b0ee7cc089 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -26,7 +26,7 @@ def setUp(self): self.config.global_config.batch_size = 4 self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" - self.config.explorer.repeat_times = 3 + self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 3 self.config.explorer.use_v1 = False self.config.monitor.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.monitor.monitor_type = MonitorType.TENSORBOARD diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 0ca202b2fa..316d3ae297 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -11,8 +11,7 @@ from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType from trinity.common.experience import Experience from trinity.common.rewards import REWARD_FUNCTIONS -from trinity.common.task import Task -from trinity.common.workflows import WORKFLOWS +from trinity.common.workflows import WORKFLOWS, Task from trinity.utils.registry import Registry FILE_READERS = Registry("file_readers") @@ -173,6 +172,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List: @FILE_READERS.register_module("rollout") class RolloutDataReader(BufferReader): def __init__(self, meta: StorageConfig, config: BufferConfig): + self.meta = meta self.name = meta.name self.split = meta.split subset_name = meta.subset_name @@ -206,8 +206,6 @@ def read(self, strategy: Optional[ReadStrategy] = None): if self.index >= len(self.dataset) * self.total_epochs: raise StopIteration sample = self.dataset[self.index % len(self.dataset)] - task_desc = sample[self.prompt_key] if self.prompt_key in sample else None - truth = sample[self.response_key] if self.response_key in sample else None workflow_class = ( WORKFLOWS.get(sample[self.workflow_key]) if self.workflow_key in sample @@ -220,12 +218,12 @@ def read(self, strategy: Optional[ReadStrategy] = None): ) assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required" task = Task( - task_desc=task_desc, - truth=truth, workflow=workflow_class, + format_args=self.meta.format, + rollout_args=self.meta.rollout_args, + is_eval=self.meta.task_type == TaskType.EVAL, reward_fn=reward_fn, - raw=sample, - task_type=self.task_type, + raw_task=sample, ) self.index += 1 if self.task_type == TaskType.EVAL and self.index == len(self.dataset): diff --git a/trinity/common/config.py b/trinity/common/config.py index 9c9ea72017..2e8e830007 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -28,7 +28,10 @@ class FormatConfig: prompt_key: str = "prompt" response_key: str = "response" messages_key: str = "message" - chat_template: str = "" + chat_template: str = "" # deprecated + + system_prompt: Optional[str] = None + reply_prefix: Optional[str] = None # for sample-level task controlling reward_fn_key: str = "" @@ -47,6 +50,17 @@ class FormatConfig: label_key: str = "" +@dataclass +class GenerationConfig: + # repeat each task for `repeat_times` times (for GPRO-like algorithms) + repeat_times: int = 1 + + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + logprobs: int = 0 # vLLM return `logprobs + 1` elements + + @dataclass class StorageConfig: """Storage config.""" @@ -63,13 +77,11 @@ class StorageConfig: index: int = 0 # used for algorithm_type is None - task_type: TaskType = TaskType.EXPLORE + task_type: TaskType = TaskType.EXPLORE # automatically set default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None total_epochs: int = 1 # automatically set - # used for algorithm_type is None and TaskType.EVAL - eval_repeat_times: int = 1 # TODO - eval_temperature: float = 0.1 # TODO + rollout_args: GenerationConfig = field(default_factory=GenerationConfig) @dataclass @@ -138,8 +150,11 @@ class ExplorerInput: taskset: StorageConfig = field(default_factory=StorageConfig) eval_tasksets: List[StorageConfig] = field(default_factory=list) + # The following args provide default values for the corresponding args in `taskset` and `eval_tasksets` default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None + system_prompt: Optional[str] = None + reply_prefix: Optional[str] = None @dataclass @@ -180,9 +195,6 @@ class ExplorerConfig: # For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num` runner_num: int = 1 - # repeat each task for `repeat_times` times (for GPRO-like algorithms) - repeat_times: int = 1 - # for rollout tokneize chat_template: Optional[str] = None @@ -191,11 +203,7 @@ class ExplorerConfig: enable_prefix_caching: bool = False enforce_eager: bool = True dtype: str = "bfloat16" - temperature: float = 0.0 - top_p: float = 1.0 - top_k: int = -1 seed: int = 42 - logprobs: int = 0 # vLLM return `logprobs + 1` elements backend: str = "nccl" use_ray: bool = False gpu_memory_utilization: float = 0.9 @@ -327,10 +335,12 @@ def _check_interval(self) -> None: def _check_buffer(self) -> None: # noqa: C901 # check explorer_input - if self.mode != "train" and self.buffer.explorer_input.taskset.path is None: + if self.mode != "train" and not self.buffer.explorer_input.taskset.path: raise ValueError( "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset." ) + if not self.buffer.explorer_input.taskset.name: + self.buffer.explorer_input.taskset.name = "taskset" self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE self.buffer.explorer_input.taskset.total_epochs = self.global_config.total_epochs if self.buffer.explorer_input.taskset.default_workflow_type is None: @@ -341,13 +351,33 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.explorer_input.taskset.default_reward_fn_type = ( self.buffer.explorer_input.default_reward_fn_type ) + if self.buffer.explorer_input.taskset.format.system_prompt is None: + self.buffer.explorer_input.taskset.format.system_prompt = ( + self.buffer.explorer_input.system_prompt + ) + if self.buffer.explorer_input.taskset.format.reply_prefix is None: + self.buffer.explorer_input.taskset.format.reply_prefix = ( + self.buffer.explorer_input.reply_prefix + ) - for dataset in self.buffer.explorer_input.eval_tasksets: + remained_tasksets = [] + for idx, dataset in enumerate(self.buffer.explorer_input.eval_tasksets): + if not dataset.path: + logger.warning(f"Eval dataset [{dataset}]'s path is not configured. Skip.") + continue dataset.task_type = TaskType.EVAL + if not dataset.name: + dataset.name = f"eval_taskset_{idx}" if dataset.default_workflow_type is None: dataset.default_workflow_type = self.buffer.explorer_input.default_workflow_type if dataset.default_reward_fn_type is None: dataset.default_reward_fn_type = self.buffer.explorer_input.default_reward_fn_type + if dataset.format.system_prompt is None: + dataset.format.system_prompt = self.buffer.explorer_input.system_prompt + if dataset.format.reply_prefix is None: + dataset.format.reply_prefix = self.buffer.explorer_input.reply_prefix + remained_tasksets.append(dataset) + self.buffer.explorer_input.eval_tasksets = remained_tasksets # check trainer_input.experience_buffer if self.mode == "both": @@ -387,7 +417,10 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT # set read_batch_size / pad_token_id / tokenizer_path - self.buffer.read_batch_size = self.global_config.batch_size * self.explorer.repeat_times + self.buffer.read_batch_size = ( + self.global_config.batch_size + * self.buffer.explorer_input.taskset.rollout_args.repeat_times + ) if self.buffer.pad_token_id is None: from transformers import AutoTokenizer diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 2f9249e0e9..fc39c3c303 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -47,6 +47,7 @@ def __init__( self.use_v1 = config.explorer.use_v1 if config.explorer.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.explorer.bundle_indices if not vllm.envs.is_set("VLLM_USE_V1"): self.logger.info(f"Using vLLM v{int(config.explorer.use_v1)} engine") os.environ["VLLM_USE_V1"] = str(int(config.explorer.use_v1)) @@ -55,15 +56,15 @@ def __init__( os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" self.default_sampling_params = vllm.SamplingParams( - n=config.explorer.repeat_times, - temperature=config.explorer.temperature, + n=1, + temperature=0.0, max_tokens=config.model.max_response_tokens, min_tokens=1, truncate_prompt_tokens=config.model.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, output_kind=RequestOutputKind.FINAL_ONLY, - logprobs=config.explorer.logprobs, + logprobs=0, ) self.enable_thinking = config.model.enable_thinking self.request_id = 0 diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index c4baf567a6..1e3efe6d22 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -33,6 +33,9 @@ class vLLMRolloutModel(InferenceModel): def __init__(self, config: Config, **kwargs): self.logger = get_logger(__name__) self.config = config + if config.explorer.tensor_parallel_size != 1: + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.explorer.bundle_indices if not vllm.envs.is_set("VLLM_USE_V1"): self.logger.info(f"Using vLLM v{int(config.explorer.use_v1)} engine") os.environ["VLLM_USE_V1"] = str(int(config.explorer.use_v1)) @@ -41,14 +44,14 @@ def __init__(self, config: Config, **kwargs): os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" self.default_sampling_params = SamplingParams( - n=config.explorer.repeat_times, - temperature=config.explorer.temperature, + n=1, + temperature=0.0, max_tokens=config.model.max_response_tokens, min_tokens=1, truncate_prompt_tokens=config.model.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, - logprobs=config.explorer.logprobs, + logprobs=0, ) self.llm = LLM( # TODO: check checkpoint path @@ -151,7 +154,7 @@ def generate(self, prompts: List[str], **kwargs) -> List: Example: - >>> # config.explorer.repeat_times == 2 or kwargs["repeat_times"] == 2 + >>> # config.buffer.explorer_input.taskset.rollout_args.repeat_times == 2 or kwargs["repeat_times"] == 2 >>> >>> prompts = [ >>> "Hello, world!", @@ -172,7 +175,7 @@ def generate(self, prompts: List[str], **kwargs) -> List: ) experiences = [] for output in outputs: - for i in range(self.config.explorer.repeat_times): + for i in range(self.config.buffer.explorer_input.taskset.rollout_args.repeat_times): experiences.append( Experience( tokens=torch.cat( diff --git a/trinity/common/task.py b/trinity/common/task.py deleted file mode 100644 index b0582ab936..0000000000 --- a/trinity/common/task.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -"""Task Class.""" -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Optional, Type - -from trinity.common.config import Config -from trinity.common.constants import TaskType -from trinity.common.rewards.reward_fn import RewardFn -from trinity.common.workflows.workflow import Workflow -from trinity.utils.log import get_logger - -logger = get_logger(__name__) - - -@dataclass -class Task: - """A Task class that defines a task and its associated reward function / workflow.""" - - task_desc: str - workflow: Type[Workflow] - reward_fn: Optional[Type[RewardFn]] = None - truth: Optional[str] = None - raw: Optional[dict] = None # The raw data sample - task_type: Optional[TaskType] = None - - def to_workflow(self, model: Any, config: Config) -> Workflow: - """Convert the task to a workflow. - - Args: - model (ModelWrapper): The rollout model for the workflow. - config (Config): The global configuration. - - Returns: - Workflow: The generated workflow object. - """ - if self.task_type == TaskType.EVAL: - repeat_times = 1 - else: - repeat_times = config.explorer.repeat_times - return self.workflow( - model=model, - task_desc=self.task_desc, - truth=self.truth, - reward_fn=self.reward_fn, - raw=self.raw, - repeat_times=repeat_times, - config=config, - is_eval=self.task_type == TaskType.EVAL, - ) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index d44440c826..944037ac2e 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -301,10 +301,14 @@ def synchronize_config(self, config: Config) -> None: self.trainer.default_local_dir = config.model.checkpoint_path self.trainer.sft_warmup_steps = config.trainer.sft_warmup_steps self.actor_rollout_ref.actor.ppo_mini_batch_size = config.global_config.batch_size - self.actor_rollout_ref.rollout.temperature = config.explorer.temperature - self.actor_rollout_ref.rollout.n = config.explorer.repeat_times + self.actor_rollout_ref.rollout.temperature = ( + config.buffer.explorer_input.taskset.rollout_args.temperature + ) + self.actor_rollout_ref.rollout.n = ( + config.buffer.explorer_input.taskset.rollout_args.repeat_times + ) self.critic.ppo_mini_batch_size = config.global_config.batch_size - self.critic.rollout_n = config.explorer.repeat_times + self.critic.rollout_n = self.actor_rollout_ref.rollout.n self.actor_rollout_ref.actor.algorithm_type = config.trainer.algorithm_type if config.trainer.algorithm_type == AlgorithmType.PPO: diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index a8bcd886a2..92bf29a64e 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -3,9 +3,10 @@ from .envs.alfworld.alfworld_workflow import AlfworldWorkflow from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow -from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow +from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task __all__ = [ + "Task", "WORKFLOWS", "SimpleWorkflow", "MathWorkflow", diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index a7e45e4a7f..9f8b389858 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -3,7 +3,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task EXAMPLE_PROMPT = """ Observation: @@ -96,13 +96,17 @@ def parse_action(response): class AlfworldWorkflow(MultiTurnWorkflow): """A workflow for alfworld task.""" - def __init__(self, model: ModelWrapper, **kwargs): - super().__init__(model) - self.system_prompt = kwargs.get("system_prompt", None) # Unuse here - self.task_desc: str = kwargs.get("task_desc") - self.truth = kwargs.get("truth") # Unuse here - self.reward_fn = None # Unuse here - self.repeat_times = kwargs.get("repeat_times", 1) + def __init__( + self, + model: ModelWrapper, + task: Task, + ): + super().__init__( + model=model, + task=task, + ) + self.task_desc = task.task_desc or "0" + self.repeat_times = task.rollout_args.repeat_times self.max_env_steps = 30 def get_model_response(self, messages): diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index 06a0748f4c..b3669f01d0 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -4,7 +4,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task SCIWORLD_SYSTEM_PROMPT = """ You are an agent, you job is to do some scientific experiment in a virtual test-based environments. @@ -59,13 +59,17 @@ def parse_action(response): class SciWorldWorkflow(MultiTurnWorkflow): """A workflow for sciworld task.""" - def __init__(self, model: ModelWrapper, **kwargs): - super().__init__(model) - self.system_prompt = kwargs.get("system_prompt", None) # Unuse here - self.task_desc: str = kwargs.get("task_desc") - self.truth = kwargs.get("truth") # Unuse here - self.reward_fn = None # Unuse here - self.repeat_times = kwargs.get("repeat_times", 1) + def __init__( + self, + model: ModelWrapper, + task: Task, + ): + super().__init__( + model=model, + task=task, + ) + self.task_desc = task.task_desc or "0" + self.repeat_times = task.rollout_args.repeat_times self.max_env_steps = 30 # should be less than 100 def get_model_response(self, messages): diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 741fcd4b05..55035fd7b4 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -3,7 +3,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task SPARSE_REWARD = False @@ -181,13 +181,17 @@ def validate_action(action, available_actions): class WebShopWorkflow(MultiTurnWorkflow): """A workflow for webshop task.""" - def __init__(self, model: ModelWrapper, **kwargs): - super().__init__(model) - self.system_prompt = kwargs.get("system_prompt", None) # Unuse here - self.task_desc: str = kwargs.get("task_desc", "0") - self.truth = kwargs.get("truth") # Unuse here - self.reward_fn = None # Unuse here - self.repeat_times = kwargs.get("repeat_times", 1) + def __init__( + self, + model: ModelWrapper, + task: Task, + ): + super().__init__( + model=model, + task=task, + ) + self.task_desc = task.task_desc or "0" + self.repeat_times = task.rollout_args.repeat_times self.max_env_steps = 15 def get_model_response(self, messages): diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 4fa3a7c2f5..905fe2e5b8 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -1,11 +1,15 @@ # -*- coding: utf-8 -*- """Base Workflow Class""" +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import List +from dataclasses import asdict, dataclass, field +from typing import Any, List, Optional, Type, Union import torch +from trinity.common.config import FormatConfig, GenerationConfig from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn @@ -18,13 +22,53 @@ WORKFLOWS = Registry("workflows") +@dataclass +class Task: + """A Task class that defines a task and its associated reward function / workflow.""" + + workflow: Type[Workflow] + format_args: FormatConfig + rollout_args: GenerationConfig = field(default_factory=GenerationConfig) + is_eval: bool = False + reward_fn: Optional[Type[RewardFn]] = None + raw_task: Optional[dict] = None # The raw data sample + + def to_workflow(self, model: Any) -> Workflow: + """Convert the task to a workflow. + + Args: + model (ModelWrapper): The rollout model for the workflow. + + Returns: + Workflow: The generated workflow object. + """ + return self.workflow( + model=model, + task=self, + ) + + @property + def task_desc(self) -> Union[str, None]: + prompt_key = self.format_args.prompt_key + return self.raw_task[prompt_key] if prompt_key in self.raw_task else None # type: ignore + + @property + def truth(self) -> Union[str, None]: + response_key = self.format_args.response_key + return self.raw_task[response_key] if response_key in self.raw_task else None # type: ignore + + class Workflow(ABC): """The base workflow class. A workflow is a runnable object which generates a list of experiences. """ - def __init__(self, model: ModelWrapper, **kwargs): + def __init__( + self, + model: ModelWrapper, + task: Task, + ): self.model = model @abstractmethod @@ -37,8 +81,15 @@ class MultiTurnWorkflow(Workflow): The base workflow class for multi-turn tasks. """ - def __init__(self, model: ModelWrapper, **kwargs): - super().__init__(model) + def __init__( + self, + model: ModelWrapper, + task: Task, + ): + super().__init__( + model=model, + task=task, + ) @abstractmethod def run(self) -> List[Experience]: @@ -81,34 +132,46 @@ class SimpleWorkflow(Workflow): def __init__( self, model: ModelWrapper, - **kwargs, + task: Task, ): - super().__init__(model) - self.system_prompt = kwargs.get("system_prompt", None) - self.reply_prefix = kwargs.get("reply_prefix", None) # TODO: add reply_prefix - self.task_desc = kwargs.get("task_desc") - self.truth = kwargs.get("truth") - self.reward_fn = kwargs.get("reward_fn") - if isinstance(self.reward_fn, type) and issubclass(self.reward_fn, RewardFn): - self.reward_fn = self.reward_fn() + super().__init__( + model=model, + task=task, + ) + self.format_args = task.format_args + self.system_prompt = task.format_args.system_prompt + self.reply_prefix = task.format_args.reply_prefix + + self.raw_task = task.raw_task + self.task_desc = task.task_desc + self.truth = task.truth + + reward_fn = task.reward_fn + if isinstance(reward_fn, type) and issubclass(reward_fn, RewardFn): + self.reward_fn: RewardFn = reward_fn() else: raise ValueError("`reward_fn` must be a subclass of `RewardFn`") - # Rollout n times - self.repeat_times = kwargs.get("repeat_times", 1) - self.is_eval = kwargs.get("is_eval", False) + # Rollout args + rollout_args = asdict(task.rollout_args) + rollout_args["n"] = rollout_args["repeat_times"] + self.rollout_args = rollout_args + self.is_eval = task.is_eval - def run(self) -> List[Experience]: - # TODO: Optimize the generate function + def format_messages(self): messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) messages.append({"role": "user", "content": self.task_desc}) if self.reply_prefix: messages.append({"role": "assistant", "content": self.reply_prefix}) + return messages + + def run(self) -> List[Experience]: + # TODO: Optimize the generate function + messages = self.format_messages() logger.debug("start chat") - n = 1 if self.is_eval else self.repeat_times - responses = self.model.chat(messages, n=n) + responses = self.model.chat(messages, **self.rollout_args) for response in responses: reward = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] @@ -134,18 +197,16 @@ class MathWorkflow(SimpleWorkflow): def __init__( self, model: ModelWrapper, - **kwargs, + task: Task, ): - if kwargs.get("reward_fn", None) is None: - kwargs["reward_fn"] = MathRewardFn - if kwargs["reward_fn"] == MathRewardFn and kwargs.get("system_prompt", None) is None: - kwargs[ - "system_prompt" - ] = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., + if task.reward_fn is None: + task.reward_fn = MathRewardFn + if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None: + task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . """ super().__init__( - model, - **kwargs, + model=model, + task=task, ) diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py index 5805f71907..c7b3f39f1f 100644 --- a/trinity/explorer/runner_pool.py +++ b/trinity/explorer/runner_pool.py @@ -5,7 +5,7 @@ import ray from trinity.common.config import Config -from trinity.common.task import Task +from trinity.common.workflows import Task from trinity.explorer.workflow_runner import Status, WorkflowRunner from trinity.utils.log import get_logger diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 86e4003df4..87e80aaf9b 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -11,10 +11,9 @@ from trinity.buffer import get_buffer_writer from trinity.common.config import Config -from trinity.common.constants import TaskType from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel, ModelWrapper -from trinity.common.task import Task +from trinity.common.workflows import Task from trinity.utils.log import get_logger @@ -48,7 +47,7 @@ def _run_task(self, task: Task) -> List[Experience]: """Init workflow from the task and run it.""" if task.workflow is None: raise ValueError("Workflow is not set in the task.") - workflow = task.to_workflow(self.model_wrapper, self.config) + workflow = task.to_workflow(self.model_wrapper) return workflow.run() def run_task(self, task: Task) -> Status: @@ -77,7 +76,7 @@ def run_task(self, task: Task) -> Status: if metrics: for k, v in metrics.items(): metric[k] = sum(v) / len(v) # type: ignore - if not task.task_type == TaskType.EVAL: + if not task.is_eval: self.experience_buffer.write(exps) return Status(True, metric=metric) except Exception as e: diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 1f55b3c702..1f1e48d61b 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -73,10 +73,12 @@ def _init_default_config(self): "taskset_prompt_key": "question", "taskset_response_key": "answer", # Eval Taskset Configs - # TODO - # Task Workflow Configs + "_eval_tasksets_num": 0, + # Explorer Input Configs "default_workflow_type": "math_workflow", "default_reward_fn_type": "math_reward", + "system_prompt": None, + "reply_prefix": None, # Experience Buffer Configs "_dpo_storage_type": StorageType.FILE.value, "_not_dpo_storage_type": StorageType.QUEUE.value, @@ -197,6 +199,11 @@ def reset_session_state(self): def maintain_session_state(self): for key in self.default_config: st.session_state[key] = st.session_state[key] + eavl_dataset_keys = ["name", "path", "subset_name", "split", "prompt_key", "response_key"] + for idx in range(st.session_state["_eval_tasksets_num"]): + for key in eavl_dataset_keys: + full_key = f"eval_taskset_{idx}_{key}" + st.session_state[full_key] = st.session_state[full_key] def _set_project(self): st.text_input("Project", key="project") @@ -311,11 +318,29 @@ def _set_taskset_path(self): self.unfinished_fields.add("taskset_path") st.warning("Please input taskset path.") + def _set_system_prompt(self): + st.text_area( + "System Prompt", + key="system_prompt", + placeholder="System prompt is used to guide the model behavior.", + ) + + def _set_reply_prefix(self): + st.text_area( + "Assistant Reply Prefix", + key="reply_prefix", + placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """ + """and a common setting is: \nLet me solve this step by step. """, + ) + def _set_taskset_args(self): if st.session_state["taskset_path"] and "://" not in st.session_state["taskset_path"]: subset_name_col, split_col = st.columns(2) subset_name_col.text_input( - "Subset Name :orange-badge[(Needs review)]", key="taskset_subset_name" + "Subset Name :orange-badge[(Needs review)]", + key="taskset_subset_name", + help="The subset name used for `datasets.load_datasets`, see " + "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", ) split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split") prompt_key_col, response_key_col = st.columns(2) @@ -326,6 +351,49 @@ def _set_taskset_args(self): "Response Key :orange-badge[(Needs review)]", key="taskset_response_key" ) + def _set_eval_taskset_idx(self, idx): + st.text_input( + "Taskset Name", + key=f"eval_taskset_{idx}_name", + ) + st.text_input( + "Eval Taskset Path", + key=f"eval_taskset_{idx}_path", + ) + if not st.session_state[f"eval_taskset_{idx}_path"].strip(): + st.warning("Please input the taskset path, or it will be ignored.") + subset_name_col, split_col = st.columns(2) + subset_name_col.text_input( + "Subset Name :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_subset_name", + help="The subset name used for `datasets.load_datasets`, see " + "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.", + ) + split_col.text_input( + "Eval Split :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_split", + ) + prompt_key_col, response_key_col = st.columns(2) + prompt_key_col.text_input( + "Prompt Key :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_prompt_key", + ) + response_key_col.text_input( + "Response Key :orange-badge[(Needs review)]", + key=f"eval_taskset_{idx}_response_key", + ) + + def _set_eval_tasksets(self): + if st.button("Add Eval Taskset"): + st.session_state["_eval_tasksets_num"] += 1 + if st.session_state["_eval_tasksets_num"] > 0: + tabs = st.tabs( + [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])] + ) + for idx, tab in enumerate(tabs): + with tab: + self._set_eval_taskset_idx(idx) + def _set_default_workflow_type(self): st.selectbox( "Default Workflow Type :orange-badge[(Needs review)]", @@ -1131,6 +1199,10 @@ def _expert_buffer_part(self): self._set_configs_with_st_columns(["total_epochs", "train_batch_size"]) self._check_train_batch_size() + self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) + self._set_system_prompt() + self._set_reply_prefix() + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: with st.expander("Taskset Configs", expanded=True): self._set_taskset_path() @@ -1141,8 +1213,7 @@ def _expert_buffer_part(self): self._set_dpo_dataset_kwargs() with st.expander("Eval Tasksets Configs", expanded=True): - # TODO: - pass + self._set_eval_tasksets() with st.expander("SFT Dataset Configs"): self._set_sft_warmup_dataset_path() @@ -1153,8 +1224,6 @@ def _expert_buffer_part(self): self._set_storage_type() self._set_experience_buffer_path() - self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"]) - self.buffer_advanced_tab = st.expander("Advanced Config") with self.buffer_advanced_tab: self._set_configs_with_st_columns(["buffer_max_retry_times", "max_retry_interval"]) @@ -1614,10 +1683,19 @@ def generate_config(self): "prompt_key": st.session_state["taskset_prompt_key"], "response_key": st.session_state["taskset_response_key"], }, + "rollout_args": { + "repeat_times": st.session_state["repeat_times"], + "temperature": st.session_state["temperature"], + "top_p": st.session_state["top_p"], + "top_k": st.session_state["top_k"], + "logprobs": st.session_state["logprobs"], + }, }, "eval_tasksets": [], # TODO: add eval tasksets "default_workflow_type": st.session_state["default_workflow_type"], "default_reward_fn_type": st.session_state["default_reward_fn_type"], + "system_prompt": st.session_state["system_prompt"], + "reply_prefix": st.session_state["reply_prefix"], }, "trainer_input": { "experience_buffer": { @@ -1631,26 +1709,21 @@ def generate_config(self): "engine_type": st.session_state["engine_type"], "engine_num": st.session_state["engine_num"], "runner_num": st.session_state["runner_num"], - "repeat_times": st.session_state["repeat_times"], # "chat_template": None, # TODO: add chat template "tensor_parallel_size": st.session_state["tensor_parallel_size"], "enable_prefix_caching": st.session_state["enable_prefix_caching"], "enforce_eager": st.session_state["enforce_eager"], "dtype": st.session_state["dtype"], - "temperature": st.session_state["temperature"], - "top_p": st.session_state["top_p"], # TODO - "top_k": st.session_state["top_k"], # TODO "seed": st.session_state["seed"], - "logprobs": st.session_state["logprobs"], "backend": st.session_state["backend"], - "use_ray": st.session_state["use_ray"], # TODO - "gpu_memory_utilization": st.session_state["gpu_memory_utilization"], # TODO - "enable_chunked_prefill": st.session_state["enable_chunked_prefill"], # TODO + "use_ray": st.session_state["use_ray"], + "gpu_memory_utilization": st.session_state["gpu_memory_utilization"], + "enable_chunked_prefill": st.session_state["enable_chunked_prefill"], "use_v1": True, "max_pending_requests": st.session_state["max_pending_requests"], "max_waiting_steps": st.session_state["max_waiting_steps"], - "max_timeout": st.session_state["max_timeout"], # TODO - "max_retry_times": st.session_state["explorer_max_retry_times"], # TODO + "max_timeout": st.session_state["max_timeout"], + "max_retry_times": st.session_state["explorer_max_retry_times"], }, "synchronizer": { "sync_method": st.session_state["sync_method"], @@ -1671,6 +1744,18 @@ def generate_config(self): }, } + for idx in range(st.session_state["_eval_tasksets_num"]): + if st.session_state[f"eval_taskset_{idx}_path"].strip(): + config["buffer"]["explorer_input"]["eval_tasksets"].append( + { + "name": st.session_state[f"eval_taskset_{idx}_name"], + "path": st.session_state[f"eval_taskset_{idx}_path"], + "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], + "split": st.session_state[f"eval_taskset_{idx}_split"], + "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], + "response_key": st.session_state[f"eval_taskset_{idx}_response_key"], + } + ) if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: experience_buffer = config["buffer"]["trainer_input"]["experience_buffer"] experience_buffer["split"] = st.session_state["dpo_dataset_train_split"]