Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 1
temperature: 1.0
logprobs: 0
eval_tasksets: []
default_workflow_type: 'math_workflow'
default_reward_fn_type: 'countdown_reward'
Expand All @@ -123,6 +127,9 @@ buffer:
- `buffer.explorer_input.taskset.path`: The path to the taskset.
- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`.
- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`.
- `buffer.explorer_input.taskset.rollout_args.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`.
- `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`.
- `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`.
- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default.
- `buffer.explorer_input.default_workflow_type`: The default workflow type for `taskset` and `eval_tasksets`.
- `buffer.explorer_input.default_reward_fn_type`: The default reward function type for `taskset` and `eval_tasksets`.
Expand All @@ -145,10 +152,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 5
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand All @@ -162,10 +166,7 @@ explorer:
- `explorer.enable_prefix_caching`: Whether to enable prefix caching. Default is `False`.
- `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`.
- `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`.
- `explorer.temperature`: The temperature used in vLLM. Default is `1.0`.
- `explorer.seed`: The seed used in vLLM. Default is `42`.
- `explorer.logprobs`: The logprobs used in vLLM. Default is `0`.
- `explorer.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `5`.
- `explorer.use_ray`: Whether to use Ray. Default is `False`.
- `explorer.backend`: The backend used in vLLM. Default is `nccl`.
- `explorer.max_pending_requests`: The maximum number of pending requests. Default is `32`.
Expand Down
7 changes: 4 additions & 3 deletions examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
Expand All @@ -37,10 +41,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
7 changes: 4 additions & 3 deletions examples/async_gsm8k/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
Expand All @@ -36,10 +40,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
16 changes: 0 additions & 16 deletions examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,6 @@ buffer:
prompt_key: prompt
chosen_key: chosen
rejected_key: rejected
explorer:
engine_type: vllm_async
engine_num: 0
runner_num: 32
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 1 # NOTE
use_ray: false
backend: 'nccl'
max_pending_requests: 32
max_waiting_steps: 4
synchronizer:
sync_method: 'checkpoint'
sync_interval: 30
Expand Down
7 changes: 4 additions & 3 deletions examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ buffer:
path: 'scripts/data_prepare/alfworld_data'
format:
prompt_key: 'game_file'
rollout_args:
repeat_times: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'alfworld_workflow'
trainer_input:
experience_buffer:
Expand All @@ -33,10 +37,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
7 changes: 4 additions & 3 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 8
temperature: 1.0
logprobs: 0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
Expand Down Expand Up @@ -65,10 +69,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
7 changes: 4 additions & 3 deletions examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'gt_answer'
rollout_args:
repeat_times: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
Expand All @@ -35,10 +39,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
7 changes: 4 additions & 3 deletions examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ buffer:
path: 'scripts/data_prepare/sciworld_data'
format:
prompt_key: 'game_file'
rollout_args:
repeat_times: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'sciworld_workflow'
trainer_input:
experience_buffer:
Expand All @@ -33,10 +37,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
7 changes: 4 additions & 3 deletions examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ buffer:
path: 'scripts/data_prepare/webshop_data'
format:
prompt_key: 'task_id'
rollout_args:
repeat_times: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'webshop_workflow'
trainer_input:
experience_buffer:
Expand All @@ -33,10 +37,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
7 changes: 4 additions & 3 deletions examples/opmd_gsm8k/opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
Expand All @@ -34,10 +38,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
7 changes: 4 additions & 3 deletions examples/ppo_countdown/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ buffer:
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 5
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
default_reward_fn_type: 'countdown_reward'
trainer_input:
Expand All @@ -36,10 +40,7 @@ explorer:
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 5
use_ray: false
backend: 'nccl'
max_pending_requests: 32
Expand Down
15 changes: 8 additions & 7 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def get_model_path() -> str:
class BaseTestModelWrapper:
def test_generate(self):
prompts = ["Hello, world!", "Hello, my name is"]
results = self.model_wrapper.generate(prompts)
self.assertEqual(len(results), len(prompts) * self.config.explorer.repeat_times)
repeat_times = self.config.buffer.explorer_input.taskset.rollout_args.repeat_times
results = self.model_wrapper.generate(prompts, n=repeat_times, temperature=1.0)
self.assertEqual(len(results), len(prompts) * repeat_times)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like today?"},
Expand All @@ -96,8 +97,8 @@ def test_generate(self):
},
{"role": "user", "content": "OK, thanks!"},
]
results = self.model_wrapper.chat(messages)
self.assertEqual(len(results), self.config.explorer.repeat_times)
results = self.model_wrapper.chat(messages, n=repeat_times, temperature=1.0)
self.assertEqual(len(results), repeat_times)
for result in results:
input_logprobs = result.logprobs[: result.prompt_length]
output_logprobs = result.logprobs[result.prompt_length :]
Expand Down Expand Up @@ -135,7 +136,7 @@ def setUp(self):
self.config.explorer.engine_type = "vllm"
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.engine_num = 2
self.config.explorer.repeat_times = 2
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
Expand All @@ -149,7 +150,7 @@ def setUp(self):
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 2
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.repeat_times = 2
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
Expand All @@ -176,7 +177,7 @@ def setUp(self):
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 2
self.config.explorer.tensor_parallel_size = 2
self.config.explorer.repeat_times = 2
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
Expand Down
2 changes: 1 addition & 1 deletion tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setUp(self):
self.config.global_config.batch_size = 4
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.repeat_times = 2
self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.project = "Trinity-unittest"
self.config.model.checkpoint_path = get_checkpoint_path()
Expand Down
Loading