Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ buffer:
train_dataset:
storage_type: file
path: <$DATASET_PATH/human_like_dpo_dataset>
kwargs:
format_config:
prompt_type: <prompt_type> # messages/plaintext
prompt_key: <prompt_key>
chosen_key: <chosen_key>
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ buffer:
sft_warmup_dataset:
storage_type: file
path: <$DATASET_PATH/{sft_data}>
kwargs:
format_config:
prompt_type: <prompt_type> # messages/plaintext/chatpair
prompt_key: <prompt_key>
response_key: <response_key>
Expand Down
6 changes: 2 additions & 4 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ data:
dataset_path: '/PATH/TO/DATASET'
train_split: 'train'
eval_split: ''
dataset_config:
split: 'train'
format_config:
prompt_key: 'question'
response_key: 'answer'
Expand All @@ -40,10 +38,10 @@ data:
default_reward_fn_type: 'countdown_reward'
```

- `data.dataset_path`: The path to the dataset.
<!-- - `data.dataset_path`: The path to the dataset. -->
- `data.train_split`: The split name of the dataset used for training. Default is `train`.
- `data.eval_split`: The split name of the dataset used for eval.
- `data.dataset_config`: The configuration for the dataset. <!-- TODO: may only used in Data-Juicer -->
<!-- - `data.dataset_config`: The configuration for the dataset. TODO: may only used in Data-Juicer -->
- `data.format_config`: The configuration for the format of the dataset.
- `data.db_url`: The URL of the database.
- `data.max_retry_times`: The maximum number of retries when loading the dataset from database.
Expand Down
2 changes: 1 addition & 1 deletion examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ buffer:
name: dpo_buffer
storage_type: file
path: '/PATH/TO/DATASET/'
kwargs:
format_config:
prompt_type: plaintext # plaintext/messages
prompt_key: prompt
chosen_key: chosen
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tests.tools import RayUnittestBase
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import BufferConfig, DatasetConfig
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience

Expand All @@ -13,7 +13,7 @@ def test_queue_buffer(self):
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
meta = StorageConfig(
name="test_buffer",
namespace="test_namespace",
algorithm_type=AlgorithmType.PPO,
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, DatasetConfig
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience

Expand All @@ -17,7 +17,7 @@ def test_create_sql_buffer(self) -> None:
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
meta = StorageConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
path=f"sqlite:///{db_path}",
Expand Down
12 changes: 5 additions & 7 deletions tests/data/core/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import unittest

from trinity.common.config import DataConfig, FormatConfig
from trinity.common.config import DataProcessorConfig, FormatConfig
from trinity.common.rewards import AccuracyReward
from trinity.common.task import TaskSet
from trinity.common.workflows import MathWorkflow, SimpleWorkflow
Expand All @@ -15,30 +15,28 @@ class TestRftDataset(unittest.TestCase):
"""Test cases for RftDataset"""

def setUp(self) -> None:
self.data_config = DataConfig(
dataset_path=os.path.join(
self.data_config = DataProcessorConfig(
raw_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
prompt_key="problem",
response_key="solution",
solution_key="solution",
),
)
self.data_config_sample_level_setting = DataConfig(
dataset_path=os.path.join(
self.data_config_sample_level_setting = DataProcessorConfig(
raw_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10_with_rewfn_workflow",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
prompt_key="problem",
response_key="solution",
Expand Down
27 changes: 11 additions & 16 deletions tests/data/core/formatter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import unittest

from trinity.common.config import DataConfig, FormatConfig
from trinity.common.config import DataProcessorConfig, FormatConfig
from trinity.data.core.dataset import RftDataset
from trinity.data.core.formatter import (
BoxedMathAnswerFormatter,
Expand All @@ -18,15 +18,14 @@ class TestBoxedMathDataset(unittest.TestCase):
"""Test cases for RftDataset"""

def setUp(self) -> None:
self.data_config = DataConfig(
dataset_path=os.path.join(
self.data_config = DataProcessorConfig(
raw_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
prompt_key="problem",
response_key="answer",
Expand Down Expand Up @@ -60,15 +59,14 @@ class TestRLHFFormatter(unittest.TestCase):
"""Test cases for RLHFFormatter"""

def setUp(self) -> None:
self.data_config = DataConfig(
dataset_path=os.path.join(
self.data_config = DataProcessorConfig(
raw_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
prompt_key="problem",
chat_template="User: {}\nAssistant: ",
Expand Down Expand Up @@ -109,15 +107,14 @@ class TestRewardFormatter(unittest.TestCase):
"""Test cases for RewardFormatter"""

def setUp(self) -> None:
self.data_config = DataConfig(
dataset_path=os.path.join(
self.data_config = DataProcessorConfig(
raw_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
prompt_key="problem",
chosen_key="chosen",
Expand Down Expand Up @@ -167,15 +164,14 @@ class TestSFTFormatter(unittest.TestCase):
"""Test cases for SFTFormatter"""

def setUp(self) -> None:
self.data_config = DataConfig(
dataset_path=os.path.join(
self.data_config = DataProcessorConfig(
raw_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
prompt_key="problem",
response_key="answer",
Expand Down Expand Up @@ -221,15 +217,14 @@ class TestComposedFormatter(unittest.TestCase):
"""Test cases for ComposedFormatter"""

def setUp(self) -> None:
self.data_config = DataConfig(
dataset_path=os.path.join(
self.data_config = DataProcessorConfig(
raw_data_path=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"test_data",
"test_10",
),
dataset_config={"split": "train"},
format_config=FormatConfig(
prompt_key="problem",
response_key="answer",
Expand Down
3 changes: 1 addition & 2 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def setUp(self):
self.config.monitor.project = "Trinity-unittest"
self.config.model.checkpoint_path = get_checkpoint_path()
self.config.synchronizer.sync_interval = 2
self.config.explorer.eval_interval = 4
self.config.trainer.eval_interval = 4
self.config.global_config.eval_interval = 4

@abstractmethod
def test_explorer(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from trinity.buffer.reader.queue_reader import QueueReader
from trinity.common.config import DatasetConfig, load_config
from trinity.common.config import StorageConfig, load_config
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
Expand Down Expand Up @@ -68,7 +68,7 @@ def setUp(self):
self.config.explorer.max_timeout = 5
self.config.buffer.read_batch_size = 2
self.config.buffer.pad_token_id = 0
self.config.buffer.train_dataset = DatasetConfig(
self.config.buffer.train_dataset = StorageConfig(
name="test",
namespace="test_runner_pool",
storage_type=StorageType.QUEUE,
Expand Down
14 changes: 4 additions & 10 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ray
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

from trinity.common.config import Config, DataConfig, FormatConfig, load_config
from trinity.common.config import Config, DataProcessorConfig, FormatConfig, load_config


def get_template_config() -> Config:
Expand All @@ -32,17 +32,11 @@ def get_checkpoint_path() -> str:
return path


def get_unittest_dataset_config(dataset_name: str = "countdown") -> DataConfig:
def get_unittest_dataset_config(dataset_name: str = "countdown") -> DataProcessorConfig:
"""Countdown sample dataset for 8 steps"""
if dataset_name == "countdown":
return DataConfig(
total_epochs=2,
batch_size=4,
default_workflow_type="math_workflow",
default_reward_fn_type="countdown_reward",
dataset_path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"),
train_split="train",
eval_split="test",
return DataProcessorConfig(
raw_data_path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"),
format_config=FormatConfig(
prompt_key="question",
response_key="answer",
Expand Down
3 changes: 1 addition & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def setUp(self):
)
self.config.synchronizer.sync_interval = 2
self.config.synchronizer.sync_method = SyncMethod.NCCL
self.config.explorer.eval_interval = 4
self.config.trainer.eval_interval = 4
self.config.global_config.eval_interval = 4

@abstractmethod
def test_trainer(self):
Expand Down
16 changes: 8 additions & 8 deletions trinity/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.buffer_writer import BufferWriter
from trinity.common.config import BufferConfig, Config, DatasetConfig
from trinity.common.config import BufferConfig, Config, StorageConfig
from trinity.common.constants import StorageType


Expand All @@ -13,22 +13,22 @@ class Buffer:
"""Responsible for storing experiences."""

def __init__(self, config: Config):
self.buffer_mapping: dict[str, DatasetConfig] = {}
self.buffer_mapping: dict[str, StorageConfig] = {}
self._register_from_config(config)

def get_dataset_info(self, dataset_name: str) -> DatasetConfig:
def get_dataset_info(self, dataset_name: str) -> StorageConfig:
dataset_config = self.buffer_mapping.get(dataset_name, None)
if dataset_config is None:
raise ValueError(f"{dataset_name} not found.")
return dataset_config

def register_dataset(self, dataset_config: DatasetConfig) -> None:
def register_dataset(self, dataset_config: StorageConfig) -> None:
if dataset_config.name in self.buffer_mapping:
raise ValueError(f"{dataset_config.name} already exists.")
self.buffer_mapping[dataset_config.name] = dataset_config


def get_buffer_reader(dataset_config: DatasetConfig, buffer_config: BufferConfig) -> BufferReader:
def get_buffer_reader(dataset_config: StorageConfig, buffer_config: BufferConfig) -> BufferReader:
"""Get a buffer reader for the given dataset name."""
if dataset_config.storage_type == StorageType.SQL:
from trinity.buffer.reader.sql_reader import SQLReader
Expand All @@ -39,14 +39,14 @@ def get_buffer_reader(dataset_config: DatasetConfig, buffer_config: BufferConfig

return QueueReader(dataset_config, buffer_config)
elif dataset_config.storage_type == StorageType.FILE:
from trinity.buffer.reader.file_reader import FileReader
from trinity.buffer.reader.file_reader import FileReaderManager

return FileReader(dataset_config, buffer_config)
return FileReaderManager.create_reader(dataset_config, buffer_config)
else:
raise ValueError(f"{dataset_config.storage_type} not supported.")


def get_buffer_writer(dataset_config: DatasetConfig, buffer_config: BufferConfig) -> BufferWriter:
def get_buffer_writer(dataset_config: StorageConfig, buffer_config: BufferConfig) -> BufferWriter:
"""Get a buffer writer for the given dataset name."""
if dataset_config.storage_type == StorageType.SQL:
from trinity.buffer.writer.sql_writer import SQLWriter
Expand Down
4 changes: 2 additions & 2 deletions trinity/buffer/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ray

from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, DatasetConfig
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType


Expand All @@ -16,7 +16,7 @@ class QueueActor:

FINISH_MESSAGE = "$FINISH$"

def __init__(self, dataset_config: DatasetConfig, config: BufferConfig) -> None:
def __init__(self, dataset_config: StorageConfig, config: BufferConfig) -> None:
self.config = config
self.capacity = getattr(config, "capacity", 10000)
self.queue = asyncio.Queue(self.capacity)
Expand Down
Loading