@@ -27,8 +27,6 @@ def __init__(self, input_args: InputArgs):
2727 self .input_args : InputArgs = input_args
2828 self .llm : Optional [BaseLLM ] = None
2929 self .summary : SummaryModel = SummaryModel ()
30- self .bad_info_list : List [ResultInfo ] = []
31- self .good_info_list : List [ResultInfo ] = []
3230
3331 def load_data (self ) -> Generator [MetaData , None , None ]:
3432 """
@@ -68,19 +66,11 @@ def execute(self) -> List[SummaryModel]:
6866 eval_group = group_name ,
6967 input_path = input_path ,
7068 output_path = output_path if self .input_args .save_data else '' ,
71- create_time = create_time ,
72- score = 0 ,
73- num_good = 0 ,
74- num_bad = 0 ,
75- total = 0 ,
76- type_ratio = {},
77- name_ratio = {}
69+ create_time = create_time
7870 )
7971 self .evaluate ()
8072 self .summary = self .summarize (self .summary )
81- self .summary .finish_time = time .strftime ('%Y%m%d_%H%M%S' , time .localtime ())
82- if self .input_args .save_data :
83- self .save_data (output_path , self .input_args , self .bad_info_list , self .good_info_list , self .summary )
73+ self .write_summary (self .summary .output_path , self .input_args , self .summary )
8474
8575 return [self .summary ]
8676
@@ -98,8 +88,6 @@ def evaluate(self):
9888 pbar = tqdm (total = None , unit = 'items' )
9989
10090 def process_batch (batch : List ):
101- save_flag = False
102-
10391 futures = []
10492 for group_type , group in Model .get_group (self .input_args .eval_group ).items ():
10593 if group_type == 'rule' :
@@ -111,46 +99,19 @@ def process_batch(batch: List):
11199
112100 for future in concurrent .futures .as_completed (futures ):
113101 result_info = future .result ()
114- # calculate summary ratio
102+ for t in result_info .type_list :
103+ self .summary .type_ratio [t ] += 1
104+ for n in result_info .name_list :
105+ self .summary .name_ratio [n ] += 1
115106 if result_info .error_status :
116- self .bad_info_list .append (result_info )
117107 self .summary .num_bad += 1
118- for t in result_info .type_list :
119- if t not in self .summary .type_ratio :
120- self .summary .type_ratio [t ] = 1
121- else :
122- self .summary .type_ratio [t ] += 1
123- for n in result_info .name_list :
124- if n not in self .summary .name_ratio :
125- self .summary .name_ratio [n ] = 1
126- else :
127- self .summary .name_ratio [n ] += 1
128108 else :
129- if self .input_args .save_correct :
130- self .good_info_list .append (result_info )
131- for t in result_info .type_list :
132- if t not in self .summary .type_ratio :
133- self .summary .type_ratio [t ] = 1
134- else :
135- self .summary .type_ratio [t ] += 1
136- for n in result_info .name_list :
137- if n not in self .summary .name_ratio :
138- self .summary .name_ratio [n ] = 1
139- else :
140- self .summary .name_ratio [n ] += 1
109+ self .summary .num_good += 1
141110 self .summary .total += 1
142- if self . summary . total % self . input_args . interval_size == 0 :
143- save_flag = True
111+
112+ self . write_single_data ( self . summary . output_path , self . input_args , result_info )
144113 pbar .update ()
145- # save data in file
146- if self .input_args .save_data :
147- if save_flag :
148- tmp_summary = self .summarize (self .summary )
149- tmp_summary .finish_time = time .strftime ('%Y%m%d_%H%M%S' , time .localtime ())
150- tmp_output_path = self .summary .output_path
151- self .save_data (tmp_output_path , self .input_args , self .bad_info_list , self .good_info_list , tmp_summary )
152- self .bad_info_list = []
153- self .good_info_list = []
114+ self .write_summary (self .summary .output_path , self .input_args , self .summarize (self .summary ))
154115 while True :
155116 batch = list (itertools .islice (data_iter , self .input_args .batch_size ))
156117 if not batch :
@@ -270,9 +231,9 @@ def evaluate_prompt(self, group: List[BasePrompt], d: MetaData) -> ResultInfo:
270231
271232 def summarize (self , summary : SummaryModel ) -> SummaryModel :
272233 new_summary = copy .deepcopy (summary )
234+ new_summary .finish_time = time .strftime ('%Y%m%d_%H%M%S' , time .localtime ())
273235 if new_summary .total == 0 :
274236 return new_summary
275- new_summary .num_good = new_summary .total - new_summary .num_bad
276237 new_summary .score = round (new_summary .num_good / new_summary .total * 100 , 2 )
277238 for t in new_summary .type_ratio :
278239 new_summary .type_ratio [t ] = round (new_summary .type_ratio [t ] / new_summary .total , 6 )
@@ -282,52 +243,38 @@ def summarize(self, summary: SummaryModel) -> SummaryModel:
282243 new_summary .name_ratio = dict (sorted (new_summary .name_ratio .items ()))
283244 return new_summary
284245
285- def get_summary (self ):
286- return self .summary
246+ def write_single_data (self , path : str , input_args : InputArgs , result_info : ResultInfo ):
247+ if not input_args .save_data :
248+ return
287249
288- def get_bad_info_list (self ):
289- return self .bad_info_list
290-
291- def get_good_info_list (self ):
292- return self .good_info_list
250+ if not input_args .save_correct and not result_info .error_status :
251+ return
293252
294- def save_data (
295- self ,
296- path : str ,
297- input_args : InputArgs ,
298- bad_info_list : List [ResultInfo ],
299- good_info_list : List [ResultInfo ],
300- summary : SummaryModel ,
301- ):
302- for result_info in bad_info_list :
303- for new_name in result_info .name_list :
304- t = str (new_name ).split ('-' )[0 ]
305- n = str (new_name ).split ('-' )[1 ]
306- p_t = os .path .join (path , t )
307- if not os .path .exists (p_t ):
308- os .makedirs (p_t )
309- f_n = os .path .join (path , t , n ) + ".jsonl"
310- with open (f_n , 'a' , encoding = 'utf-8' ) as f :
311- if input_args .save_raw :
312- str_json = json .dumps (result_info .to_raw_dict (), ensure_ascii = False )
313- else :
314- str_json = json .dumps (result_info .to_dict (), ensure_ascii = False )
315- f .write (str_json + '\n ' )
316- if input_args .save_correct :
317- for result_info in good_info_list :
318- for new_name in result_info .name_list :
319- t = str (new_name ).split ('-' )[0 ]
320- n = str (new_name ).split ('-' )[1 ]
321- p_t = os .path .join (path , t )
322- if not os .path .exists (p_t ):
323- os .makedirs (p_t )
324- f_n = os .path .join (path , t , n ) + ".jsonl"
325- with open (f_n , 'a' , encoding = 'utf-8' ) as f :
326- if input_args .save_raw :
327- str_json = json .dumps (result_info .to_raw_dict (), ensure_ascii = False )
328- else :
329- str_json = json .dumps (result_info .to_dict (), ensure_ascii = False )
330- f .write (str_json + '\n ' )
253+ for new_name in result_info .name_list :
254+ t = str (new_name ).split ('-' )[0 ]
255+ n = str (new_name ).split ('-' )[1 ]
256+ p_t = os .path .join (path , t )
257+ if not os .path .exists (p_t ):
258+ os .makedirs (p_t )
259+ f_n = os .path .join (path , t , n ) + ".jsonl"
260+ with open (f_n , 'a' , encoding = 'utf-8' ) as f :
261+ if input_args .save_raw :
262+ str_json = json .dumps (result_info .to_raw_dict (), ensure_ascii = False )
263+ else :
264+ str_json = json .dumps (result_info .to_dict (), ensure_ascii = False )
265+ f .write (str_json + '\n ' )
331266
267+ def write_summary (self , path : str , input_args : InputArgs , summary : SummaryModel ):
268+ if not input_args .save_data :
269+ return
332270 with open (path + '/summary.json' , 'w' , encoding = 'utf-8' ) as f :
333271 json .dump (summary .to_dict (), f , indent = 4 , ensure_ascii = False )
272+
273+ def get_summary (self ):
274+ pass
275+
276+ def get_bad_info_list (self ):
277+ pass
278+
279+ def get_good_info_list (self ):
280+ pass
0 commit comments