11import os
2- import traceback
32import threading
3+ import traceback
4+ from functools import partial
45from numbers import Number
56from typing import Any , Dict , List , Union
6- from functools import partial
7- from data_juicer .utils .constant import Fields
87
98import ray
9+ from data_juicer .utils .constant import Fields
1010
1111from trinity .common .config import BufferConfig , DataPipelineConfig , RewardShapingConfig
1212from trinity .common .constants import DataProcessorPipelineType , OpType
@@ -102,7 +102,7 @@ def __init__(
102102 def run (self , thread_event : threading .Event = None ):
103103 """Run the active iterator."""
104104 # step 1. parse the dj config
105- logger .info (' Parsing the Data-Juicer config...' )
105+ logger .info (" Parsing the Data-Juicer config..." )
106106 try :
107107 (
108108 dj_config ,
@@ -115,15 +115,15 @@ def run(self, thread_event: threading.Event = None):
115115 return 1 , "config parsing failed."
116116
117117 # step 2. prepare rft-dataset from the input buffers
118- logger .info (' Preparing Rft-Dataset from input buffers...' )
118+ logger .info (" Preparing Rft-Dataset from input buffers..." )
119119 try :
120120 dataset = RftDataset (self .config , self .buffer_config )
121121 except Exception :
122122 traceback .print_exc ()
123123 return 2 , "RftDataset loading failed."
124124
125125 # step 3. load processor
126- logger .info (' Loading data processors...' )
126+ logger .info (" Loading data processors..." )
127127 try :
128128 if hit_cleaner :
129129 cleaner = DataCleaner (
@@ -151,7 +151,7 @@ def run(self, thread_event: threading.Event = None):
151151 break
152152
153153 # step 4. load data from the input buffers for the next batch
154- logger .info (' Loading data from input buffers for the next batch...' )
154+ logger .info (" Loading data from input buffers for the next batch..." )
155155 try :
156156 dataset .read_from_buffer ()
157157 except StopIteration :
@@ -161,7 +161,7 @@ def run(self, thread_event: threading.Event = None):
161161 return 4 , "RftDataset loading from buffers failed."
162162
163163 # step 5. apply processors to calculate scores of different dimensions
164- logger .info (' Applying data processors to calculate stats...' )
164+ logger .info (" Applying data processors to calculate stats..." )
165165 try :
166166 res_dataset = dataset
167167 if hit_cleaner :
@@ -177,7 +177,7 @@ def run(self, thread_event: threading.Event = None):
177177 # step 6. calculate the average and final scores, including priority
178178 try :
179179 if hit_cleaner :
180- logger .info (' Calculating the average and final scores...' )
180+ logger .info (" Calculating the average and final scores..." )
181181 scored_dataset = self ._group_scores (res_dataset )
182182 scored_dataset = self ._compute_priority_scores (scored_dataset )
183183 else :
@@ -188,7 +188,11 @@ def run(self, thread_event: threading.Event = None):
188188
189189 # step 7. reward shaping. Only available for experience pipeline and the reward shaping config is set
190190 try :
191- if self .pipeline_type == DataProcessorPipelineType .EXPERIENCE and len (self .config .reward_shaping ) > 0 :
191+ if (
192+ self .pipeline_type == DataProcessorPipelineType .EXPERIENCE
193+ and self .config .reward_shaping is not None
194+ and len (self .config .reward_shaping ) > 0
195+ ):
192196 logger .info ("Rewarding shaping..." )
193197 reshaped_dataset = self ._reward_shaping (scored_dataset )
194198 else :
@@ -215,7 +219,7 @@ def run(self, thread_event: threading.Event = None):
215219
216220 # step 10. export the result to the output buffer
217221 try :
218- logger .info (' Writing processed data to output buffer...' )
222+ logger .info (" Writing processed data to output buffer..." )
219223 res_dataset .write_to_buffer ()
220224 except Exception :
221225 traceback .print_exc ()
@@ -325,13 +329,21 @@ def _reward_shaping_single(self, sample, reward_shaping_config: RewardShapingCon
325329 if tgt_stats not in sample [Fields .stats ]:
326330 return sample
327331 if op_type == OpType .ADD :
328- sample [self .config .format .reward_key ] += reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
332+ sample [self .config .format .reward_key ] += (
333+ reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
334+ )
329335 elif op_type == OpType .MUL :
330- sample [self .config .format .reward_key ] *= reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
336+ sample [self .config .format .reward_key ] *= (
337+ reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
338+ )
331339 elif op_type == OpType .SUB :
332- sample [self .config .format .reward_key ] -= reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
340+ sample [self .config .format .reward_key ] -= (
341+ reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
342+ )
333343 elif op_type == OpType .DIV :
334- sample [self .config .format .reward_key ] /= reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
344+ sample [self .config .format .reward_key ] /= (
345+ reward_shaping_config .weight * sample [Fields .stats ][tgt_stats ]
346+ )
335347 return sample
336348
337349 def _reward_shaping (self , rft_dataset : RftDataset ) -> RftDataset :
@@ -342,7 +354,9 @@ def _reward_shaping(self, rft_dataset: RftDataset) -> RftDataset:
342354 # get reward shaping configs
343355 reward_shaping_configs = self .config .reward_shaping
344356 for reward_shaping_config in reward_shaping_configs :
345- dataset = dataset .map (partial (self ._reward_shaping_single , reward_shaping_config = reward_shaping_config ))
357+ dataset = dataset .map (
358+ partial (self ._reward_shaping_single , reward_shaping_config = reward_shaping_config )
359+ )
346360
347361 rft_dataset .data = dataset
348362 return rft_dataset
0 commit comments