Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
data:
# basic info
dataset_path: '/PATH/TO/DATASET/'
dataset_path: 'openai/gsm8k'
config_name: "main"
train_split: 'train'
eval_split: ''
eval_split: 'test'
format_config:
prompt_key: 'question'
response_key: 'answer'
Expand All @@ -24,7 +25,7 @@ model:
model_path: '/PATH/TO/MODEL/'
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: '/PATH/TO/CHECKPOINT/'
checkpoint_path: ""
cluster:
node_num: 1
gpu_per_node: 8
Expand All @@ -35,7 +36,7 @@ buffer:
name: gsm8k_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////gsm8k.db'
path: 'sqlite:///gsm8k.db'
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DataConfig:

dataset_path: str = ""
train_split: str = "train"
config_name: Optional[str] = None
eval_split: Optional[str] = None # TODO: check data format
format_config: FormatConfig = field(default_factory=FormatConfig)

Expand Down
14 changes: 12 additions & 2 deletions trinity/common/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def task_generator(
yield task


def load_hf_dataset(config: DataConfig, split: str):
"""Load a Hugging Face dataset with optional configuration name."""
if config.config_name is not None:
hf_dataset = load_dataset(config.dataset_path, config.config_name, split=split)
else:
hf_dataset = load_dataset(config.dataset_path, split=split)
return hf_dataset


@dataclass
class TaskSet:
"""A TaskSet class that defines a set of tasks and their associated reward functions."""
Expand All @@ -125,7 +134,8 @@ def load(
# disable datasets caching to avoid reuse old-version dataset
datasets.disable_caching()
if task_type == TaskType.EVAL:
dataset = load_dataset(config.dataset_path)[config.eval_split]
assert config.eval_split is not None, "eval_split must be provided for eval taskset."
dataset = load_hf_dataset(config, config.eval_split)
else: # default
if task_type != TaskType.EVAL and config.db_url != "":
logger.info(f"Loading dataset from database with url: {config.db_url}")
Expand All @@ -134,7 +144,7 @@ def load(
dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
elif config.dataset_path != "":
logger.info(f"Loading dataset from local file with path: {config.dataset_path}.")
dataset = load_dataset(config.dataset_path)[config.train_split]
dataset = load_hf_dataset(config, config.train_split)
else:
raise ValueError("No dataset path or db url provided.")
datasets.enable_caching()
Expand Down