diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md
index 9eb9cd90fb..10e44d7ff1 100644
--- a/docs/sphinx_doc/source/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md
@@ -157,6 +157,8 @@ Defines the model paths and token limits.
model:
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance
critic_model_path: ${model.model_path} # use the value of model.model_path
+ custom_chat_template: None
+ chat_template_path: None
max_model_len: 20480
max_prompt_tokens: 4096
max_response_tokens: 16384
@@ -165,6 +167,8 @@ model:
- `model_path`: Path to the model being trained.
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
+- `custom_chat_template`: Optional custom chat template in string format. If not specified, the system will use the default chat template from tokenizer.
+- `chat_template_path`: Optional path to the chat template file in jinja2 type; overrides `custom_chat_template` if set. If not specified, the system will use the default chat template from tokenizer.
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not specified, the system will attempt to set it to `max_prompt_tokens` + `max_response_tokens`. However, this requires both values to be already set; otherwise, an error will be raised.
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
index c1e1847254..a6404e9b58 100644
--- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
+++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
@@ -157,6 +157,8 @@ monitor:
model:
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH 是预先设置的环境变量
critic_model_path: ${model.model_path} # 使用 model.model_path 的值
+ custom_chat_template: None
+ chat_template_path: None
max_model_len: 20480
max_prompt_tokens: 4096
max_response_tokens: 16384
@@ -165,6 +167,8 @@ model:
- `model_path`: 被训练模型的路径。
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
+- `custom_chat_template`: 可选的自定义 chat template 字符串格式。若未指定,系统会使用 tokenizer 的默认 chat template。
+- `chat_template_path`: 可选的 chat template 文件路径,类型通常为 jinja2;若设置,则覆盖 `custom_chat_template`。若未指定,系统会使用 tokenizer 的默认 chat template。
- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如未指定,系统会尝试将其设为 `max_prompt_tokens` + `max_response_tokens`。但前提是这两个值都必须已设置,否则将引发错误。
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
diff --git a/examples/dapo_math/README.md b/examples/dapo_math/README.md
index f82357e618..0192a4e5e0 100644
--- a/examples/dapo_math/README.md
+++ b/examples/dapo_math/README.md
@@ -1,5 +1,5 @@
# DAPO on DAPO-MATH-17k dataset [WIP]
-This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset.
+Note this example only shows the usage of GRPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset. We plan to implement DAPO algorithm soon.
The config file is located in [`dapo.yaml`](dapo.yaml).
diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml
index 6418c22d56..2b75cf507c 100644
--- a/examples/dapo_math/dapo.yaml
+++ b/examples/dapo_math/dapo.yaml
@@ -2,7 +2,7 @@ project: Trinity-RFT-example
name: dapo
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
model:
- model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ model_path: ${oc.env:TRINITY_MODEL_PATH} # Suggest using larger model on this dataset
max_response_tokens: 20480
max_model_len: 21504
algorithm:
diff --git a/tests/common/config_test.py b/tests/common/config_test.py
index 37b9805ba3..d2ca350aef 100644
--- a/tests/common/config_test.py
+++ b/tests/common/config_test.py
@@ -154,6 +154,19 @@ def test_optimizer_config_propagation(self):
self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01)
self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant")
+ def test_chat_template_path(self):
+ config = get_template_config()
+ config.model.chat_template_path = "tests/template/custom_chat_template.j2"
+ config.check_and_update()
+ self.assertIsNotNone(config.model.custom_chat_template)
+ self.assertEqual(
+ config.model.custom_chat_template,
+ config.buffer.explorer_input.tasksets[0].format.chat_template,
+ )
+ self.assertEqual(
+ config.model.custom_chat_template, config.explorer.rollout_model.chat_template
+ )
+
def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py
index f60d39ccf9..03c09d9821 100644
--- a/tests/common/vllm_test.py
+++ b/tests/common/vllm_test.py
@@ -109,9 +109,9 @@ def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
+ self.config.model.custom_chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.engine_num = self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
- self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = self.repeat_times
self.config.explorer.rollout_model.enable_history = self.enable_history
self.config.check_and_update()
diff --git a/tests/template/custom_chat_template.j2 b/tests/template/custom_chat_template.j2
new file mode 100644
index 0000000000..1a0adedd94
--- /dev/null
+++ b/tests/template/custom_chat_template.j2
@@ -0,0 +1,67 @@
+{%- if tools %}
+ {{- '<|im_start|>system\n' }}
+ {%- if messages[0].role == 'system' %}
+ {{- messages[0].content + '\n\n' }}
+ {%- endif %}
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }}
+ {%- for tool in tools %}
+ {{- "\n" }}
+ {{- tool | tojson }}
+ {%- endfor %}
+ {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }}
+{%- else %}
+ {%- if messages[0].role == 'system' %}
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
+ {%- endif %}
+{%- endif %}
+{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
+{%- for message in messages[::-1] %}
+ {%- set index = (messages|length - 1) - loop.index0 %}
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %}
+ {%- set ns.multi_step_tool = false %}
+ {%- set ns.last_query_index = index %}
+ {%- endif %}
+{%- endfor %}
+{%- for message in messages %}
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
+ {%- elif message.role == "assistant" %}
+ {{- '<|im_start|>' + message.role + '\n' + message.content }}
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if (loop.first and content) or (not loop.first) %}
+ {{- '\n' }}
+ {%- endif %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '\n{"name": "' }}
+ {{- tool_call.name }}
+ {{- '", "arguments": ' }}
+ {%- if tool_call.arguments is string %}
+ {{- tool_call.arguments }}
+ {%- else %}
+ {{- tool_call.arguments | tojson }}
+ {%- endif %}
+ {{- '}\n' }}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|im_end|>\n' }}
+ {%- elif message.role == "tool" %}
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
+ {{- '<|im_start|>user' }}
+ {%- endif %}
+ {{- '\n\n' }}
+ {{- message.content }}
+ {{- '\n' }}
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
+ {{- '<|im_end|>\n' }}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {{- '<|im_start|>assistant\n' }}
+ {%- if enable_thinking is defined and enable_thinking is false %}
+ {{- '\n\n\n\n' }}
+ {%- endif %}
+{%- endif %}
diff --git a/trinity/common/config.py b/trinity/common/config.py
index c722959b96..5ab6631432 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -412,6 +412,9 @@ class ModelConfig:
critic_model_path: str = ""
custom_chat_template: Optional[str] = None
+ chat_template_path: Optional[
+ str
+ ] = None # path to the chat template file, e.g., jinja2 type; overrides `custom_chat_template` if set
# rollout args
temperature: float = 1.0
@@ -872,6 +875,7 @@ def _check_explorer_input(self) -> None:
set_if_none(taskset.rollout_args, "top_k", self.model.top_k)
set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs)
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)
+ set_if_none(taskset.format, "chat_template", self.model.custom_chat_template)
for idx, dataset in enumerate(explorer_input.eval_tasksets):
if not dataset.path:
@@ -1066,6 +1070,11 @@ def _check_model(self) -> None:
if not model.critic_model_path:
model.critic_model_path = model.model_path
+ # check template
+ if model.chat_template_path and model.custom_chat_template is None:
+ with open(model.chat_template_path, "r") as f:
+ model.custom_chat_template = f.read()
+
# check max_model_len, max_prompt_tokens, max_response_tokens
# if all three are set, check if they are valid
@@ -1192,6 +1201,11 @@ def check_and_update(self) -> Config: # noqa: C901
]
for args in ["model_path"] + rollout_args + length_args:
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
+ if (
+ self.explorer.rollout_model.chat_template is None
+ and self.model.custom_chat_template is not None
+ ):
+ self.explorer.rollout_model.chat_template = self.model.custom_chat_template
for aux_model in self.explorer.auxiliary_models:
if not aux_model.model_path:
raise ValueError("auxiliary model's model_path is required.")
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index 49a241393a..adff85cb7d 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -15,6 +15,7 @@
@dataclass
class Data:
train_batch_size: int = 1024 # kept to pass RayPPOTrainer._validate_config
+ trust_remote_code: bool = False
@dataclass
@@ -34,6 +35,7 @@ class ActorModel:
custom_chat_template: Optional[str] = None
enable_activation_offload: bool = False
use_shm: bool = False
+ trust_remote_code: bool = False # Whether to enable loading a remote code model
# lora configs
lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
@@ -219,6 +221,7 @@ class CriticModel:
tokenizer_path: str = ""
override_config: Dict[str, str] = field(default_factory=dict)
external_lib: Optional[str] = None
+ trust_remote_code: bool = False # Whether to enable loading a remote code model
enable_gradient_checkpointing: bool = True
use_remove_padding: bool = True
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index bc879e6afc..a2c11ff49f 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -194,9 +194,10 @@ def __init__(
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
- tokenizer = hf_tokenizer(local_path)
+ trust_remote_code = config.data.get("trust_remote_code", False)
+ tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# processor for multimodal LLM, could be None
- processor = hf_processor(local_path, use_fast=True)
+ processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
# define worker classes
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: