Skip to content

Commit 9a4927f

Browse files
committed
Add document and apply reviews
1 parent 10dcadf commit 9a4927f

File tree

10 files changed

+191
-30
lines changed

10 files changed

+191
-30
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,21 @@ model:
164164
max_response_tokens: 16384
165165
min_response_tokens: 1
166166
enable_prompt_truncation: true
167+
repetition_penalty: 1.0
168+
lora_configs: null
169+
rope_scaling: null
170+
rope_theta: null
171+
tinker:
172+
enable: false
173+
base_model: null
174+
rank: 32
175+
seed: null
176+
train_mlp: true
177+
train_attn: true
178+
train_unembed: true
167179
```
168180

169-
- `model_path`: Path to the model being trained.
181+
- `model_path`: Path to the model being trained. If `tinker` is enabled, this is the path to the local tokenizer.
170182
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
171183
- `custom_chat_template`: Optional custom chat template in string format. If not specified, the system will use the default chat template from tokenizer.
172184
- `chat_template_path`: Optional path to the chat template file in jinja2 type; overrides `custom_chat_template` if set. If not specified, the system will use the default chat template from tokenizer.
@@ -175,6 +187,25 @@ model:
175187
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
176188
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
177189
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode.
190+
- `repetition_penalty`: Repetition penalty factor. Default is `1.0`.
191+
- `lora_configs`: Optional LoRA configuration. If not specified, defaults to `null`. Currently, only one LoRA configuration is supported.
192+
- `name`: Name of the LoRA. Default is `None`.
193+
- `path`: Path to the LoRA. Default is `None`.
194+
- `base_model_name`: Name of the base model for LoRA. If not specified, defaults to `None`.
195+
- `lora_rank`: Rank of the LoRA. Default is `32`.
196+
- `lora_alpha`: Alpha value of the LoRA. Default is `32`.
197+
- `lora_dtype`: Data type of the LoRA. Default is `auto`.
198+
- `target_modules`: List of target modules for LoRA. Default is `all-linear`.
199+
- `rope_scaling`: Optional RoPE scaling configuration in JSON format. If not specified, defaults to `null`.
200+
- `rope_theta`: Optional RoPE theta value. If not specified, defaults to `null`.
201+
- `tinker`: Optional Tinker configuration. Note: LoRA configuration will be ignored if Tinker is enabled.
202+
- `enable`: Whether to enable Tinker. Default is `false`.
203+
- `base_model`: Path to the base model for Tinker. If not specified, defaults to `model_path`.
204+
- `rank`: LoRA rank controlling the size of adaptation matrices. Default is `32`.
205+
- `seed`: Random seed for Tinker. If not specified, defaults to `null`.
206+
- `train_mlp`: Whether to train the MLP layer. Default is `true`.
207+
- `train_attn`: Whether to train the attention layer. Default is `true`.
208+
- `train_unembed`: Whether to train the unembedding layer. Default is `true`.
178209

179210
```{tip}
180211
If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,18 @@ model:
164164
max_response_tokens: 16384
165165
min_response_tokens: 1
166166
enable_prompt_truncation: true
167+
repetition_penalty: 1.0
168+
lora_configs: null
169+
rope_scaling: null
170+
rope_theta: null
171+
tinker:
172+
enable: false
173+
base_model: null
174+
rank: 32
175+
seed: null
176+
train_mlp: true
177+
train_attn: true
178+
train_unembed: true
167179
```
168180

169181
- `model_path`: 被训练模型的路径。
@@ -175,6 +187,25 @@ model:
175187
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
176188
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
177189
- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。在 OpenAI API 模式下不生效。
190+
- `repetition_penalty`:重复惩罚因子。默认值为 `1.0`。
191+
- `lora_configs`:可选的 LoRA 配置。若未指定,则默认为 `null`。目前仅支持一个 LoRA 配置。
192+
- `name`:LoRA 的名称。默认为 `None`。
193+
- `path`:LoRA 的路径。默认为 `None`。
194+
- `base_model_name`:LoRA 所基于的基础模型名称。若未指定,则默认为 `None`。
195+
- `lora_rank`:LoRA 的秩(rank)。默认为 `32`。
196+
- `lora_alpha`:LoRA 的 alpha 值。默认为 `32`。
197+
- `lora_dtype`:LoRA 的数据类型。默认为 `auto`。
198+
- `target_modules`:LoRA 的目标模块列表。默认为 `all-linear`。
199+
- `rope_scaling`:可选的 RoPE 缩放配置,采用 JSON 格式。若未指定,则默认为 `null`。
200+
- `rope_theta`:可选的 RoPE theta 值。若未指定,则默认为 `null`。
201+
- `tinker`:可选的 Tinker 配置。注意:若启用 Tinker,则 LoRA 配置将被忽略。
202+
- `enable`:是否启用 Tinker。默认为 `false`。
203+
- `base_model`:Tinker 所使用的基础模型路径。若未指定,则默认为 `model_path`。
204+
- `rank`:控制适配矩阵大小的 LoRA 秩(rank)。默认为 `32`。
205+
- `seed`:Tinker 使用的随机种子。若未指定,则默认为 `null`。
206+
- `train_mlp`:是否训练 MLP 层。默认为 `true`。
207+
- `train_attn`:是否训练注意力层。默认为 `true`。
208+
- `train_unembed`:是否训练反嵌入(unembedding)层。默认为 `true`。
178209

179210
```{tip}
180211
如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。

examples/tinker/README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Trinity with Tinker Backend
2+
3+
This example demonstrates how to use Trinity with the [Tinker](https://thinkingmachines.ai/tinker/) backend, which enables model training on devices without GPUs.
4+
5+
## Setup Instructions
6+
7+
### 1. API Key Configuration
8+
Before starting Ray, you must set the `TRINITY_API_KEY` environment variable to your Tinker API key to enable proper access to Tinker's API:
9+
10+
```bash
11+
export TRINITY_API_KEY=your_tinker_api_key
12+
```
13+
14+
### 2. Configuration File
15+
Configure the Tinker backend in your YAML configuration file by setting the `model.tinker` parameters as shown below:
16+
17+
```yaml
18+
model:
19+
tinker:
20+
enable: true
21+
base_model: null
22+
rank: 32
23+
seed: null
24+
train_mlp: true
25+
train_attn: true
26+
train_unembed: true
27+
```
28+
29+
### 3. Configuration Parameters Explained
30+
31+
- **`tinker`**: Optional Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings will be ignored.
32+
- **`enable`**: Whether to activate the Tinker backend. Default: `false`
33+
- **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config
34+
- **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32`
35+
- **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set
36+
- **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true`
37+
- **`train_attn`**: Whether to train the attention layers. Default: `true`
38+
- **`train_unembed`**: Whether to train the unembedding (output) layer. Default: `true`
39+
40+
## Usage Notes
41+
42+
Once configured, Trinity works with the Tinker backend just like it does with the standard veRL training backend, with two important limitations:
43+
1. **Entropy loss** is not consistent compared to veRL backends
44+
2. Algorithms that require **`compute_advantage_in_trainer=true`** are **not supported**
45+
46+
The complete configuration file can be found at [`tinker.yaml`](tinker.yaml).

examples/tinker/tinker.yaml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
mode: both
2+
project: Trinity-RFT-gsm8k
3+
name: tinker-Qwen3-4B
4+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
5+
algorithm:
6+
algorithm_type: grpo
7+
repeat_times: 8
8+
sample_strategy: default
9+
kl_loss_fn_args:
10+
kl_coef: 0.0
11+
optimizer:
12+
lr: 1.0e-05
13+
lr_warmup_steps_ratio: 0.0
14+
warmup_style: constant
15+
data_processor: {}
16+
model:
17+
model_path: Qwen/Qwen3-4B-Instruct-2507
18+
max_prompt_tokens: 1024
19+
max_response_tokens: 2048
20+
tinker:
21+
enable: true
22+
base_model: Qwen/Qwen3-4B-Instruct-2507
23+
buffer:
24+
batch_size: 96
25+
total_epochs: 1
26+
explorer_input:
27+
taskset:
28+
name: taskset
29+
storage_type: file
30+
path: openai/gsm8k
31+
split: train
32+
subset_name: main
33+
format:
34+
prompt_key: question
35+
response_key: answer
36+
rollout_args:
37+
temperature: 1.0
38+
logprobs: 0
39+
eval_tasksets: []
40+
default_workflow_type: math_workflow
41+
trainer_input:
42+
experience_buffer:
43+
name: experience_buffer
44+
storage_type: queue
45+
replay_buffer:
46+
enable: false
47+
explorer:
48+
runner_per_model: 8
49+
rollout_model:
50+
engine_num: 4
51+
seed: 42
52+
auxiliary_models: []
53+
eval_interval: 1000
54+
trainer:
55+
trainer_type: verl
56+
save_interval: 100
57+
enable_preview: true
58+
grad_clip: 1.0
59+
max_token_len_per_gpu: 16384
60+
monitor:
61+
monitor_type: tensorboard
62+
synchronizer:
63+
sync_method: memory
64+
sync_style: fixed
65+
sync_interval: 2
66+
sync_timeout: 1200
67+
log:
68+
level: INFO

trinity/common/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,12 +1530,11 @@ def check_and_update(self) -> Config: # noqa: C901
15301530
f"Invalid trainer.save_hf_checkpoint: {self.trainer.save_hf_checkpoint}, "
15311531
"must be one of 'last', 'always', or 'never'."
15321532
)
1533+
self.trainer.trainer_config.synchronize_config(self)
15331534
elif self.trainer.trainer_type == "tinker":
15341535
self.trainer.trainer_config = None
15351536
else:
15361537
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
1537-
if self.trainer.trainer_config:
1538-
self.trainer.trainer_config.synchronize_config(self)
15391538

15401539
# check service
15411540
if self.service.data_juicer is not None:

