Skip to content

Commit bf5c134

Browse files
authored
Add chat_template_path and trust_remote_code (#379)
1 parent 0c58b5e commit bf5c134

File tree

10 files changed

+111
-5
lines changed

10 files changed

+111
-5
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ Defines the model paths and token limits.
157157
model:
158158
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance
159159
critic_model_path: ${model.model_path} # use the value of model.model_path
160+
custom_chat_template: None
161+
chat_template_path: None
160162
max_model_len: 20480
161163
max_prompt_tokens: 4096
162164
max_response_tokens: 16384
@@ -165,6 +167,8 @@ model:
165167

166168
- `model_path`: Path to the model being trained.
167169
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
170+
- `custom_chat_template`: Optional custom chat template in string format. If not specified, the system will use the default chat template from tokenizer.
171+
- `chat_template_path`: Optional path to the chat template file in jinja2 type; overrides `custom_chat_template` if set. If not specified, the system will use the default chat template from tokenizer.
168172
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not specified, the system will attempt to set it to `max_prompt_tokens` + `max_response_tokens`. However, this requires both values to be already set; otherwise, an error will be raised.
169173
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
170174
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ monitor:
157157
model:
158158
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH 是预先设置的环境变量
159159
critic_model_path: ${model.model_path} # 使用 model.model_path 的值
160+
custom_chat_template: None
161+
chat_template_path: None
160162
max_model_len: 20480
161163
max_prompt_tokens: 4096
162164
max_response_tokens: 16384
@@ -165,6 +167,8 @@ model:
165167

166168
- `model_path`: 被训练模型的路径。
167169
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
170+
- `custom_chat_template`: 可选的自定义 chat template 字符串格式。若未指定,系统会使用 tokenizer 的默认 chat template。
171+
- `chat_template_path`: 可选的 chat template 文件路径,类型通常为 jinja2;若设置,则覆盖 `custom_chat_template`。若未指定,系统会使用 tokenizer 的默认 chat template。
168172
- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如未指定,系统会尝试将其设为 `max_prompt_tokens` + `max_response_tokens`。但前提是这两个值都必须已设置,否则将引发错误。
169173
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
170174
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。

examples/dapo_math/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# DAPO on DAPO-MATH-17k dataset [WIP]
22

3-
This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset.
3+
Note this example only shows the usage of GRPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset. We plan to implement DAPO algorithm soon.
44

55
The config file is located in [`dapo.yaml`](dapo.yaml).

examples/dapo_math/dapo.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ project: Trinity-RFT-example
22
name: dapo
33
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
44
model:
5-
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
5+
model_path: ${oc.env:TRINITY_MODEL_PATH} # Suggest using larger model on this dataset
66
max_response_tokens: 20480
77
max_model_len: 21504
88
algorithm:

tests/common/config_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ def test_optimizer_config_propagation(self):
168168
self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant")
169169
self.assertEqual(config.trainer.trainer_config.critic.optim.clip_grad, 1.0)
170170

171+
def test_chat_template_path(self):
172+
config = get_template_config()
173+
config.model.chat_template_path = "tests/template/custom_chat_template.j2"
174+
config.check_and_update()
175+
self.assertIsNotNone(config.model.custom_chat_template)
176+
self.assertEqual(
177+
config.model.custom_chat_template,
178+
config.buffer.explorer_input.tasksets[0].format.chat_template,
179+
)
180+
self.assertEqual(
181+
config.model.custom_chat_template, config.explorer.rollout_model.chat_template
182+
)
183+
171184
def tearDown(self):
172185
if os.path.exists(CHECKPOINT_ROOT_DIR):
173186
shutil.rmtree(CHECKPOINT_ROOT_DIR)

tests/common/vllm_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def setUp(self):
110110
self.config = get_template_config()
111111
self.config.mode = "explore"
112112
self.config.model.model_path = get_model_path()
113+
self.config.model.custom_chat_template = CHAT_TEMPLATE
113114
self.config.explorer.rollout_model.engine_num = self.engine_num
114115
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
115-
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
116116
self.config.algorithm.repeat_times = self.repeat_times
117117
self.config.explorer.rollout_model.enable_history = self.enable_history
118118
self.config.check_and_update()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
{%- if tools %}
2+
{{- '<|im_start|>system\n' }}
3+
{%- if messages[0].role == 'system' %}
4+
{{- messages[0].content + '\n\n' }}
5+
{%- endif %}
6+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7+
{%- for tool in tools %}
8+
{{- "\n" }}
9+
{{- tool | tojson }}
10+
{%- endfor %}
11+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12+
{%- else %}
13+
{%- if messages[0].role == 'system' %}
14+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15+
{%- endif %}
16+
{%- endif %}
17+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18+
{%- for message in messages[::-1] %}
19+
{%- set index = (messages|length - 1) - loop.index0 %}
20+
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21+
{%- set ns.multi_step_tool = false %}
22+
{%- set ns.last_query_index = index %}
23+
{%- endif %}
24+
{%- endfor %}
25+
{%- for message in messages %}
26+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28+
{%- elif message.role == "assistant" %}
29+
{{- '<|im_start|>' + message.role + '\n' + message.content }}
30+
{%- if message.tool_calls %}
31+
{%- for tool_call in message.tool_calls %}
32+
{%- if (loop.first and content) or (not loop.first) %}
33+
{{- '\n' }}
34+
{%- endif %}
35+
{%- if tool_call.function %}
36+
{%- set tool_call = tool_call.function %}
37+
{%- endif %}
38+
{{- '<tool_call>\n{"name": "' }}
39+
{{- tool_call.name }}
40+
{{- '", "arguments": ' }}
41+
{%- if tool_call.arguments is string %}
42+
{{- tool_call.arguments }}
43+
{%- else %}
44+
{{- tool_call.arguments | tojson }}
45+
{%- endif %}
46+
{{- '}\n</tool_call>' }}
47+
{%- endfor %}
48+
{%- endif %}
49+
{{- '<|im_end|>\n' }}
50+
{%- elif message.role == "tool" %}
51+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
52+
{{- '<|im_start|>user' }}
53+
{%- endif %}
54+
{{- '\n<tool_response>\n' }}
55+
{{- message.content }}
56+
{{- '\n</tool_response>' }}
57+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
58+
{{- '<|im_end|>\n' }}
59+
{%- endif %}
60+
{%- endif %}
61+
{%- endfor %}
62+
{%- if add_generation_prompt %}
63+
{{- '<|im_start|>assistant\n' }}
64+
{%- if enable_thinking is defined and enable_thinking is false %}
65+
{{- '<think>\n\n</think>\n\n' }}
66+
{%- endif %}
67+
{%- endif %}

trinity/common/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,9 @@ class ModelConfig:
417417
critic_model_path: str = ""
418418

419419
custom_chat_template: Optional[str] = None
420+
chat_template_path: Optional[
421+
str
422+
] = None # path to the chat template file, e.g., jinja2 type; overrides `custom_chat_template` if set
420423

421424
# rollout args
422425
temperature: float = 1.0
@@ -885,6 +888,7 @@ def _check_explorer_input(self) -> None:
885888
set_if_none(taskset.rollout_args, "top_k", self.model.top_k)
886889
set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs)
887890
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)
891+
set_if_none(taskset.format, "chat_template", self.model.custom_chat_template)
888892

