88import argparse
99from pykt .config import que_type_models
1010from extract_quelevel_raw_result import get_one_result_help as get_quelevel_one_result_help
11+ import traceback
12+
1113
1214cut = True
1315
@@ -68,6 +70,8 @@ def get_metrics(df, y_pred_col='y_pred', y_true_col='y_true', name="test", cut=F
6870 """获取原始指标"""
6971 if not save_dir is None :
7072 save_df (df ,name ,save_dir )
73+ if len (df )== 0 :
74+ return {}
7175 print (f"get_metrics,y_pred_col={ y_pred_col } ,name={ name } " )
7276 # 针对concept_preds
7377 if y_pred_col == "concept_preds" :
@@ -142,12 +146,13 @@ def add_concepts(save_dir, data_dir, report, stu_inter_num_dict, cut, data_dict)
142146 df_qid_long , df_qid_short = paser_raw_data (
143147 qid_test , df_qid_test , stu_inter_num_dict = stu_inter_num_dict )
144148 # get report
145- base_long_report = get_metrics (df_qid_long , cut = cut , name = 'long' ,save_dir = save_dir )
146- report .update (base_long_report )
149+ if len (df_qid_long )!= 0 :
150+ base_long_report = get_metrics (df_qid_long , cut = cut , name = 'long' ,save_dir = save_dir )
151+ report .update (base_long_report )
147152
148-
149- base_short_report = get_metrics (df_qid_short , cut = cut , name = 'short' ,save_dir = save_dir )
150- report .update (base_short_report )
153+ if len ( df_qid_short ) != 0 :
154+ base_short_report = get_metrics (df_qid_short , cut = cut , name = 'short' ,save_dir = save_dir )
155+ report .update (base_short_report )
151156
152157
153158 df_qid = pd .concat ([df_qid_long , df_qid_short ])
@@ -166,13 +171,14 @@ def add_concept_win(save_dir, data_dir, report, stu_inter_num_dict, cut, data_di
166171 df_qid_test_win = data_dict ['df_qid_test_win' ]
167172 df_qid_win_long , df_qid_win_short = paser_raw_data (qid_test_win , df_qid_test_win , stu_inter_num_dict )
168173
169- base_win_long_report = get_metrics (df_qid_win_long , cut = cut , name = 'win_long' ,save_dir = save_dir )
170- report .update (base_win_long_report )
174+ if len (df_qid_win_long )!= 0 :
175+ base_win_long_report = get_metrics (df_qid_win_long , cut = cut , name = 'win_long' ,save_dir = save_dir )
176+ report .update (base_win_long_report )
171177
172178
173-
174- base_win_short_report = get_metrics (df_qid_win_short , cut = cut , name = 'win_short' ,save_dir = save_dir )
175- report .update (base_win_short_report )
179+ if len ( df_qid_win_short ) != 0 :
180+ base_win_short_report = get_metrics (df_qid_win_short , cut = cut , name = 'win_short' ,save_dir = save_dir )
181+ report .update (base_win_short_report )
176182
177183
178184 df_qid_win = pd .concat ([df_qid_win_long , df_qid_win_short ])
@@ -208,20 +214,23 @@ def concept_update_l2(save_dir, data_dir, report, stu_inter_num_dict, cut, data_
208214
209215 # 筛选需要计算的指标
210216 df_result_long = df_result [df_result ['cidx' ].isin (keep_cidx )].copy ()
211- df_result_win_long = df_result_win [df_result_win ['cidx' ].isin (keep_cidx )].copy ()
217+ if len (df_result_long )!= 0 :
218+ report .update (get_metrics (df_result_long , cut = cut , name = 'b200' ,save_dir = save_dir ))
212219
213- # update report
214- report . update ( get_metrics ( df_result_long , cut = cut , name = 'b200' , save_dir = save_dir ))
215- report .update (get_metrics (df_result_win_long , cut = cut , name = 'win_b200' ,save_dir = save_dir ))
220+ df_result_win_long = df_result_win [ df_result_win [ 'cidx' ]. isin ( keep_cidx )]. copy ()
221+ if len ( df_result_win_long ) != 0 :
222+ report .update (get_metrics (df_result_win_long , cut = cut , name = 'win_b200' ,save_dir = save_dir ))
216223
217224
218225 # 小于200的指标
219226 df_result_short = df_result [~ df_result ['cidx' ].isin (keep_cidx )].copy ()
227+ if len (df_result_short )!= 0 :
228+ report .update (get_metrics (df_result_short , cut = cut , name = 's200' ,save_dir = save_dir ))
229+
220230 df_result_win_short = df_result_win [~ df_result_win ['cidx' ].isin (keep_cidx )].copy ()
231+ if len (df_result_short )!= 0 :
232+ report .update (get_metrics (df_result_win_short , cut = cut , name = 'win_s200' ,save_dir = save_dir ))
221233
222- # update report
223- report .update (get_metrics (df_result_short , cut = cut , name = 's200' ,save_dir = save_dir ))
224- report .update (get_metrics (df_result_win_short , cut = cut , name = 'win_s200' ,save_dir = save_dir ))
225234
226235# question
227236
@@ -272,44 +281,56 @@ def update_question_df(df):
272281
273282def que_update_ls_report (que_test , que_win_test , report ,save_dir ):
274283 # split
275- que_short = que_test [que_test ['inter_num' ] <= 200 ].reset_index ()
276- que_long = que_test [que_test ['inter_num' ] > 200 ].reset_index ()
277-
278- que_win_short = que_win_test [que_win_test ['inter_num' ]
279- <= 200 ].reset_index ()
280- que_win_long = que_win_test [que_win_test ['inter_num' ] > 200 ].reset_index ()
284+ if len (que_test )!= 0 :
285+ que_short = que_test [que_test ['inter_num' ] <= 200 ].reset_index ()
286+ que_long = que_test [que_test ['inter_num' ] > 200 ].reset_index ()
287+ else :
288+ que_short = pd .DataFrame ()
289+ que_long = pd .DataFrame ()
290+
291+ if len (que_win_test )!= 0 :
292+ que_win_short = que_win_test [que_win_test ['inter_num' ]
293+ <= 200 ].reset_index ()
294+ que_win_long = que_win_test [que_win_test ['inter_num' ] > 200 ].reset_index ()
295+ else :
296+ que_win_short = pd .DataFrame ()
297+ que_win_long = pd .DataFrame ()
281298
282299
283300 # update
284301 for y_pred_col in ['concept_preds' , 'late_mean' , 'late_vote' , 'late_all' , 'early_preds' ]:
285302 print (f"que_update_ls_report start { y_pred_col } " )
286- if y_pred_col not in que_test .columns :
303+ if len ( que_test ) != 0 and y_pred_col not in que_test .columns :
287304 print (f"skip { y_pred_col } " )
288305 continue
289306 # long
290- report_long = get_metrics (que_long , y_true_col = "late_trues" ,
291- y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } _long' ,save_dir = save_dir )
292- report .update (report_long )
307+ if len (que_long )!= 0 :
308+ report_long = get_metrics (que_long , y_true_col = "late_trues" ,
309+ y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } _long' ,save_dir = save_dir )
310+ report .update (report_long )
293311
294312 # short
295- report_short = get_metrics (que_short , y_true_col = "late_trues" ,
296- y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } _short' ,save_dir = save_dir )
297- report .update (report_short )
313+ if len (que_short )!= 0 :
314+ report_short = get_metrics (que_short , y_true_col = "late_trues" ,
315+ y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } _short' ,save_dir = save_dir )
316+ report .update (report_short )
298317
299318 # long + short
300319 report_col = get_metrics (que_test , y_true_col = "late_trues" ,
301320 y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } ' ,save_dir = save_dir )
302321 report .update (report_col )
303322
304323 # win long
305- report_win_long = get_metrics (que_win_long , y_true_col = "late_trues" ,
306- y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } _win_long' ,save_dir = save_dir )
307- report .update (report_win_long )
324+ if len (que_win_long )!= 0 :
325+ report_win_long = get_metrics (que_win_long , y_true_col = "late_trues" ,
326+ y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } _win_long' ,save_dir = save_dir )
327+ report .update (report_win_long )
308328
309329 # win short
310- report_win_short = get_metrics (que_win_short , y_true_col = "late_trues" ,
311- y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } _win_short' ,save_dir = save_dir )
312- report .update (report_win_short )
330+ if len (que_win_short )!= 0 :
331+ report_win_short = get_metrics (que_win_short , y_true_col = "late_trues" ,
332+ y_pred_col = y_pred_col , cut = cut , name = f'{ y_pred_col } _win_short' ,save_dir = save_dir )
333+ report .update (report_win_short )
313334
314335 # win long + win short
315336 report_win = get_metrics (que_win_test , y_true_col = "late_trues" ,
@@ -359,49 +380,53 @@ def que_update_l2(que_test, que_win_test, report,save_dir):
359380def add_question_report (save_dir , data_dir , report , stu_inter_num_dict , cut , data_dict ):
360381 config = json .load (open (os .path .join (save_dir ,"config.json" )))
361382 emb_type = config ['params' ]['emb_type' ]
362-
363-
364- que_test = pd .read_csv (os .path .join (
365- save_dir , f"{ emb_type } _test_question_predictions.txt" ), sep = '\t ' )
366- que_test = update_question_df (que_test )
367-
383+
368384 df_que_test = data_dict ['df_que_test' ]
369-
370- que_win_test = pd .read_csv (os .path .join (
371- save_dir , f"{ emb_type } _test_question_window_predictions.txt" ), sep = '\t ' )
372- que_win_test = update_question_df (que_win_test )
373-
374- df_que_win_test = data_dict ['df_que_win_test' ]
375-
376385 # 映射学生
377386 orirow_2_uid = {}
378387 for _ , row in df_que_test .iterrows ():
379388 orirow_2_uid [int (row ['orirow' ].split (',' )[0 ])] = row ['uid' ]
389+
390+ try :
391+ que_test = pd .read_csv (os .path .join (
392+ save_dir , f"{ emb_type } _test_question_predictions.txt" ), sep = '\t ' )
393+ que_test = update_question_df (que_test )
394+ # map
395+ que_test ['uid' ] = que_test ['orirow' ].map (orirow_2_uid )
396+ que_test ['inter_num' ] = que_test ['uid' ].map (stu_inter_num_dict )
397+ save_df (que_test ,'que_test' ,save_dir )
398+ except :
399+ que_test = pd .DataFrame ()
380400
381- # map
382- que_test ['uid' ] = que_test ['orirow' ].map (orirow_2_uid )
383- que_test ['inter_num' ] = que_test ['uid' ].map (stu_inter_num_dict )
384-
385- que_win_test ['uid' ] = que_win_test ['orirow' ].map (orirow_2_uid )
386- que_win_test ['inter_num' ] = que_win_test ['uid' ].map (stu_inter_num_dict )
401+
402+ try :
403+ que_win_test = pd .read_csv (os .path .join (
404+ save_dir , f"{ emb_type } _test_question_window_predictions.txt" ), sep = '\t ' )
405+ que_win_test = update_question_df (que_win_test )
406+ df_que_win_test = data_dict ['df_que_win_test' ]
407+
408+ que_win_test ['uid' ] = que_win_test ['orirow' ].map (orirow_2_uid )
409+ que_win_test ['inter_num' ] = que_win_test ['uid' ].map (stu_inter_num_dict )
410+ save_df (que_win_test ,'que_win_test' ,save_dir )
411+ except :
412+ que_win_test = pd .DataFrame ()
387413
388414 # print("Start 基于题目的长短序列")
389415 try :
390416 print ("Start 基于题目的长短序列" )
391417 que_update_ls_report (que_test , que_win_test , report ,save_dir = save_dir ) # short long 结果
392418 except :
393- print ("Fail 基于题目的长短序列" )
419+ print ("Fail 基于题目的长短序列" )
394420
395421 # print("Start 基于题目的>200部分")
396422
397423 try :
398424 print ("Start 基于题目的>200部分" )
399425 que_update_l2 (que_test , que_win_test , report ,save_dir = save_dir ) # 大于200部分的结果
400426 except :
401- print ("Start 基于题目的>200部分" )
427+ print ("Start 基于题目的>200部分" )
402428
403- save_df (que_test ,'que_test' ,save_dir )
404- save_df (que_win_test ,'que_win_test' ,save_dir )
429+
405430 que_test = None
406431 return que_test ,que_win_test
407432
@@ -428,21 +453,21 @@ def get_one_result(root_save_dir, stu_inter_num_dict, data_dict, cut, skip=False
428453 print ("Start 知识点非win" )
429454 add_concepts (save_dir , data_dir , report ,stu_inter_num_dict , cut , data_dict )
430455 except :
431- print ("Fail 知识点非win" )
456+ print (f "Fail 知识点非win,details is { traceback . format_exc () } " )
432457
433458 #知识点win
434459 try :
435460 print ("Start 知识点win" )
436461 add_concept_win (save_dir , data_dir , report ,stu_inter_num_dict , cut , data_dict )
437462 except :
438- print ("Fail 知识点win" )
463+ print (f "Fail 知识点win,details is { traceback . format_exc () } " )
439464
440465 #长短序列
441466 try :
442467 print ("Start 知识点长短序列" )
443468 concept_update_l2 (save_dir , data_dir , report , stu_inter_num_dict , cut , data_dict )
444469 except :
445- print ("Fail 知识点长短序列" )
470+ print (f "Fail 知识点长短序列,details is { traceback . format_exc () } " )
446471
447472
448473 if not data_dict ['df_que_test' ] is None :
@@ -496,7 +521,7 @@ def get_one_result_help(dataset, model_name,model_root_dir,data_root_dir):
496521if __name__ == "__main__" :
497522 # model_root_dir = "/root/autodl-nas/liuqiongqiong/bakt/pykt-toolkit/examples/best_model_path"
498523 # model_root_dir = "/root/autodl-nas/project/pykt_nips2022/examples/best_model_path"
499- model_root_dir = "/root/autodl-nas/project/pykt_nips2022/examples/best_model_path_1002 "
524+ model_root_dir = "/root/autodl-nas/project/full_result_pykt/best_model_path "
500525 # model_root_dir = "/root/autodl-nas/project/pykt_qikt/examples/best_model_path"
501526 data_root_dir = '/root/autodl-nas/project/pykt_nips2022/data'
502527
0 commit comments