From aeeb531a76b5aa1755ca430ed49a5e1e856ba23b Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Fri, 25 Apr 2025 18:07:06 +0800 Subject: [PATCH 1/6] fix: some hf dataset may need specified config name --- examples/grpo_gsm8k/gsm8k.yaml | 9 +++++---- trinity/common/config.py | 1 + trinity/common/task.py | 14 ++++++++++++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index e2ca38cd63..7fe58567ec 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -1,8 +1,9 @@ data: # basic info - dataset_path: '/PATH/TO/DATASET/' + dataset_path: 'openai/gsm8k' + config_name: "main" train_split: 'train' - eval_split: '' + eval_split: 'test' format_config: prompt_key: 'question' response_key: 'answer' @@ -24,7 +25,7 @@ model: model_path: '/PATH/TO/MODEL/' max_prompt_tokens: 256 max_response_tokens: 1024 - checkpoint_path: '/PATH/TO/CHECKPOINT/' + checkpoint_path: "" cluster: node_num: 1 gpu_per_node: 8 @@ -35,7 +36,7 @@ buffer: name: gsm8k_buffer storage_type: queue algorithm_type: ppo - path: 'sqlite:////gsm8k.db' + path: 'sqlite:///gsm8k.db' # sft_warmup_dataset: # Uncomment these to enable sft warmup # name: warmup_data # storage_type: file diff --git a/trinity/common/config.py b/trinity/common/config.py index 46551e55d3..cc85493e63 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -45,6 +45,7 @@ class DataConfig: dataset_path: str = "" train_split: str = "train" + config_name: Optional[str] = None eval_split: Optional[str] = None # TODO: check data format format_config: FormatConfig = field(default_factory=FormatConfig) diff --git a/trinity/common/task.py b/trinity/common/task.py index e9f88724fb..a386c3fc0f 100644 --- a/trinity/common/task.py +++ b/trinity/common/task.py @@ -101,6 +101,15 @@ def task_generator( yield task +def load_hf_dataset(config: DataConfig, split: str): + """Load a Hugging Face dataset with optional configuration name.""" + if config.config_name is not None: + hf_dataset = load_dataset(config.dataset_path, config.config_name, split=split) + else: + hf_dataset = load_dataset(config.dataset_path, split=split) + return hf_dataset + + @dataclass class TaskSet: """A TaskSet class that defines a set of tasks and their associated reward functions.""" @@ -125,7 +134,8 @@ def load( # disable datasets caching to avoid reuse old-version dataset datasets.disable_caching() if task_type == TaskType.EVAL: - dataset = load_dataset(config.dataset_path)[config.eval_split] + assert config.eval_split is not None, "eval_split must be provided for eval taskset." + dataset = load_hf_dataset(config, config.eval_split) else: # default if task_type != TaskType.EVAL and config.db_url != "": logger.info(f"Loading dataset from database with url: {config.db_url}") @@ -134,7 +144,7 @@ def load( dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}") elif config.dataset_path != "": logger.info(f"Loading dataset from local file with path: {config.dataset_path}.") - dataset = load_dataset(config.dataset_path)[config.train_split] + dataset = load_hf_dataset(config, config.train_split) else: raise ValueError("No dataset path or db url provided.") datasets.enable_caching() From 625536f9abf22901bf93acccddc54ec6c93bd5c3 Mon Sep 17 00:00:00 2001 From: Qunhong Zeng <871206929@qq.com> Date: Sun, 27 Apr 2025 10:43:55 +0800 Subject: [PATCH 2/6] config_name -> subset_name Co-authored-by: Yuchang Sun <52027540+hiyuchang@users.noreply.github.com> --- trinity/common/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index cc85493e63..8fdd9f8e27 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -45,7 +45,7 @@ class DataConfig: dataset_path: str = "" train_split: str = "train" - config_name: Optional[str] = None + subset_name: Optional[str] = None eval_split: Optional[str] = None # TODO: check data format format_config: FormatConfig = field(default_factory=FormatConfig) From ff2bd10dcf1547d784a4fb4e7611e7425fa28e15 Mon Sep 17 00:00:00 2001 From: Qunhong Zeng <871206929@qq.com> Date: Sun, 27 Apr 2025 10:44:06 +0800 Subject: [PATCH 3/6] config_name -> subset_name Co-authored-by: Yuchang Sun <52027540+hiyuchang@users.noreply.github.com> --- trinity/common/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/common/task.py b/trinity/common/task.py index a386c3fc0f..93a34246fb 100644 --- a/trinity/common/task.py +++ b/trinity/common/task.py @@ -104,7 +104,7 @@ def task_generator( def load_hf_dataset(config: DataConfig, split: str): """Load a Hugging Face dataset with optional configuration name.""" if config.config_name is not None: - hf_dataset = load_dataset(config.dataset_path, config.config_name, split=split) + hf_dataset = load_dataset(config.dataset_path, config.subset_name, split=split) else: hf_dataset = load_dataset(config.dataset_path, split=split) return hf_dataset From 7e109f27911e843820804e685937f1c034d980d0 Mon Sep 17 00:00:00 2001 From: Qunhong Zeng <871206929@qq.com> Date: Sun, 27 Apr 2025 10:45:48 +0800 Subject: [PATCH 4/6] config_name -> subset_name --- examples/grpo_gsm8k/gsm8k.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 7fe58567ec..9229e40ee1 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -1,7 +1,7 @@ data: # basic info dataset_path: 'openai/gsm8k' - config_name: "main" + subset_name: "main" train_split: 'train' eval_split: 'test' format_config: From 0477d8c8ee1a7ad8db2762b2c9dbea1ce141c981 Mon Sep 17 00:00:00 2001 From: Qunhong Zeng <871206929@qq.com> Date: Sun, 27 Apr 2025 10:47:14 +0800 Subject: [PATCH 5/6] config_name -> subset_name --- trinity/common/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/common/task.py b/trinity/common/task.py index 93a34246fb..bef638b6c9 100644 --- a/trinity/common/task.py +++ b/trinity/common/task.py @@ -103,7 +103,7 @@ def task_generator( def load_hf_dataset(config: DataConfig, split: str): """Load a Hugging Face dataset with optional configuration name.""" - if config.config_name is not None: + if config.subset_name is not None: hf_dataset = load_dataset(config.dataset_path, config.subset_name, split=split) else: hf_dataset = load_dataset(config.dataset_path, split=split) From c9c3290adef74fb93b4e523402f52556ac2fdb9b Mon Sep 17 00:00:00 2001 From: Qunhong Zeng <871206929@qq.com> Date: Sun, 27 Apr 2025 02:53:51 +0000 Subject: [PATCH 6/6] style: fix trim trailing whitespace --- examples/grpo_gsm8k/gsm8k.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 9229e40ee1..677027ea19 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -1,7 +1,7 @@ data: # basic info dataset_path: 'openai/gsm8k' - subset_name: "main" + subset_name: "main" train_split: 'train' eval_split: 'test' format_config: