Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/sphinx_doc/source/tutorial/example_data_functionalities.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ data:
dataset_path: '/path/to/gsm8k'
dataset_config:
split: 'train' # only need the train split
format_config: # set the field mappings
format: # set the field mappings
prompt_key: 'question'
response_key: 'answer'
# database related. The result dataset will be stored in the database.
Expand All @@ -51,7 +51,7 @@ Here you can set the basic information for the GSM-8K dataset, database informat

+ `dataset_path`: the path to the raw dataset.
+ `dataset_config`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library.
+ `format_config`: some dataset format config items, which are used to map original data field names to unified ones.
+ `format`: some dataset format config items, which are used to map original data field names to unified ones.
+ `db_url`: the URL of the postgresql database to store the result dataset.
+ `total_epochs`: the total number of epochs to train on this dataset.
+ `batch_size`: the training batch size.
Expand All @@ -68,7 +68,7 @@ data:
dataset_path: '/path/to/gsm8k'
dataset_config:
split: 'train' # only need the train split
format_config: # set the field mappings
format: # set the field mappings
prompt_key: 'question'
response_key: 'answer'
# database related. The result dataset will be stored in the database.
Expand Down Expand Up @@ -114,7 +114,7 @@ data:
dataset_path: '/path/to/gsm8k'
dataset_config:
split: 'train' # only need the train split
format_config: # set the field mappings
format: # set the field mappings
prompt_key: 'question'
response_key: 'answer'
# database related. The result dataset will be stored in the database.
Expand Down Expand Up @@ -190,7 +190,7 @@ data:
dataset_path: 'tests/test_data/test_human_annotator'
dataset_config:
split: 'train' # only need the train split
format_config: # set the field mappings
format: # set the field mappings
prompt_key: 'prompt'
chosen_key: 'chosen'
rejected_key: 'rejected'
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ buffer:
train_dataset:
storage_type: file
path: <$DATASET_PATH/human_like_dpo_dataset>
kwargs:
format:
prompt_type: <prompt_type> # messages/plaintext
prompt_key: <prompt_key>
chosen_key: <chosen_key>
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ buffer:
sft_warmup_dataset:
storage_type: file
path: <$DATASET_PATH/{sft_data}>
kwargs:
format:
prompt_type: <prompt_type> # messages/plaintext/chatpair
prompt_key: <prompt_key>
response_key: <response_key>
Expand Down
12 changes: 5 additions & 7 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ data:
dataset_path: '/PATH/TO/DATASET'
train_split: 'train'
eval_split: ''
dataset_config:
split: 'train'
format_config:
format:
prompt_key: 'question'
response_key: 'answer'

Expand All @@ -40,11 +38,11 @@ data:
default_reward_fn_type: 'countdown_reward'
```

- `data.dataset_path`: The path to the dataset.
<!-- - `data.dataset_path`: The path to the dataset. -->
- `data.train_split`: The split name of the dataset used for training. Default is `train`.
- `data.eval_split`: The split name of the dataset used for eval.
- `data.dataset_config`: The configuration for the dataset. <!-- TODO: may only used in Data-Juicer -->
- `data.format_config`: The configuration for the format of the dataset.
<!-- - `data.dataset_config`: The configuration for the dataset. TODO: may only used in Data-Juicer -->
- `data.format`: The configuration for the format of the dataset.
- `data.db_url`: The URL of the database.
- `data.max_retry_times`: The maximum number of retries when loading the dataset from database.
- `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database.
Expand All @@ -53,7 +51,7 @@ data:
- `data.default_workflow_type`: The default workflow type used for training.
- `data.default_reward_fn_type`: The default reward function type used for training.

<!-- TODO explain the dataset_config and format_config -->
<!-- TODO explain the dataset_config and format -->

## Model

