1+ import copy
12import os
23import time
34import uuid
@@ -53,7 +54,7 @@ def __getstate__(self):
5354 def __setstate__ (self , state ):
5455 self .__dict__ .update (state )
5556
56- def _initialize_spark (self ):
57+ def initialize_spark (self ):
5758 """Initialize Spark session if not already provided."""
5859 if self .spark_session is not None :
5960 return self .spark_session , self .spark_session .sparkContext
@@ -63,11 +64,18 @@ def _initialize_spark(self):
6364 else :
6465 raise ValueError ('Both spark_session and spark_conf are None. Please provide one.' )
6566
67+ def cleanup (self , spark ):
68+ """Clean up Spark resources."""
69+ if spark :
70+ spark .stop ()
71+ if spark .sparkContext :
72+ spark .sparkContext .stop ()
73+
6674 def load_data (self ) -> RDD :
6775 """Load and return the RDD data."""
6876 return self .spark_rdd
6977
70- def execute (self ) -> List [ SummaryModel ] :
78+ def execute (self ) -> SummaryModel :
7179 """Main execution method for Spark evaluation."""
7280 create_time = time .strftime ('%Y%m%d_%H%M%S' , time .localtime ())
7381
@@ -80,7 +88,7 @@ def execute(self) -> List[SummaryModel]:
8088 self .llm = Model .get_llm (llm_name )
8189
8290 print ("============= Init PySpark =============" )
83- spark , sc = self ._initialize_spark ()
91+ spark , sc = self .initialize_spark ()
8492 self ._sc = sc
8593 print ("============== Init Done ===============" )
8694
@@ -98,7 +106,7 @@ def execute(self) -> List[SummaryModel]:
98106
99107 # Evaluate data
100108 data_info_list = data_rdd .map (
101- lambda x : self ._evaluate_item (x , broadcast_group , broadcast_llm )
109+ lambda x : self .evaluate_item (x , broadcast_group , broadcast_llm )
102110 ).persist () # Cache the evaluated data for multiple uses
103111
104112 # Filter and count bad/good items
@@ -119,26 +127,24 @@ def execute(self) -> List[SummaryModel]:
119127 score = round ((total - num_bad ) / total * 100 , 2 ) if total > 0 else 0 ,
120128 num_good = total - num_bad ,
121129 num_bad = num_bad ,
122- total = total ,
123- type_ratio = {},
124- name_ratio = {}
130+ total = total
125131 )
126132 # Generate detailed summary
127- self ._summarize_results ()
128-
129- self .summary .finish_time = time .strftime ('%Y%m%d_%H%M%S' , time .localtime ())
130-
131- return [self .summary ]
133+ self .summary = self .summarize (self .summary )
134+ return self .summary
132135
133136 except Exception as e :
134137 raise e
135138 finally :
136139 if not self .input_args .save_data :
137- self ._cleanup (spark )
140+ self .cleanup (spark )
138141 else :
139142 self .spark_session = spark
140143
141- def _evaluate_item (self , data_rdd_item , broadcast_group , broadcast_llm ) -> Dict [str , Any ]:
144+ def evaluate (self ):
145+ pass
146+
147+ def evaluate_item (self , data_rdd_item , broadcast_group , broadcast_llm ) -> Dict [str , Any ]:
142148 """Evaluate a single data item using broadcast variables."""
143149 data : MetaData = data_rdd_item
144150 result_info = ResultInfo (data_id = data .data_id , prompt = data .prompt , content = data .content )
@@ -158,9 +164,9 @@ def _evaluate_item(self, data_rdd_item, broadcast_group, broadcast_llm) -> Dict[
158164
159165 for group_type , group_items in group .items ():
160166 if group_type == 'rule' :
161- r_i = self ._evaluate_rule (group_items , data )
167+ r_i = self .evaluate_rule (group_items , data )
162168 elif group_type == 'prompt' :
163- r_i = self ._evaluate_prompt (group_items , data , llm )
169+ r_i = self .evaluate_prompt (group_items , data , llm )
164170 else :
165171 raise RuntimeError (f'Unsupported group type: { group_type } ' )
166172
@@ -186,7 +192,7 @@ def _evaluate_item(self, data_rdd_item, broadcast_group, broadcast_llm) -> Dict[
186192
187193 return result_info .to_dict ()
188194
189- def _evaluate_rule (self , group : List [BaseRule ], data : MetaData ) -> ResultInfo :
195+ def evaluate_rule (self , group : List [BaseRule ], data : MetaData ) -> ResultInfo :
190196 """Evaluate data against a group of rules."""
191197 result_info = ResultInfo (data_id = data .data_id , prompt = data .prompt , content = data .content )
192198
@@ -218,7 +224,7 @@ def _evaluate_rule(self, group: List[BaseRule], data: MetaData) -> ResultInfo:
218224
219225 return result_info
220226
221- def _evaluate_prompt (self , group : List [BasePrompt ], data : MetaData , llm : BaseLLM ) -> ResultInfo :
227+ def evaluate_prompt (self , group : List [BasePrompt ], data : MetaData , llm : BaseLLM ) -> ResultInfo :
222228 """Evaluate data against a group of prompts using LLM."""
223229 if llm is None :
224230 raise ValueError ("LLM is required for prompt evaluation" )
@@ -254,37 +260,42 @@ def _evaluate_prompt(self, group: List[BasePrompt], data: MetaData, llm: BaseLLM
254260
255261 return result_info
256262
257- def _summarize_results (self ) :
263+ def summarize (self , summary : SummaryModel ) -> SummaryModel :
258264 """Generate summary statistics from bad info list."""
259- if not self .bad_info_list :
260- return
261-
262- # Calculate type ratios
263- type_counts = (
264- self .bad_info_list
265- .flatMap (lambda x : [(t , 1 ) for t in x ['type_list' ]])
266- .reduceByKey (lambda a , b : a + b )
267- .collectAsMap ()
268- )
269- self .summary .type_ratio = {
270- k : round (v / self .summary .total , 6 )
271- for k , v in type_counts .items ()
272- }
273-
274- # Calculate name ratios
275- name_counts = (
276- self .bad_info_list
277- .flatMap (lambda x : [(n , 1 ) for n in x ['name_list' ]])
278- .reduceByKey (lambda a , b : a + b )
279- .collectAsMap ()
280- )
281- self .summary .name_ratio = {
282- k : round (v / self .summary .total , 6 )
283- for k , v in name_counts .items ()
284- }
285-
286- self .summary .type_ratio = dict (sorted (self .summary .type_ratio .items ()))
287- self .summary .name_ratio = dict (sorted (self .summary .name_ratio .items ()))
265+ def collect_ratio (data_info_list , key_name : str , total_count : int ):
266+ data_info_counts = (
267+ data_info_list
268+ .flatMap (lambda x : [(t , 1 ) for t in x [key_name ]])
269+ .reduceByKey (lambda a , b : a + b )
270+ .collectAsMap ()
271+ )
272+ return {
273+ k : round (v / total_count , 6 )
274+ for k , v in data_info_counts .items ()
275+ }
276+
277+
278+ new_summary = copy .deepcopy (self .summary )
279+ if not self .bad_info_list and not self .good_info_list :
280+ return new_summary
281+ if not self .bad_info_list and self .good_info_list :
282+ if not self .input_args .save_correct :
283+ return new_summary
284+
285+ new_summary .type_ratio = collect_ratio (self .bad_info_list , 'type_list' , new_summary .total )
286+ new_summary .name_ratio = collect_ratio (self .bad_info_list , 'name_list' , new_summary .total )
287+
288+ if self .input_args .save_correct :
289+ type_ratio_correct = collect_ratio (self .good_info_list , 'type_list' , new_summary .total )
290+ name_ratio_correct = collect_ratio (self .good_info_list , 'name_list' , new_summary .total )
291+ new_summary .type_ratio .update (type_ratio_correct )
292+ new_summary .name_ratio .update (name_ratio_correct )
293+
294+ new_summary .type_ratio = dict (sorted (new_summary .type_ratio .items ()))
295+ new_summary .name_ratio = dict (sorted (new_summary .name_ratio .items ()))
296+
297+ new_summary .finish_time = time .strftime ('%Y%m%d_%H%M%S' , time .localtime ())
298+ return new_summary
288299
289300 def get_summary (self ):
290301 return self .summary
@@ -314,16 +325,3 @@ def get_good_info_list(self):
314325 }
315326 })
316327 return self .good_info_list
317-
318- def save_data (self , start_time ):
319- """Save output data to specified path."""
320- output_path = os .path .join (self .input_args .output_path , start_time )
321- model_path = os .path .join (output_path , self .input_args .eval_group )
322- os .makedirs (model_path , exist_ok = True )
323-
324- def _cleanup (self , spark ):
325- """Clean up Spark resources."""
326- if spark :
327- spark .stop ()
328- if spark .sparkContext :
329- spark .sparkContext .stop ()
0 commit comments