diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index 7a388390d8..430f872489 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -105,6 +105,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
+ rollout_args:
+ repeat_times: 1
+ temperature: 1.0
+ logprobs: 0
eval_tasksets: []
default_workflow_type: 'math_workflow'
default_reward_fn_type: 'countdown_reward'
@@ -123,6 +127,9 @@ buffer:
- `buffer.explorer_input.taskset.path`: The path to the taskset.
- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`.
- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`.
+- `buffer.explorer_input.taskset.rollout_args.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`.
+- `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`.
+- `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`.
- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default.
- `buffer.explorer_input.default_workflow_type`: The default workflow type for `taskset` and `eval_tasksets`.
- `buffer.explorer_input.default_reward_fn_type`: The default reward function type for `taskset` and `eval_tasksets`.
@@ -145,10 +152,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 5
use_ray: false
backend: 'nccl'
max_pending_requests: 32
@@ -162,10 +166,7 @@ explorer:
- `explorer.enable_prefix_caching`: Whether to enable prefix caching. Default is `False`.
- `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`.
- `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`.
-- `explorer.temperature`: The temperature used in vLLM. Default is `1.0`.
- `explorer.seed`: The seed used in vLLM. Default is `42`.
-- `explorer.logprobs`: The logprobs used in vLLM. Default is `0`.
-- `explorer.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `5`.
- `explorer.use_ray`: Whether to use Ray. Default is `False`.
- `explorer.backend`: The backend used in vLLM. Default is `nccl`.
- `explorer.max_pending_requests`: The maximum number of pending requests. Default is `32`.
diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml
index 65dc61ce94..673da76a59 100644
--- a/examples/async_gsm8k/explorer.yaml
+++ b/examples/async_gsm8k/explorer.yaml
@@ -23,6 +23,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
+ rollout_args:
+ repeat_times: 8
+ temperature: 1.0
+ logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
@@ -37,10 +41,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml
index 85a9afbf11..df193c3f37 100644
--- a/examples/async_gsm8k/trainer.yaml
+++ b/examples/async_gsm8k/trainer.yaml
@@ -22,6 +22,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
+ rollout_args:
+ repeat_times: 8
+ temperature: 1.0
+ logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
@@ -36,10 +40,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml
index 1e812201a0..de459f9230 100644
--- a/examples/dpo_humanlike/dpo.yaml
+++ b/examples/dpo_humanlike/dpo.yaml
@@ -23,22 +23,6 @@ buffer:
prompt_key: prompt
chosen_key: chosen
rejected_key: rejected
-explorer:
- engine_type: vllm_async
- engine_num: 0
- runner_num: 32
- tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
- dtype: bfloat16
- temperature: 1.0
- seed: 42
- logprobs: 0
- repeat_times: 1 # NOTE
- use_ray: false
- backend: 'nccl'
- max_pending_requests: 32
- max_waiting_steps: 4
synchronizer:
sync_method: 'checkpoint'
sync_interval: 30
diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml
index b083d7874d..18dc2595e6 100644
--- a/examples/grpo_alfworld/alfworld.yaml
+++ b/examples/grpo_alfworld/alfworld.yaml
@@ -19,6 +19,10 @@ buffer:
path: 'scripts/data_prepare/alfworld_data'
format:
prompt_key: 'game_file'
+ rollout_args:
+ repeat_times: 8
+ temperature: 1.0
+ logprobs: 0
default_workflow_type: 'alfworld_workflow'
trainer_input:
experience_buffer:
@@ -33,10 +37,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml
index 7e25bfaaee..748a5ac0e5 100644
--- a/examples/grpo_gsm8k/gsm8k.yaml
+++ b/examples/grpo_gsm8k/gsm8k.yaml
@@ -36,6 +36,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
+ rollout_args:
+ repeat_times: 8
+ temperature: 1.0
+ logprobs: 0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
@@ -65,10 +69,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml
index 06f88fe818..b22291b09a 100644
--- a/examples/grpo_math/math.yaml
+++ b/examples/grpo_math/math.yaml
@@ -21,6 +21,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'gt_answer'
+ rollout_args:
+ repeat_times: 8
+ temperature: 1.0
+ logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
@@ -35,10 +39,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml
index ffe30d44f0..6ba88d51e2 100644
--- a/examples/grpo_sciworld/sciworld.yaml
+++ b/examples/grpo_sciworld/sciworld.yaml
@@ -19,6 +19,10 @@ buffer:
path: 'scripts/data_prepare/sciworld_data'
format:
prompt_key: 'game_file'
+ rollout_args:
+ repeat_times: 8
+ temperature: 1.0
+ logprobs: 0
default_workflow_type: 'sciworld_workflow'
trainer_input:
experience_buffer:
@@ -33,10 +37,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml
index 451495433f..d5b59d67b0 100644
--- a/examples/grpo_webshop/webshop.yaml
+++ b/examples/grpo_webshop/webshop.yaml
@@ -19,6 +19,10 @@ buffer:
path: 'scripts/data_prepare/webshop_data'
format:
prompt_key: 'task_id'
+ rollout_args:
+ repeat_times: 8
+ temperature: 1.0
+ logprobs: 0
default_workflow_type: 'webshop_workflow'
trainer_input:
experience_buffer:
@@ -33,10 +37,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml
index dcfeee47db..4739400f1a 100644
--- a/examples/opmd_gsm8k/opmd_gsm8k.yaml
+++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml
@@ -20,6 +20,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
+ rollout_args:
+ repeat_times: 8
+ temperature: 1.0
+ logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
@@ -34,10 +38,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml
index 941c0ef97b..f7ad9c4362 100644
--- a/examples/ppo_countdown/countdown.yaml
+++ b/examples/ppo_countdown/countdown.yaml
@@ -21,6 +21,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
+ rollout_args:
+ repeat_times: 5
+ temperature: 1.0
+ logprobs: 0
default_workflow_type: 'math_workflow'
default_reward_fn_type: 'countdown_reward'
trainer_input:
@@ -36,10 +40,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
- repeat_times: 5
use_ray: false
backend: 'nccl'
max_pending_requests: 32
diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py
index 9b1d5f9997..090cb6f2fd 100644
--- a/tests/common/vllm_test.py
+++ b/tests/common/vllm_test.py
@@ -85,8 +85,9 @@ def get_model_path() -> str:
class BaseTestModelWrapper:
def test_generate(self):
prompts = ["Hello, world!", "Hello, my name is"]
- results = self.model_wrapper.generate(prompts)
- self.assertEqual(len(results), len(prompts) * self.config.explorer.repeat_times)
+ repeat_times = self.config.buffer.explorer_input.taskset.rollout_args.repeat_times
+ results = self.model_wrapper.generate(prompts, n=repeat_times, temperature=1.0)
+ self.assertEqual(len(results), len(prompts) * repeat_times)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like today?"},
@@ -96,8 +97,8 @@ def test_generate(self):
},
{"role": "user", "content": "OK, thanks!"},
]
- results = self.model_wrapper.chat(messages)
- self.assertEqual(len(results), self.config.explorer.repeat_times)
+ results = self.model_wrapper.chat(messages, n=repeat_times, temperature=1.0)
+ self.assertEqual(len(results), repeat_times)
for result in results:
input_logprobs = result.logprobs[: result.prompt_length]
output_logprobs = result.logprobs[result.prompt_length :]
@@ -135,7 +136,7 @@ def setUp(self):
self.config.explorer.engine_type = "vllm"
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.engine_num = 2
- self.config.explorer.repeat_times = 2
+ self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
@@ -149,7 +150,7 @@ def setUp(self):
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 2
self.config.explorer.tensor_parallel_size = 1
- self.config.explorer.repeat_times = 2
+ self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
@@ -176,7 +177,7 @@ def setUp(self):
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 2
self.config.explorer.tensor_parallel_size = 2
- self.config.explorer.repeat_times = 2
+ self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py
index 013aa29f64..f4570571a1 100644
--- a/tests/explorer/explorer_test.py
+++ b/tests/explorer/explorer_test.py
@@ -22,7 +22,7 @@ def setUp(self):
self.config.global_config.batch_size = 4
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
- self.config.explorer.repeat_times = 2
+ self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.project = "Trinity-unittest"
self.config.model.checkpoint_path = get_checkpoint_path()
diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py
index 49f6f3d924..5d63fcdc63 100644
--- a/tests/explorer/runner_pool_test.py
+++ b/tests/explorer/runner_pool_test.py
@@ -6,12 +6,13 @@
import ray
import torch
+from tests.tools import get_unittest_dataset_config
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.common.config import StorageConfig, load_config
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
-from trinity.common.task import Task
+from trinity.common.workflows import Task
from trinity.common.workflows.workflow import WORKFLOWS, Workflow
from trinity.explorer.runner_pool import RunnerPool
@@ -20,9 +21,9 @@
@WORKFLOWS.register_module("dummy_workflow")
class DummyWorkflow(Workflow):
- def __init__(self, model, **kwargs):
- super().__init__(model)
- self.error_type = kwargs.get("task_desc")
+ def __init__(self, model, task):
+ super().__init__(model, task)
+ self.error_type = task.task_desc
self.seconds = None
if "timeout" in self.error_type:
self.seconds = int(self.error_type.split("_")[-1])
@@ -81,30 +82,61 @@ def setUp(self):
def test_runner_pool(self):
pool = RunnerPool(self.config, [DummyModel.remote(), DummyModel.remote()])
+ taskset_config = get_unittest_dataset_config("countdown")
tasks = [
Task(
- task_desc="timeout_100",
workflow=DummyWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "timeout_100",
+ },
),
Task(
- task_desc="exception",
workflow=DummyWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "exception",
+ },
),
Task(
- task_desc="timeout_2",
workflow=DummyWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "timeout_2",
+ },
),
Task(
- task_desc="success",
workflow=DummyWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "success",
+ },
),
Task(
- task_desc="timeout_101",
workflow=DummyWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "timeout_101",
+ },
),
Task(
- task_desc="exit",
workflow=DummyWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "exit",
+ },
),
]
diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py
index e27c75a95d..3c00733e54 100644
--- a/tests/explorer/workflow_test.py
+++ b/tests/explorer/workflow_test.py
@@ -4,7 +4,9 @@
from dataclasses import dataclass
from unittest.mock import MagicMock
+from tests.tools import get_unittest_dataset_config
from trinity.common.workflows import MathWorkflow
+from trinity.common.workflows.workflow import Task
@dataclass
@@ -27,9 +29,19 @@ def test_math_workflow(self) -> None:
MockResponse("\nOnly thinking\n"),
MockResponse("ThinkingAnswer is not end1"),
]
- workflow = MathWorkflow(model=model, task_desc="1+1=", truth="2")
+ taskset_config = get_unittest_dataset_config("countdown")
+ task = Task(
+ workflow=MathWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "1+1=",
+ taskset_config.format.response_key: "2",
+ },
+ )
+ workflow = task.to_workflow(model=model)
experiences = workflow.run()
- print(experiences)
self.assertEqual(len(experiences), 9)
self.assertEqual(experiences[0].reward, 0.9)
self.assertEqual(experiences[1].reward, -0.1)
@@ -51,7 +63,18 @@ def test_math_fraction_workflow(self) -> None:
MockResponse(r"\boxed{\frac{1} {10}}"),
MockResponse(r"The answer is \boxed{\frac{40}{400}}"),
]
- workflow = MathWorkflow(model=model, task_desc=r"\frac{40}{400}", truth=r"\frac{40}{400}")
+ taskset_config = get_unittest_dataset_config("countdown")
+ task = Task(
+ workflow=MathWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: r"\frac{40}{400}",
+ taskset_config.format.response_key: r"\frac{40}{400}",
+ },
+ )
+ workflow = task.to_workflow(model=model)
experiences = workflow.run()
self.assertEqual(len(experiences), 6)
self.assertEqual(experiences[0].reward, 0.9)
@@ -68,11 +91,18 @@ def test_math_complex_workflow(self) -> None:
r"$\boxed{\dfrac{108 + 31\sqrt{5}}{216}} \quad \text{and} \quad \boxed{\dfrac{108 - 31\sqrt{5}}{216}}$"
),
]
- workflow = MathWorkflow(
- model=model,
- task_desc="",
- truth=r"$x_{1}=\frac{1}{2}+\frac{31\sqrt{5}}{216},\quadx_{2}=\frac{1}{2}-\frac{31\sqrt{5}}{216}$",
+ taskset_config = get_unittest_dataset_config("countdown")
+ task = Task(
+ workflow=MathWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "",
+ taskset_config.format.response_key: r"$x_{1}=\frac{1}{2}+\frac{31\sqrt{5}}{216},\quadx_{2}=\frac{1}{2}-\frac{31\sqrt{5}}{216}$",
+ },
)
+ workflow = task.to_workflow(model=model)
experiences = workflow.run()
self.assertEqual(len(experiences), 1)
self.assertEqual(experiences[0].reward, 0.9)
@@ -84,11 +114,18 @@ def test_gsm8k_workflow(self) -> None:
MockResponse(" 36.0 "),
MockResponse("Kim's total points are 6 + 30 = 36 "),
]
- workflow = MathWorkflow(
- model=model,
- task_desc="",
- truth=r"36",
+ taskset_config = get_unittest_dataset_config("countdown")
+ task = Task(
+ workflow=MathWorkflow,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "",
+ taskset_config.format.response_key: r"36",
+ },
)
+ workflow = task.to_workflow(model=model)
experiences = workflow.run()
# self.assertEqual(len(experiences), 1)
self.assertEqual(experiences[0].reward, 1.1)
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index aee283bccb..0cf3ba3cf9 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -18,7 +18,7 @@ buffer:
taskset:
name: taskset
storage_type: file
- path: ''
+ path: 'placeholder'
split: 'train'
default_workflow_type: ''
default_reward_fn_type: ''
@@ -26,14 +26,11 @@ explorer:
engine_type: vllm_async
engine_num: 2
runner_num: 4
- repeat_times: 1
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
backend: nccl
use_ray: false
use_v1: true
diff --git a/tests/test_data/template.yaml b/tests/test_data/template.yaml
index 2ac9addf28..21fe407842 100644
--- a/tests/test_data/template.yaml
+++ b/tests/test_data/template.yaml
@@ -25,6 +25,4 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 1.0
seed: 42
- logprobs: 0
diff --git a/tests/tools.py b/tests/tools.py
index be415fba75..a650638a0d 100644
--- a/tests/tools.py
+++ b/tests/tools.py
@@ -6,7 +6,13 @@
import ray
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
-from trinity.common.config import Config, FormatConfig, StorageConfig, load_config
+from trinity.common.config import (
+ Config,
+ FormatConfig,
+ GenerationConfig,
+ StorageConfig,
+ load_config,
+)
def get_template_config() -> Config:
@@ -45,6 +51,11 @@ def get_unittest_dataset_config(
prompt_key="question",
response_key="answer",
),
+ rollout_args=GenerationConfig(
+ repeat_times=1,
+ temperature=1.0,
+ logprobs=0,
+ ),
default_workflow_type="math_workflow",
default_reward_fn_type="countdown_reward",
)
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index 62f4abf745..b0ee7cc089 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -26,7 +26,7 @@ def setUp(self):
self.config.global_config.batch_size = 4
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
- self.config.explorer.repeat_times = 3
+ self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 3
self.config.explorer.use_v1 = False
self.config.monitor.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 0ca202b2fa..316d3ae297 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -11,8 +11,7 @@
from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType
from trinity.common.experience import Experience
from trinity.common.rewards import REWARD_FUNCTIONS
-from trinity.common.task import Task
-from trinity.common.workflows import WORKFLOWS
+from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.registry import Registry
FILE_READERS = Registry("file_readers")
@@ -173,6 +172,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
@FILE_READERS.register_module("rollout")
class RolloutDataReader(BufferReader):
def __init__(self, meta: StorageConfig, config: BufferConfig):
+ self.meta = meta
self.name = meta.name
self.split = meta.split
subset_name = meta.subset_name
@@ -206,8 +206,6 @@ def read(self, strategy: Optional[ReadStrategy] = None):
if self.index >= len(self.dataset) * self.total_epochs:
raise StopIteration
sample = self.dataset[self.index % len(self.dataset)]
- task_desc = sample[self.prompt_key] if self.prompt_key in sample else None
- truth = sample[self.response_key] if self.response_key in sample else None
workflow_class = (
WORKFLOWS.get(sample[self.workflow_key])
if self.workflow_key in sample
@@ -220,12 +218,12 @@ def read(self, strategy: Optional[ReadStrategy] = None):
)
assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required"
task = Task(
- task_desc=task_desc,
- truth=truth,
workflow=workflow_class,
+ format_args=self.meta.format,
+ rollout_args=self.meta.rollout_args,
+ is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
- raw=sample,
- task_type=self.task_type,
+ raw_task=sample,
)
self.index += 1
if self.task_type == TaskType.EVAL and self.index == len(self.dataset):
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 9c9ea72017..2e8e830007 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -28,7 +28,10 @@ class FormatConfig:
prompt_key: str = "prompt"
response_key: str = "response"
messages_key: str = "message"
- chat_template: str = ""
+ chat_template: str = "" # deprecated
+
+ system_prompt: Optional[str] = None
+ reply_prefix: Optional[str] = None
# for sample-level task controlling
reward_fn_key: str = ""
@@ -47,6 +50,17 @@ class FormatConfig:
label_key: str = ""
+@dataclass
+class GenerationConfig:
+ # repeat each task for `repeat_times` times (for GPRO-like algorithms)
+ repeat_times: int = 1
+
+ temperature: float = 1.0
+ top_p: float = 1.0
+ top_k: int = -1
+ logprobs: int = 0 # vLLM return `logprobs + 1` elements
+
+
@dataclass
class StorageConfig:
"""Storage config."""
@@ -63,13 +77,11 @@ class StorageConfig:
index: int = 0
# used for algorithm_type is None
- task_type: TaskType = TaskType.EXPLORE
+ task_type: TaskType = TaskType.EXPLORE # automatically set
default_workflow_type: Optional[str] = None
default_reward_fn_type: Optional[str] = None
total_epochs: int = 1 # automatically set
- # used for algorithm_type is None and TaskType.EVAL
- eval_repeat_times: int = 1 # TODO
- eval_temperature: float = 0.1 # TODO
+ rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
@dataclass
@@ -138,8 +150,11 @@ class ExplorerInput:
taskset: StorageConfig = field(default_factory=StorageConfig)
eval_tasksets: List[StorageConfig] = field(default_factory=list)
+ # The following args provide default values for the corresponding args in `taskset` and `eval_tasksets`
default_workflow_type: Optional[str] = None
default_reward_fn_type: Optional[str] = None
+ system_prompt: Optional[str] = None
+ reply_prefix: Optional[str] = None
@dataclass
@@ -180,9 +195,6 @@ class ExplorerConfig:
# For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num`
runner_num: int = 1
- # repeat each task for `repeat_times` times (for GPRO-like algorithms)
- repeat_times: int = 1
-
# for rollout tokneize
chat_template: Optional[str] = None
@@ -191,11 +203,7 @@ class ExplorerConfig:
enable_prefix_caching: bool = False
enforce_eager: bool = True
dtype: str = "bfloat16"
- temperature: float = 0.0
- top_p: float = 1.0
- top_k: int = -1
seed: int = 42
- logprobs: int = 0 # vLLM return `logprobs + 1` elements
backend: str = "nccl"
use_ray: bool = False
gpu_memory_utilization: float = 0.9
@@ -327,10 +335,12 @@ def _check_interval(self) -> None:
def _check_buffer(self) -> None: # noqa: C901
# check explorer_input
- if self.mode != "train" and self.buffer.explorer_input.taskset.path is None:
+ if self.mode != "train" and not self.buffer.explorer_input.taskset.path:
raise ValueError(
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
)
+ if not self.buffer.explorer_input.taskset.name:
+ self.buffer.explorer_input.taskset.name = "taskset"
self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE
self.buffer.explorer_input.taskset.total_epochs = self.global_config.total_epochs
if self.buffer.explorer_input.taskset.default_workflow_type is None:
@@ -341,13 +351,33 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.explorer_input.taskset.default_reward_fn_type = (
self.buffer.explorer_input.default_reward_fn_type
)
+ if self.buffer.explorer_input.taskset.format.system_prompt is None:
+ self.buffer.explorer_input.taskset.format.system_prompt = (
+ self.buffer.explorer_input.system_prompt
+ )
+ if self.buffer.explorer_input.taskset.format.reply_prefix is None:
+ self.buffer.explorer_input.taskset.format.reply_prefix = (
+ self.buffer.explorer_input.reply_prefix
+ )
- for dataset in self.buffer.explorer_input.eval_tasksets:
+ remained_tasksets = []
+ for idx, dataset in enumerate(self.buffer.explorer_input.eval_tasksets):
+ if not dataset.path:
+ logger.warning(f"Eval dataset [{dataset}]'s path is not configured. Skip.")
+ continue
dataset.task_type = TaskType.EVAL
+ if not dataset.name:
+ dataset.name = f"eval_taskset_{idx}"
if dataset.default_workflow_type is None:
dataset.default_workflow_type = self.buffer.explorer_input.default_workflow_type
if dataset.default_reward_fn_type is None:
dataset.default_reward_fn_type = self.buffer.explorer_input.default_reward_fn_type
+ if dataset.format.system_prompt is None:
+ dataset.format.system_prompt = self.buffer.explorer_input.system_prompt
+ if dataset.format.reply_prefix is None:
+ dataset.format.reply_prefix = self.buffer.explorer_input.reply_prefix
+ remained_tasksets.append(dataset)
+ self.buffer.explorer_input.eval_tasksets = remained_tasksets
# check trainer_input.experience_buffer
if self.mode == "both":
@@ -387,7 +417,10 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT
# set read_batch_size / pad_token_id / tokenizer_path
- self.buffer.read_batch_size = self.global_config.batch_size * self.explorer.repeat_times
+ self.buffer.read_batch_size = (
+ self.global_config.batch_size
+ * self.buffer.explorer_input.taskset.rollout_args.repeat_times
+ )
if self.buffer.pad_token_id is None:
from transformers import AutoTokenizer
diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py
index 2f9249e0e9..fc39c3c303 100644
--- a/trinity/common/models/vllm_async_model.py
+++ b/trinity/common/models/vllm_async_model.py
@@ -47,6 +47,7 @@ def __init__(
self.use_v1 = config.explorer.use_v1
if config.explorer.tensor_parallel_size != 1:
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+ os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.explorer.bundle_indices
if not vllm.envs.is_set("VLLM_USE_V1"):
self.logger.info(f"Using vLLM v{int(config.explorer.use_v1)} engine")
os.environ["VLLM_USE_V1"] = str(int(config.explorer.use_v1))
@@ -55,15 +56,15 @@ def __init__(
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
self.default_sampling_params = vllm.SamplingParams(
- n=config.explorer.repeat_times,
- temperature=config.explorer.temperature,
+ n=1,
+ temperature=0.0,
max_tokens=config.model.max_response_tokens,
min_tokens=1,
truncate_prompt_tokens=config.model.max_prompt_tokens,
skip_special_tokens=True,
include_stop_str_in_output=False,
output_kind=RequestOutputKind.FINAL_ONLY,
- logprobs=config.explorer.logprobs,
+ logprobs=0,
)
self.enable_thinking = config.model.enable_thinking
self.request_id = 0
diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py
index c4baf567a6..1e3efe6d22 100644
--- a/trinity/common/models/vllm_model.py
+++ b/trinity/common/models/vllm_model.py
@@ -33,6 +33,9 @@ class vLLMRolloutModel(InferenceModel):
def __init__(self, config: Config, **kwargs):
self.logger = get_logger(__name__)
self.config = config
+ if config.explorer.tensor_parallel_size != 1:
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+ os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.explorer.bundle_indices
if not vllm.envs.is_set("VLLM_USE_V1"):
self.logger.info(f"Using vLLM v{int(config.explorer.use_v1)} engine")
os.environ["VLLM_USE_V1"] = str(int(config.explorer.use_v1))
@@ -41,14 +44,14 @@ def __init__(self, config: Config, **kwargs):
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
self.default_sampling_params = SamplingParams(
- n=config.explorer.repeat_times,
- temperature=config.explorer.temperature,
+ n=1,
+ temperature=0.0,
max_tokens=config.model.max_response_tokens,
min_tokens=1,
truncate_prompt_tokens=config.model.max_prompt_tokens,
skip_special_tokens=True,
include_stop_str_in_output=False,
- logprobs=config.explorer.logprobs,
+ logprobs=0,
)
self.llm = LLM(
# TODO: check checkpoint path
@@ -151,7 +154,7 @@ def generate(self, prompts: List[str], **kwargs) -> List:
Example:
- >>> # config.explorer.repeat_times == 2 or kwargs["repeat_times"] == 2
+ >>> # config.buffer.explorer_input.taskset.rollout_args.repeat_times == 2 or kwargs["repeat_times"] == 2
>>>
>>> prompts = [
>>> "Hello, world!",
@@ -172,7 +175,7 @@ def generate(self, prompts: List[str], **kwargs) -> List:
)
experiences = []
for output in outputs:
- for i in range(self.config.explorer.repeat_times):
+ for i in range(self.config.buffer.explorer_input.taskset.rollout_args.repeat_times):
experiences.append(
Experience(
tokens=torch.cat(
diff --git a/trinity/common/task.py b/trinity/common/task.py
deleted file mode 100644
index b0582ab936..0000000000
--- a/trinity/common/task.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# -*- coding: utf-8 -*-
-"""Task Class."""
-from __future__ import annotations
-
-from dataclasses import dataclass
-from typing import Any, Optional, Type
-
-from trinity.common.config import Config
-from trinity.common.constants import TaskType
-from trinity.common.rewards.reward_fn import RewardFn
-from trinity.common.workflows.workflow import Workflow
-from trinity.utils.log import get_logger
-
-logger = get_logger(__name__)
-
-
-@dataclass
-class Task:
- """A Task class that defines a task and its associated reward function / workflow."""
-
- task_desc: str
- workflow: Type[Workflow]
- reward_fn: Optional[Type[RewardFn]] = None
- truth: Optional[str] = None
- raw: Optional[dict] = None # The raw data sample
- task_type: Optional[TaskType] = None
-
- def to_workflow(self, model: Any, config: Config) -> Workflow:
- """Convert the task to a workflow.
-
- Args:
- model (ModelWrapper): The rollout model for the workflow.
- config (Config): The global configuration.
-
- Returns:
- Workflow: The generated workflow object.
- """
- if self.task_type == TaskType.EVAL:
- repeat_times = 1
- else:
- repeat_times = config.explorer.repeat_times
- return self.workflow(
- model=model,
- task_desc=self.task_desc,
- truth=self.truth,
- reward_fn=self.reward_fn,
- raw=self.raw,
- repeat_times=repeat_times,
- config=config,
- is_eval=self.task_type == TaskType.EVAL,
- )
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index d44440c826..944037ac2e 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -301,10 +301,14 @@ def synchronize_config(self, config: Config) -> None:
self.trainer.default_local_dir = config.model.checkpoint_path
self.trainer.sft_warmup_steps = config.trainer.sft_warmup_steps
self.actor_rollout_ref.actor.ppo_mini_batch_size = config.global_config.batch_size
- self.actor_rollout_ref.rollout.temperature = config.explorer.temperature
- self.actor_rollout_ref.rollout.n = config.explorer.repeat_times
+ self.actor_rollout_ref.rollout.temperature = (
+ config.buffer.explorer_input.taskset.rollout_args.temperature
+ )
+ self.actor_rollout_ref.rollout.n = (
+ config.buffer.explorer_input.taskset.rollout_args.repeat_times
+ )
self.critic.ppo_mini_batch_size = config.global_config.batch_size
- self.critic.rollout_n = config.explorer.repeat_times
+ self.critic.rollout_n = self.actor_rollout_ref.rollout.n
self.actor_rollout_ref.actor.algorithm_type = config.trainer.algorithm_type
if config.trainer.algorithm_type == AlgorithmType.PPO:
diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py
index a8bcd886a2..92bf29a64e 100644
--- a/trinity/common/workflows/__init__.py
+++ b/trinity/common/workflows/__init__.py
@@ -3,9 +3,10 @@
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
-from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
+from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task
__all__ = [
+ "Task",
"WORKFLOWS",
"SimpleWorkflow",
"MathWorkflow",
diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
index a7e45e4a7f..9f8b389858 100644
--- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
+++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
@@ -3,7 +3,7 @@
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
-from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow
+from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task
EXAMPLE_PROMPT = """
Observation:
@@ -96,13 +96,17 @@ def parse_action(response):
class AlfworldWorkflow(MultiTurnWorkflow):
"""A workflow for alfworld task."""
- def __init__(self, model: ModelWrapper, **kwargs):
- super().__init__(model)
- self.system_prompt = kwargs.get("system_prompt", None) # Unuse here
- self.task_desc: str = kwargs.get("task_desc")
- self.truth = kwargs.get("truth") # Unuse here
- self.reward_fn = None # Unuse here
- self.repeat_times = kwargs.get("repeat_times", 1)
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ )
+ self.task_desc = task.task_desc or "0"
+ self.repeat_times = task.rollout_args.repeat_times
self.max_env_steps = 30
def get_model_response(self, messages):
diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py
index 06a0748f4c..b3669f01d0 100644
--- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py
+++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py
@@ -4,7 +4,7 @@
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
-from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow
+from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task
SCIWORLD_SYSTEM_PROMPT = """
You are an agent, you job is to do some scientific experiment in a virtual test-based environments.
@@ -59,13 +59,17 @@ def parse_action(response):
class SciWorldWorkflow(MultiTurnWorkflow):
"""A workflow for sciworld task."""
- def __init__(self, model: ModelWrapper, **kwargs):
- super().__init__(model)
- self.system_prompt = kwargs.get("system_prompt", None) # Unuse here
- self.task_desc: str = kwargs.get("task_desc")
- self.truth = kwargs.get("truth") # Unuse here
- self.reward_fn = None # Unuse here
- self.repeat_times = kwargs.get("repeat_times", 1)
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ )
+ self.task_desc = task.task_desc or "0"
+ self.repeat_times = task.rollout_args.repeat_times
self.max_env_steps = 30 # should be less than 100
def get_model_response(self, messages):
diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py
index 741fcd4b05..55035fd7b4 100644
--- a/trinity/common/workflows/envs/webshop/webshop_workflow.py
+++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py
@@ -3,7 +3,7 @@
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
-from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow
+from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task
SPARSE_REWARD = False
@@ -181,13 +181,17 @@ def validate_action(action, available_actions):
class WebShopWorkflow(MultiTurnWorkflow):
"""A workflow for webshop task."""
- def __init__(self, model: ModelWrapper, **kwargs):
- super().__init__(model)
- self.system_prompt = kwargs.get("system_prompt", None) # Unuse here
- self.task_desc: str = kwargs.get("task_desc", "0")
- self.truth = kwargs.get("truth") # Unuse here
- self.reward_fn = None # Unuse here
- self.repeat_times = kwargs.get("repeat_times", 1)
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ )
+ self.task_desc = task.task_desc or "0"
+ self.repeat_times = task.rollout_args.repeat_times
self.max_env_steps = 15
def get_model_response(self, messages):
diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py
index 4fa3a7c2f5..905fe2e5b8 100644
--- a/trinity/common/workflows/workflow.py
+++ b/trinity/common/workflows/workflow.py
@@ -1,11 +1,15 @@
# -*- coding: utf-8 -*-
"""Base Workflow Class"""
+from __future__ import annotations
+
from abc import ABC, abstractmethod
-from typing import List
+from dataclasses import asdict, dataclass, field
+from typing import Any, List, Optional, Type, Union
import torch
+from trinity.common.config import FormatConfig, GenerationConfig
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn
@@ -18,13 +22,53 @@
WORKFLOWS = Registry("workflows")
+@dataclass
+class Task:
+ """A Task class that defines a task and its associated reward function / workflow."""
+
+ workflow: Type[Workflow]
+ format_args: FormatConfig
+ rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
+ is_eval: bool = False
+ reward_fn: Optional[Type[RewardFn]] = None
+ raw_task: Optional[dict] = None # The raw data sample
+
+ def to_workflow(self, model: Any) -> Workflow:
+ """Convert the task to a workflow.
+
+ Args:
+ model (ModelWrapper): The rollout model for the workflow.
+
+ Returns:
+ Workflow: The generated workflow object.
+ """
+ return self.workflow(
+ model=model,
+ task=self,
+ )
+
+ @property
+ def task_desc(self) -> Union[str, None]:
+ prompt_key = self.format_args.prompt_key
+ return self.raw_task[prompt_key] if prompt_key in self.raw_task else None # type: ignore
+
+ @property
+ def truth(self) -> Union[str, None]:
+ response_key = self.format_args.response_key
+ return self.raw_task[response_key] if response_key in self.raw_task else None # type: ignore
+
+
class Workflow(ABC):
"""The base workflow class.
A workflow is a runnable object which generates a list of experiences.
"""
- def __init__(self, model: ModelWrapper, **kwargs):
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ ):
self.model = model
@abstractmethod
@@ -37,8 +81,15 @@ class MultiTurnWorkflow(Workflow):
The base workflow class for multi-turn tasks.
"""
- def __init__(self, model: ModelWrapper, **kwargs):
- super().__init__(model)
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ )
@abstractmethod
def run(self) -> List[Experience]:
@@ -81,34 +132,46 @@ class SimpleWorkflow(Workflow):
def __init__(
self,
model: ModelWrapper,
- **kwargs,
+ task: Task,
):
- super().__init__(model)
- self.system_prompt = kwargs.get("system_prompt", None)
- self.reply_prefix = kwargs.get("reply_prefix", None) # TODO: add reply_prefix
- self.task_desc = kwargs.get("task_desc")
- self.truth = kwargs.get("truth")
- self.reward_fn = kwargs.get("reward_fn")
- if isinstance(self.reward_fn, type) and issubclass(self.reward_fn, RewardFn):
- self.reward_fn = self.reward_fn()
+ super().__init__(
+ model=model,
+ task=task,
+ )
+ self.format_args = task.format_args
+ self.system_prompt = task.format_args.system_prompt
+ self.reply_prefix = task.format_args.reply_prefix
+
+ self.raw_task = task.raw_task
+ self.task_desc = task.task_desc
+ self.truth = task.truth
+
+ reward_fn = task.reward_fn
+ if isinstance(reward_fn, type) and issubclass(reward_fn, RewardFn):
+ self.reward_fn: RewardFn = reward_fn()
else:
raise ValueError("`reward_fn` must be a subclass of `RewardFn`")
- # Rollout n times
- self.repeat_times = kwargs.get("repeat_times", 1)
- self.is_eval = kwargs.get("is_eval", False)
+ # Rollout args
+ rollout_args = asdict(task.rollout_args)
+ rollout_args["n"] = rollout_args["repeat_times"]
+ self.rollout_args = rollout_args
+ self.is_eval = task.is_eval
- def run(self) -> List[Experience]:
- # TODO: Optimize the generate function
+ def format_messages(self):
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": self.task_desc})
if self.reply_prefix:
messages.append({"role": "assistant", "content": self.reply_prefix})
+ return messages
+
+ def run(self) -> List[Experience]:
+ # TODO: Optimize the generate function
+ messages = self.format_messages()
logger.debug("start chat")
- n = 1 if self.is_eval else self.repeat_times
- responses = self.model.chat(messages, n=n)
+ responses = self.model.chat(messages, **self.rollout_args)
for response in responses:
reward = self.reward_fn( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
@@ -134,18 +197,16 @@ class MathWorkflow(SimpleWorkflow):
def __init__(
self,
model: ModelWrapper,
- **kwargs,
+ task: Task,
):
- if kwargs.get("reward_fn", None) is None:
- kwargs["reward_fn"] = MathRewardFn
- if kwargs["reward_fn"] == MathRewardFn and kwargs.get("system_prompt", None) is None:
- kwargs[
- "system_prompt"
- ] = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e.,
+ if task.reward_fn is None:
+ task.reward_fn = MathRewardFn
+ if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None:
+ task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e.,
reasoning process here
answer here .
"""
super().__init__(
- model,
- **kwargs,
+ model=model,
+ task=task,
)
diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py
index 5805f71907..c7b3f39f1f 100644
--- a/trinity/explorer/runner_pool.py
+++ b/trinity/explorer/runner_pool.py
@@ -5,7 +5,7 @@
import ray
from trinity.common.config import Config
-from trinity.common.task import Task
+from trinity.common.workflows import Task
from trinity.explorer.workflow_runner import Status, WorkflowRunner
from trinity.utils.log import get_logger
diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py
index 86e4003df4..87e80aaf9b 100644
--- a/trinity/explorer/workflow_runner.py
+++ b/trinity/explorer/workflow_runner.py
@@ -11,10 +11,9 @@
from trinity.buffer import get_buffer_writer
from trinity.common.config import Config
-from trinity.common.constants import TaskType
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel, ModelWrapper
-from trinity.common.task import Task
+from trinity.common.workflows import Task
from trinity.utils.log import get_logger
@@ -48,7 +47,7 @@ def _run_task(self, task: Task) -> List[Experience]:
"""Init workflow from the task and run it."""
if task.workflow is None:
raise ValueError("Workflow is not set in the task.")
- workflow = task.to_workflow(self.model_wrapper, self.config)
+ workflow = task.to_workflow(self.model_wrapper)
return workflow.run()
def run_task(self, task: Task) -> Status:
@@ -77,7 +76,7 @@ def run_task(self, task: Task) -> Status:
if metrics:
for k, v in metrics.items():
metric[k] = sum(v) / len(v) # type: ignore
- if not task.task_type == TaskType.EVAL:
+ if not task.is_eval:
self.experience_buffer.write(exps)
return Status(True, metric=metric)
except Exception as e:
diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py
index 1f55b3c702..1f1e48d61b 100644
--- a/trinity/manager/config_manager.py
+++ b/trinity/manager/config_manager.py
@@ -73,10 +73,12 @@ def _init_default_config(self):
"taskset_prompt_key": "question",
"taskset_response_key": "answer",
# Eval Taskset Configs
- # TODO
- # Task Workflow Configs
+ "_eval_tasksets_num": 0,
+ # Explorer Input Configs
"default_workflow_type": "math_workflow",
"default_reward_fn_type": "math_reward",
+ "system_prompt": None,
+ "reply_prefix": None,
# Experience Buffer Configs
"_dpo_storage_type": StorageType.FILE.value,
"_not_dpo_storage_type": StorageType.QUEUE.value,
@@ -197,6 +199,11 @@ def reset_session_state(self):
def maintain_session_state(self):
for key in self.default_config:
st.session_state[key] = st.session_state[key]
+ eavl_dataset_keys = ["name", "path", "subset_name", "split", "prompt_key", "response_key"]
+ for idx in range(st.session_state["_eval_tasksets_num"]):
+ for key in eavl_dataset_keys:
+ full_key = f"eval_taskset_{idx}_{key}"
+ st.session_state[full_key] = st.session_state[full_key]
def _set_project(self):
st.text_input("Project", key="project")
@@ -311,11 +318,29 @@ def _set_taskset_path(self):
self.unfinished_fields.add("taskset_path")
st.warning("Please input taskset path.")
+ def _set_system_prompt(self):
+ st.text_area(
+ "System Prompt",
+ key="system_prompt",
+ placeholder="System prompt is used to guide the model behavior.",
+ )
+
+ def _set_reply_prefix(self):
+ st.text_area(
+ "Assistant Reply Prefix",
+ key="reply_prefix",
+ placeholder="""Assistant reply prefix is used to specify the initial content of model reply, """
+ """and a common setting is: \nLet me solve this step by step. """,
+ )
+
def _set_taskset_args(self):
if st.session_state["taskset_path"] and "://" not in st.session_state["taskset_path"]:
subset_name_col, split_col = st.columns(2)
subset_name_col.text_input(
- "Subset Name :orange-badge[(Needs review)]", key="taskset_subset_name"
+ "Subset Name :orange-badge[(Needs review)]",
+ key="taskset_subset_name",
+ help="The subset name used for `datasets.load_datasets`, see "
+ "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.",
)
split_col.text_input("Train Split :orange-badge[(Needs review)]", key="taskset_split")
prompt_key_col, response_key_col = st.columns(2)
@@ -326,6 +351,49 @@ def _set_taskset_args(self):
"Response Key :orange-badge[(Needs review)]", key="taskset_response_key"
)
+ def _set_eval_taskset_idx(self, idx):
+ st.text_input(
+ "Taskset Name",
+ key=f"eval_taskset_{idx}_name",
+ )
+ st.text_input(
+ "Eval Taskset Path",
+ key=f"eval_taskset_{idx}_path",
+ )
+ if not st.session_state[f"eval_taskset_{idx}_path"].strip():
+ st.warning("Please input the taskset path, or it will be ignored.")
+ subset_name_col, split_col = st.columns(2)
+ subset_name_col.text_input(
+ "Subset Name :orange-badge[(Needs review)]",
+ key=f"eval_taskset_{idx}_subset_name",
+ help="The subset name used for `datasets.load_datasets`, see "
+ "[here](https://huggingface.co/docs/datasets/v3.5.0/en/package_reference/loading_methods#datasets.load_dataset.name) for details.",
+ )
+ split_col.text_input(
+ "Eval Split :orange-badge[(Needs review)]",
+ key=f"eval_taskset_{idx}_split",
+ )
+ prompt_key_col, response_key_col = st.columns(2)
+ prompt_key_col.text_input(
+ "Prompt Key :orange-badge[(Needs review)]",
+ key=f"eval_taskset_{idx}_prompt_key",
+ )
+ response_key_col.text_input(
+ "Response Key :orange-badge[(Needs review)]",
+ key=f"eval_taskset_{idx}_response_key",
+ )
+
+ def _set_eval_tasksets(self):
+ if st.button("Add Eval Taskset"):
+ st.session_state["_eval_tasksets_num"] += 1
+ if st.session_state["_eval_tasksets_num"] > 0:
+ tabs = st.tabs(
+ [f"Eval Taskset {i + 1}" for i in range(st.session_state["_eval_tasksets_num"])]
+ )
+ for idx, tab in enumerate(tabs):
+ with tab:
+ self._set_eval_taskset_idx(idx)
+
def _set_default_workflow_type(self):
st.selectbox(
"Default Workflow Type :orange-badge[(Needs review)]",
@@ -1131,6 +1199,10 @@ def _expert_buffer_part(self):
self._set_configs_with_st_columns(["total_epochs", "train_batch_size"])
self._check_train_batch_size()
+ self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"])
+ self._set_system_prompt()
+ self._set_reply_prefix()
+
if st.session_state["algorithm_type"] != AlgorithmType.DPO.value:
with st.expander("Taskset Configs", expanded=True):
self._set_taskset_path()
@@ -1141,8 +1213,7 @@ def _expert_buffer_part(self):
self._set_dpo_dataset_kwargs()
with st.expander("Eval Tasksets Configs", expanded=True):
- # TODO:
- pass
+ self._set_eval_tasksets()
with st.expander("SFT Dataset Configs"):
self._set_sft_warmup_dataset_path()
@@ -1153,8 +1224,6 @@ def _expert_buffer_part(self):
self._set_storage_type()
self._set_experience_buffer_path()
- self._set_configs_with_st_columns(["default_workflow_type", "default_reward_fn_type"])
-
self.buffer_advanced_tab = st.expander("Advanced Config")
with self.buffer_advanced_tab:
self._set_configs_with_st_columns(["buffer_max_retry_times", "max_retry_interval"])
@@ -1614,10 +1683,19 @@ def generate_config(self):
"prompt_key": st.session_state["taskset_prompt_key"],
"response_key": st.session_state["taskset_response_key"],
},
+ "rollout_args": {
+ "repeat_times": st.session_state["repeat_times"],
+ "temperature": st.session_state["temperature"],
+ "top_p": st.session_state["top_p"],
+ "top_k": st.session_state["top_k"],
+ "logprobs": st.session_state["logprobs"],
+ },
},
"eval_tasksets": [], # TODO: add eval tasksets
"default_workflow_type": st.session_state["default_workflow_type"],
"default_reward_fn_type": st.session_state["default_reward_fn_type"],
+ "system_prompt": st.session_state["system_prompt"],
+ "reply_prefix": st.session_state["reply_prefix"],
},
"trainer_input": {
"experience_buffer": {
@@ -1631,26 +1709,21 @@ def generate_config(self):
"engine_type": st.session_state["engine_type"],
"engine_num": st.session_state["engine_num"],
"runner_num": st.session_state["runner_num"],
- "repeat_times": st.session_state["repeat_times"],
# "chat_template": None, # TODO: add chat template
"tensor_parallel_size": st.session_state["tensor_parallel_size"],
"enable_prefix_caching": st.session_state["enable_prefix_caching"],
"enforce_eager": st.session_state["enforce_eager"],
"dtype": st.session_state["dtype"],
- "temperature": st.session_state["temperature"],
- "top_p": st.session_state["top_p"], # TODO
- "top_k": st.session_state["top_k"], # TODO
"seed": st.session_state["seed"],
- "logprobs": st.session_state["logprobs"],
"backend": st.session_state["backend"],
- "use_ray": st.session_state["use_ray"], # TODO
- "gpu_memory_utilization": st.session_state["gpu_memory_utilization"], # TODO
- "enable_chunked_prefill": st.session_state["enable_chunked_prefill"], # TODO
+ "use_ray": st.session_state["use_ray"],
+ "gpu_memory_utilization": st.session_state["gpu_memory_utilization"],
+ "enable_chunked_prefill": st.session_state["enable_chunked_prefill"],
"use_v1": True,
"max_pending_requests": st.session_state["max_pending_requests"],
"max_waiting_steps": st.session_state["max_waiting_steps"],
- "max_timeout": st.session_state["max_timeout"], # TODO
- "max_retry_times": st.session_state["explorer_max_retry_times"], # TODO
+ "max_timeout": st.session_state["max_timeout"],
+ "max_retry_times": st.session_state["explorer_max_retry_times"],
},
"synchronizer": {
"sync_method": st.session_state["sync_method"],
@@ -1671,6 +1744,18 @@ def generate_config(self):
},
}
+ for idx in range(st.session_state["_eval_tasksets_num"]):
+ if st.session_state[f"eval_taskset_{idx}_path"].strip():
+ config["buffer"]["explorer_input"]["eval_tasksets"].append(
+ {
+ "name": st.session_state[f"eval_taskset_{idx}_name"],
+ "path": st.session_state[f"eval_taskset_{idx}_path"],
+ "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"],
+ "split": st.session_state[f"eval_taskset_{idx}_split"],
+ "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"],
+ "response_key": st.session_state[f"eval_taskset_{idx}_response_key"],
+ }
+ )
if st.session_state["algorithm_type"] == AlgorithmType.DPO.value:
experience_buffer = config["buffer"]["trainer_input"]["experience_buffer"]
experience_buffer["split"] = st.session_state["dpo_dataset_train_split"]