Skip to content

Commit f029cbd

Browse files
committed
update document
1 parent 9a4927f commit f029cbd

File tree

5 files changed

+212
-12
lines changed

5 files changed

+212
-12
lines changed
783 KB
Loading

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ model:
188188
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
189189
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode.
190190
- `repetition_penalty`: Repetition penalty factor. Default is `1.0`.
191-
- `lora_configs`: Optional LoRA configuration. If not specified, defaults to `null`. Currently, only one LoRA configuration is supported.
191+
- `lora_configs`: Optional LoRA configuration. If not specified, defaults to `null`. Currently, only one LoRA configuration is supported, and this configuration will not be applied if `tinker` is enabled.
192192
- `name`: Name of the LoRA. Default is `None`.
193193
- `path`: Path to the LoRA. Default is `None`.
194194
- `base_model_name`: Name of the base model for LoRA. If not specified, defaults to `None`.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ model:
178178
train_unembed: true
179179
```
180180

181-
- `model_path`: 被训练模型的路径。
181+
- `model_path`: 被训练模型的路径。如果启用了`tinker`,则该路径为本地 tokenizer 的路径。
182182
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
183183
- `custom_chat_template`: 可选的自定义 chat template 字符串格式。若未指定,系统会使用 tokenizer 的默认 chat template。
184184
- `chat_template_path`: 可选的 chat template 文件路径,类型通常为 jinja2;若设置,则覆盖 `custom_chat_template`。若未指定,系统会使用 tokenizer 的默认 chat template。
@@ -188,7 +188,7 @@ model:
188188
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
189189
- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。在 OpenAI API 模式下不生效。
190190
- `repetition_penalty`:重复惩罚因子。默认值为 `1.0`。
191-
- `lora_configs`:可选的 LoRA 配置。若未指定,则默认为 `null`。目前仅支持一个 LoRA 配置。
191+
- `lora_configs`:可选的 LoRA 配置。若未指定,则默认为 `null`。目前仅支持一个 LoRA 配置,并且如果启用了`tinker`,则不会使用此LoRA配置
192192
- `name`:LoRA 的名称。默认为 `None`。
193193
- `path`:LoRA 的路径。默认为 `None`。
194194
- `base_model_name`:LoRA 所基于的基础模型名称。若未指定,则默认为 `None`。

examples/tinker/README.md

Lines changed: 207 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ model:
2828
2929
### 3. Configuration Parameters Explained
3030
31-
- **`tinker`**: Optional Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings will be ignored.
31+
- **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings will be ignored.
3232
- **`enable`**: Whether to activate the Tinker backend. Default: `false`
3333
- **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config
3434
- **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32`
@@ -37,10 +37,211 @@ model:
3737
- **`train_attn`**: Whether to train the attention layers. Default: `true`
3838
- **`train_unembed`**: Whether to train the unembedding (output) layer. Default: `true`
3939

40-
## Usage Notes
4140

42-
Once configured, Trinity works with the Tinker backend just like it does with the standard veRL training backend, with two important limitations:
43-
1. **Entropy loss** is not consistent compared to veRL backends
44-
2. Algorithms that require **`compute_advantage_in_trainer=true`** are **not supported**
41+
## Usage
4542

46-
The complete configuration file can be found at [`tinker.yaml`](tinker.yaml).
43+
Once configured, Trinity works with the Tinker backend just like it does with the standard veRL backend. Start training with:
44+
45+
```bash
46+
trinity run --config tinker.yaml # Replace with your actual config file path
47+
```
48+
49+
### Important Limitations of the Tinker Backend
50+
51+
1. **Entropy loss** is not consistent compared to veRL backends.
52+
2. **Algorithms requiring `compute_advantage_in_trainer=true` are NOT supported**, including:
53+
- `PPOAlgorithm`
54+
- `ReinforcePlusPlusAlgorithm`
55+
- `RLOOAlgorithm`
56+
- `OnPolicyDistillAlgorithm`
57+
58+
> 💡 A complete example configuration file is available at [`tinker.yaml`](tinker.yaml).
59+
60+
61+
## Results on the Llama-3.2-3B Model
62+
63+
We trained the **Llama-3.2-3B** model on the **GSM8K** dataset using both the **Tinker** and **veRL** backends. Below are the full configuration files used in our experiments.
64+
65+
66+
<details><summary>Click to expand: Tinker Backend Configuration</summary>
67+
68+
```yaml
69+
mode: both
70+
project: Trinity-RFT-gsm8k
71+
group: alignment-tinker
72+
name: tinker-llama3.2-3B-off1
73+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
74+
algorithm:
75+
algorithm_type: grpo
76+
repeat_times: 8
77+
sample_strategy: default
78+
kl_loss_fn_args:
79+
kl_coef: 0.0
80+
optimizer:
81+
lr: 1.0e-05
82+
lr_warmup_steps_ratio: 0.0
83+
warmup_style: constant
84+
data_processor: {}
85+
model:
86+
model_path: meta-llama/Llama-3.2-3B
87+
max_prompt_tokens: 1024
88+
max_response_tokens: 2048
89+
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
90+
tinker:
91+
enable: true
92+
base_model: meta-llama/Llama-3.2-3B
93+
cluster:
94+
node_num: 1
95+
gpu_per_node: 8
96+
buffer:
97+
batch_size: 96
98+
total_epochs: 1
99+
explorer_input:
100+
taskset:
101+
name: taskset
102+
storage_type: file
103+
path: openai/gsm8k
104+
split: train
105+
subset_name: main
106+
format:
107+
prompt_key: question
108+
response_key: answer
109+
rollout_args:
110+
temperature: 1.0
111+
logprobs: 0
112+
eval_tasksets: []
113+
default_workflow_type: math_workflow
114+
trainer_input:
115+
experience_buffer:
116+
name: experience_buffer
117+
storage_type: queue
118+
replay_buffer:
119+
enable: false
120+
explorer:
121+
runner_per_model: 16
122+
rollout_model:
123+
engine_num: 4
124+
seed: 42
125+
auxiliary_models: []
126+
eval_interval: 1000
127+
trainer:
128+
save_interval: 100
129+
enable_preview: true
130+
grad_clip: 1.0
131+
max_token_len_per_gpu: 16384
132+
monitor:
133+
monitor_type: wandb
134+
synchronizer:
135+
sync_method: checkpoint
136+
sync_style: fixed
137+
sync_interval: 1
138+
sync_offset: 1
139+
sync_timeout: 1200
140+
```
141+
142+
</details>
143+
144+
145+
<details><summary>Click to expand: veRL Backend Configuration (LoRA)</summary>
146+
147+
```yaml
148+
mode: both
149+
project: Trinity-RFT-gsm8k
150+
group: alignment-tinker
151+
name: verl-llama3.2-3B-lora-off1
152+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
153+
algorithm:
154+
algorithm_type: grpo
155+
repeat_times: 8
156+
sample_strategy: default
157+
kl_loss_fn_args:
158+
kl_coef: 0.0
159+
optimizer:
160+
lr: 1.0e-05
161+
lr_warmup_steps_ratio: 0.0
162+
warmup_style: constant
163+
data_processor: {}
164+
model:
165+
model_path: meta-llama/Llama-3.2-3B
166+
max_prompt_tokens: 1024
167+
max_response_tokens: 2048
168+
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
169+
lora_configs:
170+
- name: lora
171+
lora_rank: 32
172+
lora_alpha: 32
173+
cluster:
174+
node_num: 1
175+
gpu_per_node: 8
176+
buffer:
177+
batch_size: 96
178+
total_epochs: 1
179+
explorer_input:
180+
taskset:
181+
name: taskset
182+
storage_type: file
183+
path: openai/gsm8k
184+
split: train
185+
subset_name: main
186+
format:
187+
prompt_key: question
188+
response_key: answer
189+
rollout_args:
190+
temperature: 1.0
191+
logprobs: 0
192+
eval_tasksets: []
193+
default_workflow_type: math_workflow
194+
trainer_input:
195+
experience_buffer:
196+
name: experience_buffer
197+
storage_type: queue
198+
replay_buffer:
199+
enable: false
200+
priority_fn: linear_decay
201+
reuse_cooldown_time: null
202+
priority_fn_args:
203+
decay: 2.0
204+
explorer:
205+
runner_per_model: 16
206+
rollout_model:
207+
engine_num: 4
208+
tensor_parallel_size: 1
209+
enforce_eager: false
210+
enable_prefix_caching: false
211+
enable_chunked_prefill: false
212+
gpu_memory_utilization: 0.9
213+
dtype: bfloat16
214+
seed: 42
215+
enable_thinking: false
216+
enable_history: false
217+
enable_openai_api: false
218+
enable_auto_tool_choice: false
219+
tool_call_parser: null
220+
reasoning_parser: null
221+
auxiliary_models: []
222+
eval_interval: 1000
223+
trainer:
224+
trainer_type: verl
225+
save_interval: 100
226+
enable_preview: true
227+
grad_clip: 1.0
228+
max_token_len_per_gpu: 16384
229+
monitor:
230+
monitor_type: wandb
231+
synchronizer:
232+
sync_method: checkpoint
233+
sync_style: fixed
234+
sync_interval: 1
235+
sync_offset: 1
236+
sync_timeout: 1200
237+
```
238+
239+
</details>
240+
241+
### Observations
242+
243+
Since **Llama-3.2-3B** is a base (non-instruct-tuned) model, it has limited ability to follow formatting instructions. Additionally, we trained for only **one epoch**. As a result, both backends achieved final rewards just slightly above **0.1**.
244+
245+
However, the training curves clearly show an **upward trend in reward**, indicating successful learning. The results are visualized below:
246+
247+
![Training Rewards on GSM8K](../../docs/sphinx_doc/assets/tinker-gsm8k.png)

examples/tinker/tinker.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ explorer:
5252
auxiliary_models: []
5353
eval_interval: 1000
5454
trainer:
55-
trainer_type: verl
5655
save_interval: 100
5756
enable_preview: true
5857
grad_clip: 1.0
@@ -62,7 +61,7 @@ monitor:
6261
synchronizer:
6362
sync_method: memory
6463
sync_style: fixed
65-
sync_interval: 2
64+
sync_interval: 1
6665
sync_timeout: 1200
6766
log:
68-
level: INFO
67+
level: INFO

0 commit comments

Comments
 (0)