diff --git a/docs/sphinx_doc/source/tutorial/align_with_verl.md b/docs/sphinx_doc/source/tutorial/align_with_verl.md index f3714e2154..6290b59be1 100644 --- a/docs/sphinx_doc/source/tutorial/align_with_verl.md +++ b/docs/sphinx_doc/source/tutorial/align_with_verl.md @@ -60,7 +60,7 @@ To match the default training setup of veRL, we set `synchronizer.sync_style=fix | `data.max_response_length` | `model.max_response_tokens` | - | | `data.filter_overlong_prompts` | `model.enable_prompt_truncation` | Explained later | | `data.truncation` | - | Equivalent to `right` | -| `data.shuffle` | `buffer.explorer_input.taskset.task_selector.selector_type:random` | Taskset-specific | +| `data.shuffle` | `buffer.explorer_input.taskset.task_selector.selector_type:shuffle` | Taskset-specific | 💡 Detailed explanation: diff --git a/docs/sphinx_doc/source_zh/tutorial/align_with_verl.md b/docs/sphinx_doc/source_zh/tutorial/align_with_verl.md index a884dbcf82..b8a65f0f27 100644 --- a/docs/sphinx_doc/source_zh/tutorial/align_with_verl.md +++ b/docs/sphinx_doc/source_zh/tutorial/align_with_verl.md @@ -60,7 +60,7 @@ Trinity-RFT 根据功能将强化微调的大量参数分为几个部分,例 | `data.max_response_length` | `model.max_response_tokens` | - | | `data.filter_overlong_prompts` | `model.enable_prompt_truncation` | 稍后说明 | | `data.truncation` | - | 等同于 `right` | -| `data.shuffle` | `buffer.explorer_input.taskset.task_selector.selector_type:random` | Taskset-specific | +| `data.shuffle` | `buffer.explorer_input.taskset.task_selector.selector_type:shuffle` | Taskset-specific | 💡 详细说明: diff --git a/trinity/common/config.py b/trinity/common/config.py index 9738ec3b8e..7c5a9c5a56 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -892,6 +892,8 @@ def _check_interval(self) -> None: ) def _check_explorer_input(self) -> None: + from trinity.buffer.selector import SELECTORS + if self.mode in {"train", "serve"}: # no need to check explorer_input in serve mode return @@ -932,6 +934,13 @@ def _check_explorer_input(self) -> None: set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens) set_if_none(taskset.format, "chat_template", self.model.custom_chat_template) + # check if selector is supported + selector = SELECTORS.get(taskset.task_selector.selector_type) + if selector is None: + raise ValueError( + f"Selector {taskset.task_selector.selector_type} is not supported." + ) + for idx, dataset in enumerate(explorer_input.eval_tasksets): if not dataset.path: raise ValueError(f"Eval dataset [{dataset}]'s path is not configured.")