diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 9eb9cd90fb..10e44d7ff1 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -157,6 +157,8 @@ Defines the model paths and token limits. model: model_path: ${oc.env:MODEL_PATH} # MODEL_PATH is an environment variable set in advance critic_model_path: ${model.model_path} # use the value of model.model_path + custom_chat_template: None + chat_template_path: None max_model_len: 20480 max_prompt_tokens: 4096 max_response_tokens: 16384 @@ -165,6 +167,8 @@ model: - `model_path`: Path to the model being trained. - `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`. +- `custom_chat_template`: Optional custom chat template in string format. If not specified, the system will use the default chat template from tokenizer. +- `chat_template_path`: Optional path to the chat template file in jinja2 type; overrides `custom_chat_template` if set. If not specified, the system will use the default chat template from tokenizer. - `max_model_len`: Maximum number of tokens in a sequence. It is recommended to set this value manually. If not specified, the system will attempt to set it to `max_prompt_tokens` + `max_response_tokens`. However, this requires both values to be already set; otherwise, an error will be raised. - `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. - `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index c1e1847254..a6404e9b58 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -157,6 +157,8 @@ monitor: model: model_path: ${oc.env:MODEL_PATH} # MODEL_PATH 是预先设置的环境变量 critic_model_path: ${model.model_path} # 使用 model.model_path 的值 + custom_chat_template: None + chat_template_path: None max_model_len: 20480 max_prompt_tokens: 4096 max_response_tokens: 16384 @@ -165,6 +167,8 @@ model: - `model_path`: 被训练模型的路径。 - `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。 +- `custom_chat_template`: 可选的自定义 chat template 字符串格式。若未指定,系统会使用 tokenizer 的默认 chat template。 +- `chat_template_path`: 可选的 chat template 文件路径,类型通常为 jinja2;若设置,则覆盖 `custom_chat_template`。若未指定,系统会使用 tokenizer 的默认 chat template。 - `max_model_len`: 表示模型所支持的单个序列最大 token 数。如未指定,系统会尝试将其设为 `max_prompt_tokens` + `max_response_tokens`。但前提是这两个值都必须已设置,否则将引发错误。 - `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 diff --git a/examples/dapo_math/README.md b/examples/dapo_math/README.md index f82357e618..0192a4e5e0 100644 --- a/examples/dapo_math/README.md +++ b/examples/dapo_math/README.md @@ -1,5 +1,5 @@ # DAPO on DAPO-MATH-17k dataset [WIP] -This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset. +Note this example only shows the usage of GRPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset. We plan to implement DAPO algorithm soon. The config file is located in [`dapo.yaml`](dapo.yaml). diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml index 6418c22d56..2b75cf507c 100644 --- a/examples/dapo_math/dapo.yaml +++ b/examples/dapo_math/dapo.yaml @@ -2,7 +2,7 @@ project: Trinity-RFT-example name: dapo checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} model: - model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct} + model_path: ${oc.env:TRINITY_MODEL_PATH} # Suggest using larger model on this dataset max_response_tokens: 20480 max_model_len: 21504 algorithm: diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 37b9805ba3..d2ca350aef 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -154,6 +154,19 @@ def test_optimizer_config_propagation(self): self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01) self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant") + def test_chat_template_path(self): + config = get_template_config() + config.model.chat_template_path = "tests/template/custom_chat_template.j2" + config.check_and_update() + self.assertIsNotNone(config.model.custom_chat_template) + self.assertEqual( + config.model.custom_chat_template, + config.buffer.explorer_input.tasksets[0].format.chat_template, + ) + self.assertEqual( + config.model.custom_chat_template, config.explorer.rollout_model.chat_template + ) + def tearDown(self): if os.path.exists(CHECKPOINT_ROOT_DIR): shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index f60d39ccf9..03c09d9821 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -109,9 +109,9 @@ def setUp(self): self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_model_path() + self.config.model.custom_chat_template = CHAT_TEMPLATE self.config.explorer.rollout_model.engine_num = self.engine_num self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE self.config.algorithm.repeat_times = self.repeat_times self.config.explorer.rollout_model.enable_history = self.enable_history self.config.check_and_update() diff --git a/tests/template/custom_chat_template.j2 b/tests/template/custom_chat_template.j2 new file mode 100644 index 0000000000..1a0adedd94 --- /dev/null +++ b/tests/template/custom_chat_template.j2 @@ -0,0 +1,67 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} diff --git a/trinity/common/config.py b/trinity/common/config.py index c722959b96..5ab6631432 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -412,6 +412,9 @@ class ModelConfig: critic_model_path: str = "" custom_chat_template: Optional[str] = None + chat_template_path: Optional[ + str + ] = None # path to the chat template file, e.g., jinja2 type; overrides `custom_chat_template` if set # rollout args temperature: float = 1.0 @@ -872,6 +875,7 @@ def _check_explorer_input(self) -> None: set_if_none(taskset.rollout_args, "top_k", self.model.top_k) set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs) set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens) + set_if_none(taskset.format, "chat_template", self.model.custom_chat_template) for idx, dataset in enumerate(explorer_input.eval_tasksets): if not dataset.path: @@ -1066,6 +1070,11 @@ def _check_model(self) -> None: if not model.critic_model_path: model.critic_model_path = model.model_path + # check template + if model.chat_template_path and model.custom_chat_template is None: + with open(model.chat_template_path, "r") as f: + model.custom_chat_template = f.read() + # check max_model_len, max_prompt_tokens, max_response_tokens # if all three are set, check if they are valid @@ -1192,6 +1201,11 @@ def check_and_update(self) -> Config: # noqa: C901 ] for args in ["model_path"] + rollout_args + length_args: setattr(self.explorer.rollout_model, args, getattr(self.model, args)) + if ( + self.explorer.rollout_model.chat_template is None + and self.model.custom_chat_template is not None + ): + self.explorer.rollout_model.chat_template = self.model.custom_chat_template for aux_model in self.explorer.auxiliary_models: if not aux_model.model_path: raise ValueError("auxiliary model's model_path is required.") diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 49a241393a..adff85cb7d 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -15,6 +15,7 @@ @dataclass class Data: train_batch_size: int = 1024 # kept to pass RayPPOTrainer._validate_config + trust_remote_code: bool = False @dataclass @@ -34,6 +35,7 @@ class ActorModel: custom_chat_template: Optional[str] = None enable_activation_offload: bool = False use_shm: bool = False + trust_remote_code: bool = False # Whether to enable loading a remote code model # lora configs lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer @@ -219,6 +221,7 @@ class CriticModel: tokenizer_path: str = "" override_config: Dict[str, str] = field(default_factory=dict) external_lib: Optional[str] = None + trust_remote_code: bool = False # Whether to enable loading a remote code model enable_gradient_checkpointing: bool = True use_remove_padding: bool = True fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index bc879e6afc..a2c11ff49f 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -194,9 +194,10 @@ def __init__( local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) # instantiate tokenizer - tokenizer = hf_tokenizer(local_path) + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) # processor for multimodal LLM, could be None - processor = hf_processor(local_path, use_fast=True) + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) # define worker classes if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: