11import os
22import traceback
33from numbers import Number
4- from typing import Any , Dict , List
4+ from typing import Any , Dict , List , Optional
5+ from functools import partial
6+ from data_juicer .utils .constant import Fields
57
68import ray
79
8- from trinity .common .config import BufferConfig , DataPipelineConfig
10+ from trinity .common .config import BufferConfig , DataPipelineConfig , RewardShapingConfig
11+ from trinity .common .constants import DataProcessorPipelineType , OpType
912from trinity .data .controllers .default_ops import DIMENSION_STATS_KEYS
1013from trinity .data .controllers .task_parser import DataTaskParser
1114from trinity .data .core .dataset import RftDataset
@@ -23,9 +26,24 @@ def __init__(
2326 self ,
2427 config : DataPipelineConfig ,
2528 buffer_config : BufferConfig ,
29+ pipeline_type : Optional [DataProcessorPipelineType , str ] = DataProcessorPipelineType .TASK ,
2630 ):
31+ """
32+ The initialization method.
33+
34+ :param config: the data pipeline config.
35+ :param buffer_config: the buffer config.
36+ :param pipeline_type: the type of the activated pipeline.
37+ """
2738 self .config = config
2839 self .buffer_config = buffer_config
40+ self .pipeline_type = pipeline_type
41+ if self .pipeline_type is None :
42+ self .pipeline_type = DataProcessorPipelineType .TASK
43+ if isinstance (self .pipeline_type , str ):
44+ self .pipeline_type = DataProcessorPipelineType (pipeline_type )
45+
46+ # check if the llm agent is required
2947 if self .config .agent_model_name is not None and self .config .agent_model_config is not None :
3048 # get the api key
3149 api_key = os .environ .get ("OPENAI_API_KEY" )
@@ -42,6 +60,8 @@ def __init__(
4260 )
4361 else :
4462 self .llm_agent = None
63+
64+ # init task parser
4565 self .task_parser = DataTaskParser (config , self .llm_agent )
4666
4767 # Priority weights
@@ -153,34 +173,42 @@ def run(self):
153173 traceback .print_exc ()
154174 return 6 , "Grouping and computing priority score failed."
155175
156- # step 7. track lineage if they are changed
176+ # step 7. reward shaping. Only available for experience pipeline and the reward shaping config is set
177+ try :
178+ if self .pipeline_type == DataProcessorPipelineType .EXPERIENCE and len (self .config .reward_shaping ) > 0 :
179+ reshaped_dataset = self ._reward_shaping (scored_dataset )
180+ else :
181+ reshaped_dataset = scored_dataset
182+ except Exception :
183+ traceback .print_exc ()
184+ return 7 , "Reward shaping failed."
185+
186+ # step 8. track lineage if they are changed
157187 try :
158- res_dataset = scored_dataset
188+ res_dataset = reshaped_dataset
159189 except Exception :
160190 traceback .print_exc ()
161- return 7 , "Tracking lineage failed."
191+ return 8 , "Tracking lineage failed."
162192
163- # step 8
193+ # step 9, sort the dataset by the computed priority
164194 try :
165195 if "priority" in res_dataset .data .features :
166196 res_dataset .sort_by ("priority" , reverse = True )
167197 except Exception :
168198 traceback .print_exc ()
169- return 8 , "Sorting results by priority failed."
199+ return 9 , "Sorting results by priority failed."
170200
171- # step 9 . sort and export the result to the output buffer
201+ # step 10 . sort and export the result to the output buffer
172202 try :
173203 res_dataset .write_to_buffer ()
174204 except Exception :
175205 traceback .print_exc ()
176- return 9 , "Exporting result to output buffer failed."
206+ return 10 , "Exporting result to output buffer failed."
177207
178208 return 0 , "success"
179209
180210 def _group_scores (self , dataset : RftDataset ) -> RftDataset :
181211 # for perplexity, normalize them with the max value.
182- from data_juicer .utils .constant import Fields
183-
184212 stats_min_max = {}
185213 for stats in dataset .data .features [Fields .stats ]:
186214 all_stats = [
@@ -268,6 +296,35 @@ def _compute_priority_scores(self, dataset: RftDataset) -> RftDataset:
268296 dataset .data = dataset .data .map (self ._compute_combined_score )
269297 return dataset
270298
299+ def _reward_shaping_single (self , sample , reward_shaping_config : RewardShapingConfig ):
300+ tgt_stats = reward_shaping_config .stats_key
301+ op_type = reward_shaping_config .op_type
302+ # if the target stats does not exist, skip this stats and return the original sample
303+ if tgt_stats not in sample [Fields .stats ]:
304+ return sample
305+ if op_type == OpType .ADD :
306+ sample [self .config .format .reward_key ] += reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
307+ elif op_type == OpType .MUL :
308+ sample [self .config .format .reward_key ] *= reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
309+ elif op_type == OpType .SUB :
310+ sample [self .config .format .reward_key ] -= reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
311+ elif op_type == OpType .DIV :
312+ sample [self .config .format .reward_key ] /= reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
313+ return sample
314+
315+ def _reward_shaping (self , rft_dataset : RftDataset ) -> RftDataset :
316+ dataset = rft_dataset .data
317+ # check if there is a reward column in the dataset. If not, skip!
318+ if self .config .format .reward_key not in dataset .features :
319+ return rft_dataset
320+ # get reward shaping configs
321+ reward_shaping_configs = self .config .reward_shaping
322+ for reward_shaping_config in reward_shaping_configs :
323+ dataset = dataset .map (partial (self ._reward_shaping_single , reward_shaping_config = reward_shaping_config ))
324+
325+ rft_dataset .data = dataset
326+ return rft_dataset
327+
271328 @ray .method (num_returns = 1 )
272329 def select_batch (self , dataset : RftDataset , batch_size : int ) -> List [Dict [str , Any ]]:
273330 """Select a batch of samples for training"""
0 commit comments