Skip to content

Commit 321e406

Browse files
committed
added scripts to compute the results
1 parent ca1c0b4 commit 321e406

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed

get_ranking_results.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
__author__ = "qiao"
2+
3+
"""
4+
Get the reranking results from the trial matching csv
5+
"""
6+
7+
from collections import Counter
8+
import pandas as pd
9+
import json
10+
import random
11+
random.seed(2023)
12+
import sys
13+
14+
from sklearn.metrics import ndcg_score
15+
from sklearn.metrics import roc_curve, auc
16+
17+
18+
def get_metrics(labels, scores):
19+
"""Input a list of labels and a list of scores, output the precision at 10 and MRR"""
20+
label_score = zip(labels, scores)
21+
label_score = sorted(label_score, key=lambda x: -x[1])
22+
23+
top_scores = [score for _, score in label_score[:10]]
24+
top_labels = [label for label, _ in label_score[:10]]
25+
26+
if len(top_scores) == 10:
27+
if top_scores[0] == top_scores[9]:
28+
# all ties
29+
labels = [label for label, score in label_score if score == top_scores[0]]
30+
prec = sum(labels) / (2 * len(labels))
31+
else:
32+
prec = sum(top_labels) / 20
33+
else:
34+
prec = sum(top_labels) / (2 * len(top_labels))
35+
36+
mrr = 0
37+
for rank, label in enumerate(top_labels):
38+
if label > 0:
39+
mrr = 1 / (rank + 1)
40+
break
41+
42+
return prec, mrr
43+
44+
45+
if __name__ == "__main__":
46+
# first we need to combine the output csv files
47+
df_list = []
48+
49+
model = sys.argv[1]
50+
51+
for cohort in ["sigir", "2021", "2022"]:
52+
df = pd.read_csv(f"results/trial_matching_{cohort}_{model}.csv")
53+
df["patient id"] = df["patient id"].apply(lambda x: cohort + " " + str(x))
54+
df_list.append(df)
55+
56+
df = pd.concat(df_list)
57+
58+
num_rows = len(df)
59+
random_scores = [random.uniform(0, 1) for _ in range(num_rows)]
60+
61+
df["inclusion"] = df["inclusion"] - df["inclusion not applicable"]
62+
df["exclusion"] = df["exclusion"] - df["exclusion not applicable"]
63+
64+
df["random"] = random_scores
65+
df["% inc"] = df["included"] / df["inclusion"]
66+
df["% not inc"] = - df["not included"] / df["inclusion"]
67+
df["bool not inc"] = - (df["not included"] > 0).astype(float)
68+
69+
df["% exc"] = - df["excluded"] / df["exclusion"]
70+
df["% not exc"] = df["not excluded"] / df["exclusion"]
71+
df["bool exc"] = - (df["excluded"] > 0).astype(float)
72+
73+
df["comb"] = df["% inc"] + df["bool exc"] + df["bool not inc"] + (df["relevance"] + df["eligibility"]) / 100
74+
75+
df = df.dropna()
76+
77+
patient_index = df.groupby("patient id")
78+
79+
score_names = ["comb", "% inc", "% not inc", "% exc", "% not exc", "bool not inc", "bool exc", "random", "eligibility", "relevance"]
80+
ndcg_list = {score_name: [] for score_name in score_names}
81+
prec_list = {score_name: [] for score_name in score_names}
82+
mrr_list = {score_name: [] for score_name in score_names}
83+
auc_list = {score_name: [] for score_name in score_names}
84+
85+
for patient_id, patient_data in patient_index:
86+
labels = patient_data["label"].tolist()
87+
if len(Counter(labels)) == 1:
88+
continue
89+
90+
# if there is only one label, just continue
91+
if len(set(labels)) <= 1:
92+
continue
93+
94+
for score_name in score_names:
95+
scores = patient_data[score_name].tolist()
96+
97+
# first get ndcg
98+
ndcg = ndcg_score([labels], [scores], k=10)
99+
ndcg_list[score_name].append(ndcg)
100+
101+
prec, mrr = get_metrics(labels, scores)
102+
prec_list[score_name].append(prec)
103+
mrr_list[score_name].append(mrr)
104+
105+
# then get auc
106+
if "sigir" in patient_id:
107+
continue
108+
109+
filt_labels = []
110+
filt_scores = []
111+
112+
for label, score in zip(labels, scores):
113+
if int(label) > 0:
114+
filt_labels.append(label)
115+
filt_scores.append(-score)
116+
117+
if len(set(filt_labels)) == 1:
118+
continue
119+
120+
fpr, tpr, thr = roc_curve(filt_labels, filt_scores, pos_label=1)
121+
auc_list[score_name].append(auc(fpr, tpr))
122+
123+
print("Ranking NDCG@10")
124+
for score_name in score_names:
125+
ndcgs = ndcg_list[score_name]
126+
127+
print(score_name, sum(ndcgs) / len(ndcgs))
128+
129+
print("Ranking Prec@10")
130+
for score_name in score_names:
131+
precs = prec_list[score_name]
132+
133+
print(score_name, sum(precs) / len(precs))
134+
135+
print("Ranking MRR")
136+
for score_name in score_names:
137+
mrrs = mrr_list[score_name]
138+
139+
print(score_name, sum(mrrs) / len(mrrs))
140+
141+
print("Auc")
142+
for score_name in score_names:
143+
aucs = auc_list[score_name]
144+
145+
print(score_name, sum(aucs) / len(aucs))

0 commit comments

Comments
 (0)