Skip to content

Commit a9c650b

Browse files
authored
fix: some hf dataset may need specified config name (#25)
1 parent 24c407e commit a9c650b

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
data:
22
# basic info
3-
dataset_path: '/PATH/TO/DATASET/'
3+
dataset_path: 'openai/gsm8k'
4+
subset_name: "main"
45
train_split: 'train'
5-
eval_split: ''
6+
eval_split: 'test'
67
format_config:
78
prompt_key: 'question'
89
response_key: 'answer'
@@ -24,7 +25,7 @@ model:
2425
model_path: '/PATH/TO/MODEL/'
2526
max_prompt_tokens: 256
2627
max_response_tokens: 1024
27-
checkpoint_path: '/PATH/TO/CHECKPOINT/'
28+
checkpoint_path: ""
2829
cluster:
2930
node_num: 1
3031
gpu_per_node: 8
@@ -35,7 +36,7 @@ buffer:
3536
name: gsm8k_buffer
3637
storage_type: queue
3738
algorithm_type: ppo
38-
path: 'sqlite:////gsm8k.db'
39+
path: 'sqlite:///gsm8k.db'
3940
# sft_warmup_dataset: # Uncomment these to enable sft warmup
4041
# name: warmup_data
4142
# storage_type: file

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class DataConfig:
4545

4646
dataset_path: str = ""
4747
train_split: str = "train"
48+
subset_name: Optional[str] = None
4849
eval_split: Optional[str] = None # TODO: check data format
4950
format_config: FormatConfig = field(default_factory=FormatConfig)
5051

trinity/common/task.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def task_generator(
101101
yield task
102102

103103

104+
def load_hf_dataset(config: DataConfig, split: str):
105+
"""Load a Hugging Face dataset with optional configuration name."""
106+
if config.subset_name is not None:
107+
hf_dataset = load_dataset(config.dataset_path, config.subset_name, split=split)
108+
else:
109+
hf_dataset = load_dataset(config.dataset_path, split=split)
110+
return hf_dataset
111+
112+
104113
@dataclass
105114
class TaskSet:
106115
"""A TaskSet class that defines a set of tasks and their associated reward functions."""
@@ -125,7 +134,8 @@ def load(
125134
# disable datasets caching to avoid reuse old-version dataset
126135
datasets.disable_caching()
127136
if task_type == TaskType.EVAL:
128-
dataset = load_dataset(config.dataset_path)[config.eval_split]
137+
assert config.eval_split is not None, "eval_split must be provided for eval taskset."
138+
dataset = load_hf_dataset(config, config.eval_split)
129139
else: # default
130140
if task_type != TaskType.EVAL and config.db_url != "":
131141
logger.info(f"Loading dataset from database with url: {config.db_url}")
@@ -134,7 +144,7 @@ def load(
134144
dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
135145
elif config.dataset_path != "":
136146
logger.info(f"Loading dataset from local file with path: {config.dataset_path}.")
137-
dataset = load_dataset(config.dataset_path)[config.train_split]
147+
dataset = load_hf_dataset(config, config.train_split)
138148
else:
139149
raise ValueError("No dataset path or db url provided.")
140150
datasets.enable_caching()

0 commit comments

Comments
 (0)