You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/sphinx_doc/source/tutorial/trinity_configs.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -188,7 +188,7 @@ model:
188
188
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
189
189
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode.
190
190
- `repetition_penalty`: Repetition penalty factor. Default is `1.0`.
191
-
- `lora_configs`: Optional LoRA configuration. If not specified, defaults to `null`. Currently, only one LoRA configuration is supported.
191
+
- `lora_configs`: Optional LoRA configuration. If not specified, defaults to `null`. Currently, only one LoRA configuration is supported, and this configuration will not be applied if `tinker` is enabled.
192
192
- `name`: Name of the LoRA. Default is `None`.
193
193
- `path`: Path to the LoRA. Default is `None`.
194
194
- `base_model_name`: Name of the base model for LoRA. If not specified, defaults to `None`.
Copy file name to clipboardExpand all lines: examples/tinker/README.md
+207-6Lines changed: 207 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -28,7 +28,7 @@ model:
28
28
29
29
### 3. Configuration Parameters Explained
30
30
31
-
- **`tinker`**: Optional Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings will be ignored.
31
+
- **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings will be ignored.
32
32
- **`enable`**: Whether to activate the Tinker backend. Default: `false`
33
33
- **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config
34
34
- **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32`
@@ -37,10 +37,211 @@ model:
37
37
- **`train_attn`**: Whether to train the attention layers. Default: `true`
38
38
- **`train_unembed`**: Whether to train the unembedding (output) layer. Default: `true`
39
39
40
-
## Usage Notes
41
40
42
-
Once configured, Trinity works with the Tinker backend just like it does with the standard veRL training backend, with two important limitations:
43
-
1. **Entropy loss** is not consistent compared to veRL backends
44
-
2. Algorithms that require **`compute_advantage_in_trainer=true`** are **not supported**
41
+
## Usage
45
42
46
-
The complete configuration file can be found at [`tinker.yaml`](tinker.yaml).
43
+
Once configured, Trinity works with the Tinker backend just like it does with the standard veRL backend. Start training with:
44
+
45
+
```bash
46
+
trinity run --config tinker.yaml # Replace with your actual config file path
47
+
```
48
+
49
+
### Important Limitations of the Tinker Backend
50
+
51
+
1. **Entropy loss** is not consistent compared to veRL backends.
52
+
2. **Algorithms requiring `compute_advantage_in_trainer=true` are NOT supported**, including:
53
+
- `PPOAlgorithm`
54
+
- `ReinforcePlusPlusAlgorithm`
55
+
- `RLOOAlgorithm`
56
+
- `OnPolicyDistillAlgorithm`
57
+
58
+
> 💡 A complete example configuration file is available at [`tinker.yaml`](tinker.yaml).
59
+
60
+
61
+
## Results on the Llama-3.2-3B Model
62
+
63
+
We trained the **Llama-3.2-3B** model on the **GSM8K** dataset using both the **Tinker** and **veRL** backends. Below are the full configuration files used in our experiments.
64
+
65
+
66
+
<details><summary>Click to expand: Tinker Backend Configuration</summary>
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
90
+
tinker:
91
+
enable: true
92
+
base_model: meta-llama/Llama-3.2-3B
93
+
cluster:
94
+
node_num: 1
95
+
gpu_per_node: 8
96
+
buffer:
97
+
batch_size: 96
98
+
total_epochs: 1
99
+
explorer_input:
100
+
taskset:
101
+
name: taskset
102
+
storage_type: file
103
+
path: openai/gsm8k
104
+
split: train
105
+
subset_name: main
106
+
format:
107
+
prompt_key: question
108
+
response_key: answer
109
+
rollout_args:
110
+
temperature: 1.0
111
+
logprobs: 0
112
+
eval_tasksets: []
113
+
default_workflow_type: math_workflow
114
+
trainer_input:
115
+
experience_buffer:
116
+
name: experience_buffer
117
+
storage_type: queue
118
+
replay_buffer:
119
+
enable: false
120
+
explorer:
121
+
runner_per_model: 16
122
+
rollout_model:
123
+
engine_num: 4
124
+
seed: 42
125
+
auxiliary_models: []
126
+
eval_interval: 1000
127
+
trainer:
128
+
save_interval: 100
129
+
enable_preview: true
130
+
grad_clip: 1.0
131
+
max_token_len_per_gpu: 16384
132
+
monitor:
133
+
monitor_type: wandb
134
+
synchronizer:
135
+
sync_method: checkpoint
136
+
sync_style: fixed
137
+
sync_interval: 1
138
+
sync_offset: 1
139
+
sync_timeout: 1200
140
+
```
141
+
142
+
</details>
143
+
144
+
145
+
<details><summary>Click to expand: veRL Backend Configuration (LoRA)</summary>
custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
169
+
lora_configs:
170
+
- name: lora
171
+
lora_rank: 32
172
+
lora_alpha: 32
173
+
cluster:
174
+
node_num: 1
175
+
gpu_per_node: 8
176
+
buffer:
177
+
batch_size: 96
178
+
total_epochs: 1
179
+
explorer_input:
180
+
taskset:
181
+
name: taskset
182
+
storage_type: file
183
+
path: openai/gsm8k
184
+
split: train
185
+
subset_name: main
186
+
format:
187
+
prompt_key: question
188
+
response_key: answer
189
+
rollout_args:
190
+
temperature: 1.0
191
+
logprobs: 0
192
+
eval_tasksets: []
193
+
default_workflow_type: math_workflow
194
+
trainer_input:
195
+
experience_buffer:
196
+
name: experience_buffer
197
+
storage_type: queue
198
+
replay_buffer:
199
+
enable: false
200
+
priority_fn: linear_decay
201
+
reuse_cooldown_time: null
202
+
priority_fn_args:
203
+
decay: 2.0
204
+
explorer:
205
+
runner_per_model: 16
206
+
rollout_model:
207
+
engine_num: 4
208
+
tensor_parallel_size: 1
209
+
enforce_eager: false
210
+
enable_prefix_caching: false
211
+
enable_chunked_prefill: false
212
+
gpu_memory_utilization: 0.9
213
+
dtype: bfloat16
214
+
seed: 42
215
+
enable_thinking: false
216
+
enable_history: false
217
+
enable_openai_api: false
218
+
enable_auto_tool_choice: false
219
+
tool_call_parser: null
220
+
reasoning_parser: null
221
+
auxiliary_models: []
222
+
eval_interval: 1000
223
+
trainer:
224
+
trainer_type: verl
225
+
save_interval: 100
226
+
enable_preview: true
227
+
grad_clip: 1.0
228
+
max_token_len_per_gpu: 16384
229
+
monitor:
230
+
monitor_type: wandb
231
+
synchronizer:
232
+
sync_method: checkpoint
233
+
sync_style: fixed
234
+
sync_interval: 1
235
+
sync_offset: 1
236
+
sync_timeout: 1200
237
+
```
238
+
239
+
</details>
240
+
241
+
### Observations
242
+
243
+
Since **Llama-3.2-3B** is a base (non-instruct-tuned) model, it has limited ability to follow formatting instructions. Additionally, we trained for only **one epoch**. As a result, both backends achieved final rewards just slightly above **0.1**.
244
+
245
+
However, the training curves clearly show an **upward trend in reward**, indicating successful learning. The results are visualized below:
246
+
247
+

0 commit comments