Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ Defines the model paths and token limits.
model:
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance
critic_model_path: ${model.model_path} # use the value of model.model_path
custom_chat_template: None
chat_template_path: None
max_model_len: 20480
max_prompt_tokens: 4096
max_response_tokens: 16384
Expand All @@ -165,6 +167,8 @@ model:

- `model_path`: Path to the model being trained.
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
- `custom_chat_template`: Optional custom chat template in string format. If not specified, the system will use the default chat template from tokenizer.
- `chat_template_path`: Optional path to the chat template file in jinja2 type; overrides `custom_chat_template` if set. If not specified, the system will use the default chat template from tokenizer.
- `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not specified, the system will attempt to set it to `max_prompt_tokens` + `max_response_tokens`. However, this requires both values to be already set; otherwise, an error will be raised.
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
Expand Down
4 changes: 4 additions & 0 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ monitor:
model:
model_path: ${oc.env:MODEL_PATH} # MODEL_PATH 是预先设置的环境变量
critic_model_path: ${model.model_path} # 使用 model.model_path 的值
custom_chat_template: None
chat_template_path: None
max_model_len: 20480
max_prompt_tokens: 4096
max_response_tokens: 16384
Expand All @@ -165,6 +167,8 @@ model:

- `model_path`: 被训练模型的路径。
- `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。
- `custom_chat_template`: 可选的自定义 chat template 字符串格式。若未指定,系统会使用 tokenizer 的默认 chat template。
- `chat_template_path`: 可选的 chat template 文件路径,类型通常为 jinja2;若设置,则覆盖 `custom_chat_template`。若未指定,系统会使用 tokenizer 的默认 chat template。
- `max_model_len`: 表示模型所支持的单个序列最大 token 数。如未指定,系统会尝试将其设为 `max_prompt_tokens` + `max_response_tokens`。但前提是这两个值都必须已设置,否则将引发错误。
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
Expand Down
2 changes: 1 addition & 1 deletion examples/dapo_math/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# DAPO on DAPO-MATH-17k dataset [WIP]

This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset.
Note this example only shows the usage of GRPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset. We plan to implement DAPO algorithm soon.

The config file is located in [`dapo.yaml`](dapo.yaml).
2 changes: 1 addition & 1 deletion examples/dapo_math/dapo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ project: Trinity-RFT-example
name: dapo
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
model_path: ${oc.env:TRINITY_MODEL_PATH} # Suggest using larger model on this dataset
max_response_tokens: 20480
max_model_len: 21504
algorithm:
Expand Down
13 changes: 13 additions & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ def test_optimizer_config_propagation(self):
self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01)
self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant")

def test_chat_template_path(self):
config = get_template_config()
config.model.chat_template_path = "tests/template/custom_chat_template.j2"
config.check_and_update()
self.assertIsNotNone(config.model.custom_chat_template)
self.assertEqual(
config.model.custom_chat_template,
config.buffer.explorer_input.tasksets[0].format.chat_template,
)
self.assertEqual(
config.model.custom_chat_template, config.explorer.rollout_model.chat_template
)

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
2 changes: 1 addition & 1 deletion tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.custom_chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.engine_num = self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = self.repeat_times
self.config.explorer.rollout_model.enable_history = self.enable_history
self.config.check_and_update()
Expand Down
67 changes: 67 additions & 0 deletions tests/template/custom_chat_template.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{{- '<|im_start|>' + message.role + '\n' + message.content }}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- endif %}
{%- endif %}
14 changes: 14 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ class ModelConfig:
critic_model_path: str = ""

custom_chat_template: Optional[str] = None
chat_template_path: Optional[
str
] = None # path to the chat template file, e.g., jinja2 type; overrides `custom_chat_template` if set

# rollout args
temperature: float = 1.0
Expand Down Expand Up @@ -872,6 +875,7 @@ def _check_explorer_input(self) -> None:
set_if_none(taskset.rollout_args, "top_k", self.model.top_k)
set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs)
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)
set_if_none(taskset.format, "chat_template", self.model.custom_chat_template)

for idx, dataset in enumerate(explorer_input.eval_tasksets):
if not dataset.path:
Expand Down Expand Up @@ -1066,6 +1070,11 @@ def _check_model(self) -> None:
if not model.critic_model_path:
model.critic_model_path = model.model_path

# check template
if model.chat_template_path and model.custom_chat_template is None:
with open(model.chat_template_path, "r") as f:
model.custom_chat_template = f.read()

# check max_model_len, max_prompt_tokens, max_response_tokens

# if all three are set, check if they are valid
Expand Down Expand Up @@ -1192,6 +1201,11 @@ def check_and_update(self) -> Config: # noqa: C901
]
for args in ["model_path"] + rollout_args + length_args:
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
if (
self.explorer.rollout_model.chat_template is None
and self.model.custom_chat_template is not None
):
self.explorer.rollout_model.chat_template = self.model.custom_chat_template
for aux_model in self.explorer.auxiliary_models:
if not aux_model.model_path:
raise ValueError("auxiliary model's model_path is required.")
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
@dataclass
class Data:
train_batch_size: int = 1024 # kept to pass RayPPOTrainer._validate_config
trust_remote_code: bool = False


@dataclass
Expand All @@ -34,6 +35,7 @@ class ActorModel:
custom_chat_template: Optional[str] = None
enable_activation_offload: bool = False
use_shm: bool = False
trust_remote_code: bool = False # Whether to enable loading a remote code model

# lora configs
lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
Expand Down Expand Up @@ -219,6 +221,7 @@ class CriticModel:
tokenizer_path: str = ""
override_config: Dict[str, str] = field(default_factory=dict)
external_lib: Optional[str] = None
trust_remote_code: bool = False # Whether to enable loading a remote code model
enable_gradient_checkpointing: bool = True
use_remove_padding: bool = True
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
Expand Down
5 changes: 3 additions & 2 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,10 @@ def __init__(
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)

# instantiate tokenizer
tokenizer = hf_tokenizer(local_path)
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# processor for multimodal LLM, could be None
processor = hf_processor(local_path, use_fast=True)
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)

# define worker classes
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
Expand Down
Loading