Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
data:
# basic info
dataset_path: '/PATH/TO/DATASET/'
dataset_path: 'openai/gsm8k'
subset_name: "main"
train_split: 'train'
eval_split: ''
eval_split: 'test'
format_config:
prompt_key: 'question'
response_key: 'answer'
Expand All @@ -24,7 +25,7 @@ model:
model_path: '/PATH/TO/MODEL/'
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: '/PATH/TO/CHECKPOINT/'
checkpoint_path: ""
cluster:
node_num: 1
gpu_per_node: 8
Expand All @@ -35,7 +36,7 @@ buffer:
name: gsm8k_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////gsm8k.db'
path: 'sqlite:///gsm8k.db'
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DataConfig:

dataset_path: str = ""
train_split: str = "train"
subset_name: Optional[str] = None
eval_split: Optional[str] = None # TODO: check data format
format_config: FormatConfig = field(default_factory=FormatConfig)

Expand Down
14 changes: 12 additions & 2 deletions trinity/common/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def task_generator(
yield task


def load_hf_dataset(config: DataConfig, split: str):
"""Load a Hugging Face dataset with optional configuration name."""
if config.subset_name is not None:
hf_dataset = load_dataset(config.dataset_path, config.subset_name, split=split)
else:
hf_dataset = load_dataset(config.dataset_path, split=split)
return hf_dataset


@dataclass
class TaskSet:
"""A TaskSet class that defines a set of tasks and their associated reward functions."""
Expand All @@ -125,7 +134,8 @@ def load(
# disable datasets caching to avoid reuse old-version dataset
datasets.disable_caching()
if task_type == TaskType.EVAL:
dataset = load_dataset(config.dataset_path)[config.eval_split]
assert config.eval_split is not None, "eval_split must be provided for eval taskset."
dataset = load_hf_dataset(config, config.eval_split)
else: # default
if task_type != TaskType.EVAL and config.db_url != "":
logger.info(f"Loading dataset from database with url: {config.db_url}")
Expand All @@ -134,7 +144,7 @@ def load(
dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
elif config.dataset_path != "":
logger.info(f"Loading dataset from local file with path: {config.dataset_path}.")
dataset = load_dataset(config.dataset_path)[config.train_split]
dataset = load_hf_dataset(config, config.train_split)
else:
raise ValueError("No dataset path or db url provided.")
datasets.enable_caching()
Expand Down