889893
for idx, dataset in enumerate(explorer_input.eval_tasksets):
890894
if not dataset.path:
@@ -1079,6 +1083,11 @@ def _check_model(self) -> None:
10791083
if not model.critic_model_path:
10801084
model.critic_model_path = model.model_path
10811085

1086+
# check template
1087+
if model.chat_template_path and model.custom_chat_template is None:
1088+
with open(model.chat_template_path, "r") as f:
1089+
model.custom_chat_template = f.read()
1090+
10821091
# check max_model_len, max_prompt_tokens, max_response_tokens
10831092

10841093
# if all three are set, check if they are valid
@@ -1207,6 +1216,11 @@ def check_and_update(self) -> Config: # noqa: C901
12071216
model_args = rollout_args + length_args + rope_args
12081217
for args in ["model_path"] + model_args:
12091218
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
1219+
if (
1220+
self.explorer.rollout_model.chat_template is None
1221+
and self.model.custom_chat_template is not None
1222+
):
1223+
self.explorer.rollout_model.chat_template = self.model.custom_chat_template
12101224
for aux_model in self.explorer.auxiliary_models:
12111225
if not aux_model.model_path:
12121226
raise ValueError("auxiliary model's model_path is required.")

trinity/common/verl_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
@dataclass
1616
class Data:
1717
train_batch_size: int = 1024 # kept to pass RayPPOTrainer._validate_config
18+
trust_remote_code: bool = False
1819

1920

2021
@dataclass
@@ -34,6 +35,7 @@ class ActorModel:
3435
custom_chat_template: Optional[str] = None
3536
enable_activation_offload: bool = False
3637
use_shm: bool = False
38+
trust_remote_code: bool = False # Whether to enable loading a remote code model
3739

3840
# lora configs
3941
lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
@@ -223,6 +225,7 @@ class CriticModel:
223225
tokenizer_path: str = ""
224226
override_config: Dict[str, str] = field(default_factory=dict)
225227
external_lib: Optional[str] = None
228+
trust_remote_code: bool = False # Whether to enable loading a remote code model
226229
enable_gradient_checkpointing: bool = True
227230
use_remove_padding: bool = True
228231
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)

trinity/trainer/verl_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,10 @@ def __init__(
194194
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
195195

196196
# instantiate tokenizer
197-
tokenizer = hf_tokenizer(local_path)
197+
trust_remote_code = config.data.get("trust_remote_code", False)
198+
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
198199
# processor for multimodal LLM, could be None
199-
processor = hf_processor(local_path, use_fast=True)
200+
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
200201

201202
# define worker classes
202203
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:

0 commit comments

Comments
 (0)