trinity/manager/synchronizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,12 @@ async def _find_latest_state_dict(self) -> None:
8383
await self._find_verl_latest_state_dict()
8484
elif self.config.trainer.trainer_type == "tinker":
8585
await self._find_tinker_latest_state_dict()
86+
else:
87+
self.logger.warning(
88+
"Synchronizer does not support this trainer type. Please use `verl` or `tinker`."
89+
)
8690

8791
async def _find_verl_latest_state_dict(self) -> None:
88-
assert self.config.trainer.trainer_type == "verl"
8992
default_local_dir = self.config.checkpoint_job_dir
9093
local_latest_state_dict_iteration = os.path.join(
9194
default_local_dir, "latest_state_dict_iteration.txt"
@@ -119,7 +122,6 @@ async def _find_verl_latest_state_dict(self) -> None:
119122
await asyncio.sleep(1)
120123

121124
async def _find_tinker_latest_state_dict(self) -> None:
122-
assert self.config.trainer.trainer_type == "tinker"
123125
default_local_dir = self.config.checkpoint_job_dir
124126
local_latest_state_dict_iteration = os.path.join(
125127
default_local_dir, "latest_state_dict_iteration.txt"

trinity/trainer/tinker/utils.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,12 @@
11
from logging import Logger
2-
from typing import Any, List, Tuple, Union
2+
from typing import Any, List, Tuple
33

44
import torch
55
from tinker import types
66

77
from trinity.common.experience import Experience, split_dpo_experience_to_single_turn
88

99

10-
def pad_to_length(
11-
tensor: torch.tensor, length: int, pad_value: Union[int, float] = 0
12-
) -> torch.tensor:
13-
pad_value = torch.tensor(pad_value, dtype=tensor.dtype)
14-
assert len(tensor) <= length, f"Tensor length {len(tensor)} is longer than length {length}."
15-
if len(tensor) == length:
16-
return tensor
17-
return torch.concat(
18-
[
19-
torch.full((length - len(tensor),), pad_value),
20-
tensor,
21-
]
22-
)
23-
24-
2510
def to_tinker_input(
2611
experiences: List[Experience], logger: Logger
2712
) -> Tuple[List[types.Datum], List[types.ModelInput], List[dict]]:

trinity/trainer/tinker_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
277277
f"global_step_{self.train_step_num}",
278278
)
279279
os.makedirs(local_path, exist_ok=True)
280-
remote_path_file = os.path.join(local_path, "remote_checkpoint_path.txt")
281-
with open(remote_path_file, "w") as f:
280+
remote_checkpoint_path = os.path.join(local_path, "remote_checkpoint_path.txt")
281+
with open(remote_checkpoint_path, "w") as f:
282282
f.write(self.latest_remote_checkpoint_path)
283283

284284
with open(self.local_latest_checkpointed_iteration, "w") as f:
@@ -311,8 +311,8 @@ def save_state_dict(self) -> None:
311311
f"global_step_{self.train_step_num}",
312312
)
313313
os.makedirs(local_path, exist_ok=True)
314-
remote_path_file = os.path.join(local_path, "remote_sampler_path.txt")
315-
with open(remote_path_file, "w") as f:
314+
remote_sampler_path = os.path.join(local_path, "remote_sampler_path.txt")
315+
with open(remote_sampler_path, "w") as f:
316316
f.write(self.latest_remote_sampler_path)
317317

318318
with open(self.local_latest_state_dict_iteration, "w") as f:

trinity/trainer/verl/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def to_data_proto(
102102
return DataProto.from_single_dict(batch_dict)
103103

104104

105-
def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict:
105+
def compute_data_metrics(batch: DataProto) -> dict:
106106
"""
107107
Computes various metrics from a batch of data for PPO training.
108108
Modified from verl.trainer.ppo.metric_utils.compute_data_metrics
@@ -113,16 +113,15 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict:
113113
114114
Args:
115115
batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
116-
use_critic: Whether to include critic-specific metrics. Defaults to True.
117116
118117
Returns:
119118
A dictionary of metrics including:
120119
- critic/score/mean, max, min: Statistics about sequence scores
121120
- critic/rewards/mean, max, min: Statistics about sequence rewards
122121
- critic/advantages/mean, max, min: Statistics about advantages
123122
- critic/returns/mean, max, min: Statistics about returns
124-
- critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
125-
- critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
123+
- critic/values/mean, max, min: Statistics about critic values
124+
- critic/vf_explained_var: Explained variance of the value function
126125
- response_length/mean, max, min, clip_ratio: Statistics about response lengths
127126
- prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
128127
"""

trinity/trainer/verl_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901
472472
metrics.update(actor_output_metrics)
473473

474474
# collect metrics
475-
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
475+
metrics.update(compute_data_metrics(batch=batch))
476476
timing_metrics = compute_timing_metrics(batch=batch, timing_raw=timing_raw)
477477
metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()})
478478
n_gpus = self.resource_pool_manager.get_n_gpus()

0 commit comments

Comments
 (0)