Skip to content

Commit 4f5fd6a

Browse files
committed
Refactor on TaskSet
1 parent a182b1f commit 4f5fd6a

File tree

7 files changed

+299
-210
lines changed

7 files changed

+299
-210
lines changed

trinity/buffer/buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def get_buffer_reader(dataset_config: DatasetConfig, buffer_config: BufferConfig
3939

4040
return QueueReader(dataset_config, buffer_config)
4141
elif dataset_config.storage_type == StorageType.FILE:
42-
from trinity.buffer.reader.file_reader import FileReader
42+
from trinity.buffer.reader.file_reader import FileReaderManager
4343

44-
return FileReader(dataset_config, buffer_config)
44+
return FileReaderManager.create_reader(dataset_config, buffer_config)
4545
else:
4646
raise ValueError(f"{dataset_config.storage_type} not supported.")
4747

trinity/buffer/reader/file_reader.py

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,48 @@
22

33
from typing import List, Optional
44

5+
import datasets
56
import transformers
67
from datasets import load_dataset
78

89
from trinity.buffer.buffer_reader import BufferReader
910
from trinity.common.config import BufferConfig, DatasetConfig
10-
from trinity.common.constants import (
11-
AlgorithmType,
12-
PromptType,
13-
ReadStrategy,
14-
StorageType,
15-
)
11+
from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType
1612
from trinity.common.experience import Experience
13+
from trinity.common.rewards import REWARD_FUNCTIONS
14+
from trinity.common.task import Task
15+
from trinity.common.workflows import WORKFLOWS
1716

1817

19-
class FileReader(BufferReader):
20-
"""Reader of the File buffer."""
18+
class FileReaderManager:
19+
subclasses: dict = {}
2120

22-
def __init__(self, meta: DatasetConfig, config: BufferConfig) -> None:
23-
assert meta.storage_type == StorageType.FILE
24-
if meta.algorithm_type == AlgorithmType.SFT:
25-
self.reader = SFTDataReader(meta, config)
26-
elif meta.algorithm_type == AlgorithmType.DPO:
27-
self.reader = DPODataReader(meta, config)
28-
else:
29-
# TODO: support read rollout task
30-
raise ValueError(f"Unsupported algorithm type: {meta.algorithm_type}")
21+
@classmethod
22+
def register_subclass(cls, algorithm_type: AlgorithmType):
23+
def decorator(_cls):
24+
if algorithm_type not in cls.subclasses:
25+
cls.subclasses[algorithm_type] = _cls
26+
return _cls
3127

32-
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
33-
"""Read data from the buffer."""
34-
if strategy is not None and strategy != ReadStrategy.FIFO:
35-
raise ValueError(f"Unsupported read strategy: {strategy}")
36-
return self.reader.read()
28+
return decorator
29+
30+
@classmethod
31+
def create_reader(cls, meta: DatasetConfig, config: BufferConfig) -> BufferReader:
32+
def add_read_check(read_func):
33+
def wrapper(self, strategy: Optional[ReadStrategy] = None, *args, **kwargs):
34+
if strategy is not None and strategy != ReadStrategy.FIFO:
35+
raise ValueError(f"Unsupported read strategy: {strategy}")
36+
return read_func(self, strategy, *args, **kwargs)
37+
38+
return wrapper
39+
40+
subclasses = cls.subclasses[meta.algorithm_type]
41+
subclasses.read = add_read_check(subclasses.read)
42+
return subclasses(meta, config)
3743

3844

39-
class SFTDataReader:
45+
@FileReaderManager.register_subclass(AlgorithmType.SFT)
46+
class SFTDataReader(BufferReader):
4047
"""Reader for SFT file data."""
4148

4249
def __init__(self, meta: DatasetConfig, config: BufferConfig):
@@ -46,11 +53,11 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig):
4653
self.prompt_key = meta.kwargs.get("prompt_key", "prompt")
4754
self.response_key = meta.kwargs.get("response_key", "response")
4855
self.read_batch_size = config.read_batch_size
49-
self.dataset = load_dataset(meta.path)[self.train_split]
56+
self.dataset = load_dataset(meta.path)[self.train_split] # TODO: support resume
5057
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
5158
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
5259

53-
def read(self) -> List:
60+
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
5461
try:
5562
batch_data = next(self.data_iter)
5663
except StopIteration:
@@ -111,15 +118,16 @@ def read(self) -> List:
111118
return exp_list
112119

113120

114-
class DPODataReader:
121+
@FileReaderManager.register_subclass(AlgorithmType.DPO)
122+
class DPODataReader(BufferReader):
115123
def __init__(self, meta: DatasetConfig, config: BufferConfig):
116124
self.train_split = meta.kwargs.get("train_split", "train")
117125
self.prompt_type = PromptType(meta.kwargs.get("prompt_type", "messages"))
118126
self.prompt_key = meta.kwargs.get("prompt_key", "prompt")
119127
self.chosen_key = meta.kwargs.get("chosen_key", "chosen")
120128
self.rejected_key = meta.kwargs.get("rejected_key", "rejected")
121129
self.read_batch_size = config.read_batch_size
122-
self.dataset = load_dataset(meta.path)[self.train_split]
130+
self.dataset = load_dataset(meta.path)[self.train_split] # TODO: support resume
123131
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
124132
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
125133

@@ -131,7 +139,7 @@ def _get_assistant_message(self, item) -> dict:
131139
else:
132140
return item
133141

134-
def read(self) -> List:
142+
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
135143
try:
136144
batch_data = next(self.data_iter)
137145
except StopIteration:
@@ -178,3 +186,59 @@ def read(self) -> List:
178186
)
179187
exp_list.append(experience)
180188
return exp_list
189+
190+
191+
@FileReaderManager.register_subclass(AlgorithmType.ROLLOUT)
192+
class RolloutDataReader(BufferReader):
193+
def __init__(self, meta: DatasetConfig, config: BufferConfig):
194+
self.split = meta.kwargs.get("split", "train")
195+
name = meta.kwargs.get("name", None)
196+
# disable datasets caching to avoid reuse old-version dataset
197+
datasets.disable_caching()
198+
self.dataset = load_dataset(meta.path, name=name, split=self.split) # TODO: may from db_url
199+
# if task_type != TaskType.EVAL and config.db_url != "":
200+
# logger.info(f"Loading dataset from database with url: {config.db_url}")
201+
# db_type = config.db_url.split(":")[0]
202+
# db_name = config.db_url.split("/")[-1]
203+
# dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
204+
datasets.enable_caching()
205+
self.index = meta.kwargs.get("index", 0) # TODO: apply shuffle
206+
207+
self.prompt_key = meta.format_config.prompt_key
208+
self.response_key = meta.format_config.response_key
209+
self.workflow_key = meta.format_config.workflow_key
210+
self.reward_fn_key = meta.format_config.reward_fn_key
211+
212+
self.task_type = meta.kwargs.get("task_type", TaskType.EXPLORE)
213+
self.default_workflow_cls = WORKFLOWS.get(meta.kwargs.get("default_workflow_type", None))
214+
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(
215+
meta.kwargs.get("default_reward_fn_type", None)
216+
)
217+
self.total_epochs = (
218+
meta.kwargs.get("total_epochs", 1) if self.task_type == TaskType.EXPLORE else 1
219+
)
220+
221+
def read(self, strategy: Optional[ReadStrategy] = None):
222+
sample = self.dataset[self.index % len(self.dataset)]
223+
task_desc = sample[self.prompt_key] if self.prompt_key in sample else None
224+
truth = sample[self.response_key] if self.response_key in sample else None
225+
workflow_class = (
226+
WORKFLOWS.get(sample[self.workflow_key])
227+
if self.workflow_key in sample
228+
else self.default_workflow_cls
229+
)
230+
reward_fn = (
231+
REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
232+
if self.reward_fn_key in sample
233+
else self.default_reward_fn_cls
234+
)
235+
task = Task(
236+
task_desc=task_desc,
237+
truth=truth,
238+
workflow=workflow_class,
239+
reward_fn=reward_fn,
240+
raw=sample,
241+
task_type=self.task_type,
242+
)
243+
self.index += 1
244+
return task

trinity/common/config.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
class FormatConfig:
2323
"""Configuration for data formatting"""
2424

25-
prompt_key: str = ""
26-
response_key: str = ""
25+
prompt_key: str = "prompt"
26+
response_key: str = "response"
2727
chat_template: str = ""
2828

2929
# for sample-level task controlling
@@ -36,8 +36,8 @@ class FormatConfig:
3636
reward_key: str = ""
3737

3838
# for dpo dataset
39-
chosen_key: str = ""
40-
rejected_key: str = ""
39+
chosen_key: str = "chosen"
40+
rejected_key: str = "rejected"
4141

4242
# for unpaired preference dataset
4343
label_key: str = ""
@@ -110,6 +110,7 @@ class DatasetConfig:
110110
algorithm_type: AlgorithmType = AlgorithmType.PPO
111111
path: Optional[str] = None
112112
namespace: str = "" # automatically generated
113+
format_config: FormatConfig = field(default_factory=FormatConfig)
113114
kwargs: Dict[str, Any] = field(default_factory=dict)
114115

115116

0 commit comments

Comments
 (0)