Skip to content

Commit 173f807

Browse files
authored
Merge pull request #89 from pykt-team/dev
Dev
2 parents d1bd046 + 2d6763e commit 173f807

40 files changed

+2945
-234
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,8 @@ examples/wandb_tmp/
5757
examples/archive/
5858
prediction.csv
5959
data/peiyou
60+
examples/iekt_ab/
61+
examples/qikt_improve/
62+
examples/*.txt
63+
examples/nips2022-peiyou/
64+
examples/nips2022-ednet/

configs/data_config.json

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,17 @@
361361
2,
362362
3,
363363
4
364-
]
364+
],
365+
"train_valid_original_file": "train_valid.csv",
366+
"test_original_file": "test.csv",
367+
"test_file": "test_sequences.csv",
368+
"test_window_file": "test_window_sequences.csv",
369+
"test_question_file": "test_question_sequences.csv",
370+
"test_question_window_file": "test_question_window_sequences.csv",
371+
"train_valid_original_file_quelevel": "train_valid_quelevel.csv",
372+
"train_valid_file_quelevel": "train_valid_sequences_quelevel.csv",
373+
"test_file_quelevel": "test_sequences_quelevel.csv",
374+
"test_window_file_quelevel": "test_window_sequences_quelevel.csv",
375+
"test_original_file_quelevel": "test_quelevel.csv"
365376
}
366-
}
377+
}

examples/data_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# process raw data
3737
if args.dataset_name=="peiyou":
3838
dname2paths["peiyou"] = args.file_path
39-
print(f"fpath: {args.file_path}")
39+
print(f"fpath: {args.file_path}")
4040
dname, writef = process_raw_data(args.dataset_name, dname2paths)
4141
print("-"*50)
4242
# split

examples/extract_quelevel_raw_result.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def que_update_ls_report(que_test, que_win_test, report,save_dir):
107107
que_short = que_test[que_test['inter_num'] <= 200].reset_index()
108108
que_long = que_test[que_test['inter_num'] > 200].reset_index()
109109

110+
110111
que_win_short = que_win_test[que_win_test['inter_num']
111112
<= 200].reset_index()
112113
que_win_long = que_win_test[que_win_test['inter_num'] > 200].reset_index()
@@ -119,29 +120,33 @@ def que_update_ls_report(que_test, que_win_test, report,save_dir):
119120
print(f"skip {y_pred_col}")
120121
continue
121122
# long
122-
report_long = get_metrics(que_long, y_true_col="y_true",
123-
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_long',save_dir=save_dir)
124-
report.update(report_long)
123+
if len(que_long)!=0:
124+
report_long = get_metrics(que_long, y_true_col="y_true",
125+
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_long',save_dir=save_dir)
126+
report.update(report_long)
125127

126-
# short
127-
report_short = get_metrics(que_short, y_true_col="y_true",
128-
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_short',save_dir=save_dir)
129-
report.update(report_short)
128+
if len(que_short)!=0:
129+
# short
130+
report_short = get_metrics(que_short, y_true_col="y_true",
131+
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_short',save_dir=save_dir)
132+
report.update(report_short)
130133

131134
# long + short
132135
report_col = get_metrics(que_test, y_true_col="y_true",
133136
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}',save_dir=save_dir)
134137
report.update(report_col)
135-
136-
# win long
137-
report_win_long = get_metrics(que_win_long, y_true_col="y_true",
138-
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_win_long',save_dir=save_dir)
139-
report.update(report_win_long)
140-
141-
# win short
142-
report_win_short = get_metrics(que_win_short, y_true_col="y_true",
143-
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_win_short',save_dir=save_dir)
144-
report.update(report_win_short)
138+
139+
if len(que_win_long)!=0:
140+
# win long
141+
report_win_long = get_metrics(que_win_long, y_true_col="y_true",
142+
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_win_long',save_dir=save_dir)
143+
report.update(report_win_long)
144+
145+
if len(que_win_short)!=0:
146+
# win short
147+
report_win_short = get_metrics(que_win_short, y_true_col="y_true",
148+
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_win_short',save_dir=save_dir)
149+
report.update(report_win_short)
145150

146151
# win long + win short
147152
report_win = get_metrics(que_win_test, y_true_col="y_true",
@@ -168,7 +173,7 @@ def add_question_report(save_dir, data_dir, report, stu_inter_num_dict, cut, dat
168173
print("Start 基于题目的长短序列")
169174
que_update_ls_report(que_test, que_win_test, report,save_dir=save_dir) # short long 结果
170175
except:
171-
print("Fail 基于题目的长短序列")
176+
print(f"Fail 基于题目的长短序列,details is {traceback.format_exc()}")
172177

173178
return que_test,que_win_test
174179

examples/extract_raw.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
dataset="ednet"
3+
models="dkt,akt,iekt,qdkt,qikt,saint"
4+
5+
IFS=','
6+
for i in $models; do
7+
echo "# $i"
8+
python -u extract_raw_result.py --dataset $dataset --model_name $i
9+
echo ""
10+
done
11+
12+
13+
# nohup sh extract_raw.sh >extract.log&

examples/extract_raw_result.py

Lines changed: 88 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import argparse
99
from pykt.config import que_type_models
1010
from extract_quelevel_raw_result import get_one_result_help as get_quelevel_one_result_help
11+
import traceback
12+
1113

1214
cut = True
1315

@@ -68,6 +70,8 @@ def get_metrics(df, y_pred_col='y_pred', y_true_col='y_true', name="test", cut=F
6870
"""获取原始指标"""
6971
if not save_dir is None:
7072
save_df(df,name,save_dir)
73+
if len(df)==0:
74+
return {}
7175
print(f"get_metrics,y_pred_col={y_pred_col},name={name}")
7276
# 针对concept_preds
7377
if y_pred_col == "concept_preds":
@@ -142,12 +146,13 @@ def add_concepts(save_dir, data_dir, report, stu_inter_num_dict, cut, data_dict)
142146
df_qid_long, df_qid_short = paser_raw_data(
143147
qid_test, df_qid_test, stu_inter_num_dict=stu_inter_num_dict)
144148
# get report
145-
base_long_report = get_metrics(df_qid_long, cut=cut, name='long',save_dir=save_dir)
146-
report.update(base_long_report)
149+
if len(df_qid_long)!=0:
150+
base_long_report = get_metrics(df_qid_long, cut=cut, name='long',save_dir=save_dir)
151+
report.update(base_long_report)
147152

148-
149-
base_short_report = get_metrics(df_qid_short, cut=cut, name='short',save_dir=save_dir)
150-
report.update(base_short_report)
153+
if len(df_qid_short)!=0:
154+
base_short_report = get_metrics(df_qid_short, cut=cut, name='short',save_dir=save_dir)
155+
report.update(base_short_report)
151156

152157

153158
df_qid = pd.concat([df_qid_long, df_qid_short])
@@ -166,13 +171,14 @@ def add_concept_win(save_dir, data_dir, report, stu_inter_num_dict, cut, data_di
166171
df_qid_test_win = data_dict['df_qid_test_win']
167172
df_qid_win_long, df_qid_win_short = paser_raw_data(qid_test_win, df_qid_test_win, stu_inter_num_dict)
168173

169-
base_win_long_report = get_metrics(df_qid_win_long, cut=cut, name='win_long',save_dir=save_dir)
170-
report.update(base_win_long_report)
174+
if len(df_qid_win_long)!=0:
175+
base_win_long_report = get_metrics(df_qid_win_long, cut=cut, name='win_long',save_dir=save_dir)
176+
report.update(base_win_long_report)
171177

172178

173-
174-
base_win_short_report = get_metrics(df_qid_win_short, cut=cut, name='win_short',save_dir=save_dir)
175-
report.update(base_win_short_report)
179+
if len(df_qid_win_short)!=0:
180+
base_win_short_report = get_metrics(df_qid_win_short, cut=cut, name='win_short',save_dir=save_dir)
181+
report.update(base_win_short_report)
176182

177183

178184
df_qid_win = pd.concat([df_qid_win_long, df_qid_win_short])
@@ -208,20 +214,23 @@ def concept_update_l2(save_dir, data_dir, report, stu_inter_num_dict, cut, data_
208214

209215
# 筛选需要计算的指标
210216
df_result_long = df_result[df_result['cidx'].isin(keep_cidx)].copy()
211-
df_result_win_long = df_result_win[df_result_win['cidx'].isin(keep_cidx)].copy()
217+
if len(df_result_long)!=0:
218+
report.update(get_metrics(df_result_long, cut=cut, name='b200',save_dir=save_dir))
212219

213-
# update report
214-
report.update(get_metrics(df_result_long, cut=cut, name='b200',save_dir=save_dir))
215-
report.update(get_metrics(df_result_win_long, cut=cut, name='win_b200',save_dir=save_dir))
220+
df_result_win_long = df_result_win[df_result_win['cidx'].isin(keep_cidx)].copy()
221+
if len(df_result_win_long)!=0:
222+
report.update(get_metrics(df_result_win_long, cut=cut, name='win_b200',save_dir=save_dir))
216223

217224

218225
# 小于200的指标
219226
df_result_short = df_result[~df_result['cidx'].isin(keep_cidx)].copy()
227+
if len(df_result_short)!=0:
228+
report.update(get_metrics(df_result_short, cut=cut, name='s200',save_dir=save_dir))
229+
220230
df_result_win_short = df_result_win[~df_result_win['cidx'].isin(keep_cidx)].copy()
231+
if len(df_result_short)!=0:
232+
report.update(get_metrics(df_result_win_short, cut=cut, name='win_s200',save_dir=save_dir))
221233

222-
# update report
223-
report.update(get_metrics(df_result_short, cut=cut, name='s200',save_dir=save_dir))
224-
report.update(get_metrics(df_result_win_short, cut=cut, name='win_s200',save_dir=save_dir))
225234

226235
# question
227236

@@ -272,44 +281,56 @@ def update_question_df(df):
272281

273282
def que_update_ls_report(que_test, que_win_test, report,save_dir):
274283
# split
275-
que_short = que_test[que_test['inter_num'] <= 200].reset_index()
276-
que_long = que_test[que_test['inter_num'] > 200].reset_index()
277-
278-
que_win_short = que_win_test[que_win_test['inter_num']
279-
<= 200].reset_index()
280-
que_win_long = que_win_test[que_win_test['inter_num'] > 200].reset_index()
284+
if len(que_test)!=0:
285+
que_short = que_test[que_test['inter_num'] <= 200].reset_index()
286+
que_long = que_test[que_test['inter_num'] > 200].reset_index()
287+
else:
288+
que_short = pd.DataFrame()
289+
que_long = pd.DataFrame()
290+
291+
if len(que_win_test)!=0:
292+
que_win_short = que_win_test[que_win_test['inter_num']
293+
<= 200].reset_index()
294+
que_win_long = que_win_test[que_win_test['inter_num'] > 200].reset_index()
295+
else:
296+
que_win_short = pd.DataFrame()
297+
que_win_long = pd.DataFrame()
281298

282299

283300
# update
284301
for y_pred_col in ['concept_preds', 'late_mean', 'late_vote', 'late_all', 'early_preds']:
285302
print(f"que_update_ls_report start {y_pred_col}")
286-
if y_pred_col not in que_test.columns:
303+
if len(que_test)!=0 and y_pred_col not in que_test.columns:
287304
print(f"skip {y_pred_col}")
288305
continue
289306
# long
290-
report_long = get_metrics(que_long, y_true_col="late_trues",
291-
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_long',save_dir=save_dir)
292-
report.update(report_long)
307+
if len(que_long)!=0:
308+
report_long = get_metrics(que_long, y_true_col="late_trues",
309+
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_long',save_dir=save_dir)
310+
report.update(report_long)
293311

294312
# short
295-
report_short = get_metrics(que_short, y_true_col="late_trues",
296-
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_short',save_dir=save_dir)
297-
report.update(report_short)
313+
if len(que_short)!=0:
314+
report_short = get_metrics(que_short, y_true_col="late_trues",
315+
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_short',save_dir=save_dir)
316+
report.update(report_short)
298317

299318
# long + short
300319
report_col = get_metrics(que_test, y_true_col="late_trues",
301320
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}',save_dir=save_dir)
302321
report.update(report_col)
303322

304323
# win long
305-
report_win_long = get_metrics(que_win_long, y_true_col="late_trues",
306-
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_win_long',save_dir=save_dir)
307-
report.update(report_win_long)
324+
if len(que_win_long)!=0:
325+
report_win_long = get_metrics(que_win_long, y_true_col="late_trues",
326+
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_win_long',save_dir=save_dir)
327+
report.update(report_win_long)
308328

309329
# win short
310-
report_win_short = get_metrics(que_win_short, y_true_col="late_trues",
311-
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_win_short',save_dir=save_dir)
312-
report.update(report_win_short)
330+
if len(que_win_short)!=0:
331+
report_win_short = get_metrics(que_win_short, y_true_col="late_trues",
332+
y_pred_col=y_pred_col, cut=cut, name=f'{y_pred_col}_win_short',save_dir=save_dir)
333+
report.update(report_win_short)
313334

314335
# win long + win short
315336
report_win = get_metrics(que_win_test, y_true_col="late_trues",
@@ -359,49 +380,53 @@ def que_update_l2(que_test, que_win_test, report,save_dir):
359380
def add_question_report(save_dir, data_dir, report, stu_inter_num_dict, cut, data_dict):
360381
config = json.load(open(os.path.join(save_dir,"config.json")))
361382
emb_type = config['params']['emb_type']
362-
363-
364-
que_test = pd.read_csv(os.path.join(
365-
save_dir, f"{emb_type}_test_question_predictions.txt"), sep='\t')
366-
que_test = update_question_df(que_test)
367-
383+
368384
df_que_test = data_dict['df_que_test']
369-
370-
que_win_test = pd.read_csv(os.path.join(
371-
save_dir, f"{emb_type}_test_question_window_predictions.txt"), sep='\t')
372-
que_win_test = update_question_df(que_win_test)
373-
374-
df_que_win_test = data_dict['df_que_win_test']
375-
376385
# 映射学生
377386
orirow_2_uid = {}
378387
for _, row in df_que_test.iterrows():
379388
orirow_2_uid[int(row['orirow'].split(',')[0])] = row['uid']
389+
390+
try:
391+
que_test = pd.read_csv(os.path.join(
392+
save_dir, f"{emb_type}_test_question_predictions.txt"), sep='\t')
393+
que_test = update_question_df(que_test)
394+
# map
395+
que_test['uid'] = que_test['orirow'].map(orirow_2_uid)
396+
que_test['inter_num'] = que_test['uid'].map(stu_inter_num_dict)
397+
save_df(que_test,'que_test',save_dir)
398+
except:
399+
que_test = pd.DataFrame()
380400

381-
# map
382-
que_test['uid'] = que_test['orirow'].map(orirow_2_uid)
383-
que_test['inter_num'] = que_test['uid'].map(stu_inter_num_dict)
384-
385-
que_win_test['uid'] = que_win_test['orirow'].map(orirow_2_uid)
386-
que_win_test['inter_num'] = que_win_test['uid'].map(stu_inter_num_dict)
401+
402+
try:
403+
que_win_test = pd.read_csv(os.path.join(
404+
save_dir, f"{emb_type}_test_question_window_predictions.txt"), sep='\t')
405+
que_win_test = update_question_df(que_win_test)
406+
df_que_win_test = data_dict['df_que_win_test']
407+
408+
que_win_test['uid'] = que_win_test['orirow'].map(orirow_2_uid)
409+
que_win_test['inter_num'] = que_win_test['uid'].map(stu_inter_num_dict)
410+
save_df(que_win_test,'que_win_test',save_dir)
411+
except:
412+
que_win_test = pd.DataFrame()
387413

388414
# print("Start 基于题目的长短序列")
389415
try:
390416
print("Start 基于题目的长短序列")
391417
que_update_ls_report(que_test, que_win_test, report,save_dir=save_dir) # short long 结果
392418
except:
393-
print("Fail 基于题目的长短序列")
419+
print("Fail 基于题目的长短序列")
394420

395421
# print("Start 基于题目的>200部分")
396422

397423
try:
398424
print("Start 基于题目的>200部分")
399425
que_update_l2(que_test, que_win_test, report,save_dir=save_dir) # 大于200部分的结果
400426
except:
401-
print("Start 基于题目的>200部分")
427+
print("Start 基于题目的>200部分")
402428

403-
save_df(que_test,'que_test',save_dir)
404-
save_df(que_win_test,'que_win_test',save_dir)
429+
405430
que_test = None
406431
return que_test,que_win_test
407432

@@ -428,21 +453,21 @@ def get_one_result(root_save_dir, stu_inter_num_dict, data_dict, cut, skip=False
428453
print("Start 知识点非win")
429454
add_concepts(save_dir, data_dir, report,stu_inter_num_dict, cut, data_dict)
430455
except:
431-
print("Fail 知识点非win")
456+
print(f"Fail 知识点非win,details is {traceback.format_exc()}")
432457

433458
#知识点win
434459
try:
435460
print("Start 知识点win")
436461
add_concept_win(save_dir, data_dir, report,stu_inter_num_dict, cut, data_dict)
437462
except:
438-
print("Fail 知识点win")
463+
print(f"Fail 知识点win,details is {traceback.format_exc()}")
439464

440465
#长短序列
441466
try:
442467
print("Start 知识点长短序列")
443468
concept_update_l2(save_dir, data_dir, report, stu_inter_num_dict, cut, data_dict)
444469
except:
445-
print("Fail 知识点长短序列")
470+
print(f"Fail 知识点长短序列,details is {traceback.format_exc()}")
446471

447472

448473
if not data_dict['df_que_test'] is None:
@@ -496,7 +521,7 @@ def get_one_result_help(dataset, model_name,model_root_dir,data_root_dir):
496521
if __name__ == "__main__":
497522
# model_root_dir = "/root/autodl-nas/liuqiongqiong/bakt/pykt-toolkit/examples/best_model_path"
498523
# model_root_dir = "/root/autodl-nas/project/pykt_nips2022/examples/best_model_path"
499-
model_root_dir = "/root/autodl-nas/project/pykt_nips2022/examples/best_model_path_1002"
524+
model_root_dir = "/root/autodl-nas/project/full_result_pykt/best_model_path"
500525
# model_root_dir = "/root/autodl-nas/project/pykt_qikt/examples/best_model_path"
501526
data_root_dir = '/root/autodl-nas/project/pykt_nips2022/data'
502527

0 commit comments

Comments
 (0)