Expand Down
2 changes: 1 addition & 1 deletion examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data:
subset_name: ''
train_split: 'train'
eval_split: 'test'
format_config:
format:
prompt_key: 'question'
response_key: 'answer'
# downstream loading related
Expand Down
2 changes: 1 addition & 1 deletion examples/async_gsm8k/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data:
subset_name: ''
train_split: 'train'
eval_split: 'test'
format_config:
format:
prompt_key: 'question'
response_key: 'answer'
# downstream loading related
Expand Down
4 changes: 2 additions & 2 deletions examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data:
train_split: "train"
dataset_path: ''
default_workflow_type: 'math_workflow'
format_config:
format:
prompt_key: ''
response_key: ''
model:
Expand All @@ -23,7 +23,7 @@ buffer:
name: dpo_buffer
storage_type: file
path: '/PATH/TO/DATASET/'
kwargs:
format:
prompt_type: plaintext # plaintext/messages
prompt_key: prompt
chosen_key: chosen
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data:
default_workflow_type: 'alfworld_workflow'
train_split: 'train'
eval_split: ''
format_config:
format:
prompt_key: 'game_file'
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ data:
subset_name: "main"
train_split: 'train'
eval_split: 'test'
format_config:
format:
prompt_key: 'question'
response_key: 'answer'
# data active iterator related
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ data:
# dataset_config:
train_split: train
eval_split: test
format_config:
format:
prompt_key: 'question'
response_key: 'gt_answer'
# db related
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data:
default_workflow_type: 'sciworld_workflow'
train_split: 'train'
eval_split: ''
format_config:
format:
prompt_key: 'game_file'
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data:
default_workflow_type: 'webshop_workflow'
train_split: 'train'
eval_split: ''
format_config:
format:
prompt_key: 'task_id'
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
Expand Down
2 changes: 1 addition & 1 deletion examples/opmd_gsm8k/opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ data:
batch_size: 96
dataset_path: '{path to datasets}/gsm8k'
default_workflow_type: 'math_workflow'
format_config:
format:
prompt_key: 'question'
response_key: 'answer'
model:
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_countdown/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
train_split: 'train'
eval_split: ''
default_reward_fn_type: 'countdown_reward'
format_config:
format:
prompt_key: 'question'
response_key: 'answer'
model:
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tests.tools import RayUnittestBase
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import BufferConfig, DatasetConfig
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience

Expand All @@ -13,7 +13,7 @@ def test_queue_buffer(self):
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
meta = StorageConfig(
name="test_buffer",
namespace="test_namespace",
algorithm_type=AlgorithmType.PPO,
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, DatasetConfig
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience

Expand All @@ -17,7 +17,7 @@ def test_create_sql_buffer(self) -> None:
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
meta = StorageConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
path=f"sqlite:///{db_path}",
Expand Down
20 changes: 9 additions & 11 deletions tests/data/core/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import unittest

from trinity.common.config import DataConfig, FormatConfig
from trinity.common.config import DataProcessorConfig, FormatConfig
from trinity.common.rewards import AccuracyReward
from trinity.common.task import TaskSet
from trinity.common.workflows import MathWorkflow, SimpleWorkflow
Expand All @@ -15,31 +15,29 @@ class TestRftDataset(unittest.TestCase):
"""Test cases for RftDataset"""

def setUp(self) -> None:
self.data_config = DataConfig(
dataset_path=os.path.join(
self.data_config = DataProcessorConfig(
source_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
format=FormatConfig(
prompt_key="problem",
response_key="solution",
solution_key="solution",
),
)
self.data_config_sample_level_setting = DataConfig(
dataset_path=os.path.join(
self.data_config_sample_level_setting = DataProcessorConfig(
source_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10_with_rewfn_workflow",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
format=FormatConfig(
prompt_key="problem",
response_key="solution",
solution_key="solution",
Expand All @@ -64,8 +62,8 @@ def test_format_dataset(self):
# apply formatters
dataset.format(
formatters=[
BoxedMathAnswerFormatter(config=self.data_config.format_config),
RLHFFormatter(config=self.data_config.format_config),
BoxedMathAnswerFormatter(config=self.data_config.format),
RLHFFormatter(config=self.data_config.format),
]
)
self.assertNotEqual(dataset.data, original_data)
Expand Down
Loading