22
33from typing import List , Optional
44
5+ import datasets
56import transformers
67from datasets import load_dataset
78
89from trinity .buffer .buffer_reader import BufferReader
910from trinity .common .config import BufferConfig , DatasetConfig
10- from trinity .common .constants import (
11- AlgorithmType ,
12- PromptType ,
13- ReadStrategy ,
14- StorageType ,
15- )
11+ from trinity .common .constants import AlgorithmType , PromptType , ReadStrategy , TaskType
1612from trinity .common .experience import Experience
13+ from trinity .common .rewards import REWARD_FUNCTIONS
14+ from trinity .common .task import Task
15+ from trinity .common .workflows import WORKFLOWS
1716
1817
19- class FileReader ( BufferReader ) :
20- """Reader of the File buffer."""
18+ class FileReaderManager :
19+ subclasses : dict = {}
2120
22- def __init__ (self , meta : DatasetConfig , config : BufferConfig ) -> None :
23- assert meta .storage_type == StorageType .FILE
24- if meta .algorithm_type == AlgorithmType .SFT :
25- self .reader = SFTDataReader (meta , config )
26- elif meta .algorithm_type == AlgorithmType .DPO :
27- self .reader = DPODataReader (meta , config )
28- else :
29- # TODO: support read rollout task
30- raise ValueError (f"Unsupported algorithm type: { meta .algorithm_type } " )
21+ @classmethod
22+ def register_subclass (cls , algorithm_type : AlgorithmType ):
23+ def decorator (_cls ):
24+ if algorithm_type not in cls .subclasses :
25+ cls .subclasses [algorithm_type ] = _cls
26+ return _cls
3127
32- def read (self , strategy : Optional [ReadStrategy ] = None ) -> List :
33- """Read data from the buffer."""
34- if strategy is not None and strategy != ReadStrategy .FIFO :
35- raise ValueError (f"Unsupported read strategy: { strategy } " )
36- return self .reader .read ()
28+ return decorator
29+
30+ @classmethod
31+ def create_reader (cls , meta : DatasetConfig , config : BufferConfig ) -> BufferReader :
32+ def add_read_check (read_func ):
33+ def wrapper (self , strategy : Optional [ReadStrategy ] = None , * args , ** kwargs ):
34+ if strategy is not None and strategy != ReadStrategy .FIFO :
35+ raise ValueError (f"Unsupported read strategy: { strategy } " )
36+ return read_func (self , strategy , * args , ** kwargs )
37+
38+ return wrapper
39+
40+ subclasses = cls .subclasses [meta .algorithm_type ]
41+ subclasses .read = add_read_check (subclasses .read )
42+ return subclasses (meta , config )
3743
3844
39- class SFTDataReader :
45+ @FileReaderManager .register_subclass (AlgorithmType .SFT )
46+ class SFTDataReader (BufferReader ):
4047 """Reader for SFT file data."""
4148
4249 def __init__ (self , meta : DatasetConfig , config : BufferConfig ):
@@ -46,11 +53,11 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig):
4653 self .prompt_key = meta .kwargs .get ("prompt_key" , "prompt" )
4754 self .response_key = meta .kwargs .get ("response_key" , "response" )
4855 self .read_batch_size = config .read_batch_size
49- self .dataset = load_dataset (meta .path )[self .train_split ]
56+ self .dataset = load_dataset (meta .path )[self .train_split ] # TODO: support resume
5057 self .data_iter = self .dataset .iter (self .read_batch_size , drop_last_batch = True )
5158 self .tokenizer = transformers .AutoTokenizer .from_pretrained (config .tokenizer_path )
5259
53- def read (self ) -> List :
60+ def read (self , strategy : Optional [ ReadStrategy ] = None ) -> List :
5461 try :
5562 batch_data = next (self .data_iter )
5663 except StopIteration :
@@ -111,15 +118,16 @@ def read(self) -> List:
111118 return exp_list
112119
113120
114- class DPODataReader :
121+ @FileReaderManager .register_subclass (AlgorithmType .DPO )
122+ class DPODataReader (BufferReader ):
115123 def __init__ (self , meta : DatasetConfig , config : BufferConfig ):
116124 self .train_split = meta .kwargs .get ("train_split" , "train" )
117125 self .prompt_type = PromptType (meta .kwargs .get ("prompt_type" , "messages" ))
118126 self .prompt_key = meta .kwargs .get ("prompt_key" , "prompt" )
119127 self .chosen_key = meta .kwargs .get ("chosen_key" , "chosen" )
120128 self .rejected_key = meta .kwargs .get ("rejected_key" , "rejected" )
121129 self .read_batch_size = config .read_batch_size
122- self .dataset = load_dataset (meta .path )[self .train_split ]
130+ self .dataset = load_dataset (meta .path )[self .train_split ] # TODO: support resume
123131 self .data_iter = self .dataset .iter (self .read_batch_size , drop_last_batch = True )
124132 self .tokenizer = transformers .AutoTokenizer .from_pretrained (config .tokenizer_path )
125133
@@ -131,7 +139,7 @@ def _get_assistant_message(self, item) -> dict:
131139 else :
132140 return item
133141
134- def read (self ) -> List :
142+ def read (self , strategy : Optional [ ReadStrategy ] = None ) -> List :
135143 try :
136144 batch_data = next (self .data_iter )
137145 except StopIteration :
@@ -178,3 +186,59 @@ def read(self) -> List:
178186 )
179187 exp_list .append (experience )
180188 return exp_list
189+
190+
191+ @FileReaderManager .register_subclass (AlgorithmType .ROLLOUT )
192+ class RolloutDataReader (BufferReader ):
193+ def __init__ (self , meta : DatasetConfig , config : BufferConfig ):
194+ self .split = meta .kwargs .get ("split" , "train" )
195+ name = meta .kwargs .get ("name" , None )
196+ # disable datasets caching to avoid reuse old-version dataset
197+ datasets .disable_caching ()
198+ self .dataset = load_dataset (meta .path , name = name , split = self .split ) # TODO: may from db_url
199+ # if task_type != TaskType.EVAL and config.db_url != "":
200+ # logger.info(f"Loading dataset from database with url: {config.db_url}")
201+ # db_type = config.db_url.split(":")[0]
202+ # db_name = config.db_url.split("/")[-1]
203+ # dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
204+ datasets .enable_caching ()
205+ self .index = meta .kwargs .get ("index" , 0 ) # TODO: apply shuffle
206+
207+ self .prompt_key = meta .format_config .prompt_key
208+ self .response_key = meta .format_config .response_key
209+ self .workflow_key = meta .format_config .workflow_key
210+ self .reward_fn_key = meta .format_config .reward_fn_key
211+
212+ self .task_type = meta .kwargs .get ("task_type" , TaskType .EXPLORE )
213+ self .default_workflow_cls = WORKFLOWS .get (meta .kwargs .get ("default_workflow_type" , None ))
214+ self .default_reward_fn_cls = REWARD_FUNCTIONS .get (
215+ meta .kwargs .get ("default_reward_fn_type" , None )
216+ )
217+ self .total_epochs = (
218+ meta .kwargs .get ("total_epochs" , 1 ) if self .task_type == TaskType .EXPLORE else 1
219+ )
220+
221+ def read (self , strategy : Optional [ReadStrategy ] = None ):
222+ sample = self .dataset [self .index % len (self .dataset )]
223+ task_desc = sample [self .prompt_key ] if self .prompt_key in sample else None
224+ truth = sample [self .response_key ] if self .response_key in sample else None
225+ workflow_class = (
226+ WORKFLOWS .get (sample [self .workflow_key ])
227+ if self .workflow_key in sample
228+ else self .default_workflow_cls
229+ )
230+ reward_fn = (
231+ REWARD_FUNCTIONS .get (sample [self .reward_fn_key ])
232+ if self .reward_fn_key in sample
233+ else self .default_reward_fn_cls
234+ )
235+ task = Task (
236+ task_desc = task_desc ,
237+ truth = truth ,
238+ workflow = workflow_class ,
239+ reward_fn = reward_fn ,
240+ raw = sample ,
241+ task_type = self .task_type ,
242+ )
243+ self .index += 1
244+ return task
0 commit comments