diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index e2ca38cd63..677027ea19 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -1,8 +1,9 @@ data: # basic info - dataset_path: '/PATH/TO/DATASET/' + dataset_path: 'openai/gsm8k' + subset_name: "main" train_split: 'train' - eval_split: '' + eval_split: 'test' format_config: prompt_key: 'question' response_key: 'answer' @@ -24,7 +25,7 @@ model: model_path: '/PATH/TO/MODEL/' max_prompt_tokens: 256 max_response_tokens: 1024 - checkpoint_path: '/PATH/TO/CHECKPOINT/' + checkpoint_path: "" cluster: node_num: 1 gpu_per_node: 8 @@ -35,7 +36,7 @@ buffer: name: gsm8k_buffer storage_type: queue algorithm_type: ppo - path: 'sqlite:////gsm8k.db' + path: 'sqlite:///gsm8k.db' # sft_warmup_dataset: # Uncomment these to enable sft warmup # name: warmup_data # storage_type: file diff --git a/trinity/common/config.py b/trinity/common/config.py index 46551e55d3..8fdd9f8e27 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -45,6 +45,7 @@ class DataConfig: dataset_path: str = "" train_split: str = "train" + subset_name: Optional[str] = None eval_split: Optional[str] = None # TODO: check data format format_config: FormatConfig = field(default_factory=FormatConfig) diff --git a/trinity/common/task.py b/trinity/common/task.py index e9f88724fb..bef638b6c9 100644 --- a/trinity/common/task.py +++ b/trinity/common/task.py @@ -101,6 +101,15 @@ def task_generator( yield task +def load_hf_dataset(config: DataConfig, split: str): + """Load a Hugging Face dataset with optional configuration name.""" + if config.subset_name is not None: + hf_dataset = load_dataset(config.dataset_path, config.subset_name, split=split) + else: + hf_dataset = load_dataset(config.dataset_path, split=split) + return hf_dataset + + @dataclass class TaskSet: """A TaskSet class that defines a set of tasks and their associated reward functions.""" @@ -125,7 +134,8 @@ def load( # disable datasets caching to avoid reuse old-version dataset datasets.disable_caching() if task_type == TaskType.EVAL: - dataset = load_dataset(config.dataset_path)[config.eval_split] + assert config.eval_split is not None, "eval_split must be provided for eval taskset." + dataset = load_hf_dataset(config, config.eval_split) else: # default if task_type != TaskType.EVAL and config.db_url != "": logger.info(f"Loading dataset from database with url: {config.db_url}") @@ -134,7 +144,7 @@ def load( dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}") elif config.dataset_path != "": logger.info(f"Loading dataset from local file with path: {config.dataset_path}.") - dataset = load_dataset(config.dataset_path)[config.train_split] + dataset = load_hf_dataset(config, config.train_split) else: raise ValueError("No dataset path or db url provided.") datasets.enable_caching()