Skip to content

Commit 6b582b8

Browse files
committed
updated codebase
1 parent 97ba047 commit 6b582b8

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

trialgpt_ranking/rank_results.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
__author__ = "qiao"
2+
3+
"""
4+
Rank the trials given the matching and aggregation results
5+
"""
6+
7+
import json
8+
import sys
9+
10+
eps = 1e-9
11+
12+
def get_matching_score(matching):
13+
# count only the valid ones
14+
included = 0
15+
not_inc = 0
16+
na_inc = 0
17+
no_info_inc = 0
18+
19+
excluded = 0
20+
not_exc = 0
21+
na_exc = 0
22+
no_info_exc = 0
23+
24+
# first count inclusions
25+
for criteria, info in matching["inclusion"].items():
26+
27+
if len(info) != 3:
28+
continue
29+
30+
if info[2] == "included":
31+
included += 1
32+
elif info[2] == "not included":
33+
not_inc += 1
34+
elif info[2] == "not applicable":
35+
na_inc += 1
36+
elif info[2] == "not enough information":
37+
no_info_inc += 1
38+
39+
# then count exclusions
40+
for criteria, info in matching["exclusion"].items():
41+
42+
if len(info) != 3:
43+
continue
44+
45+
if info[2] == "excluded":
46+
excluded += 1
47+
elif info[2] == "not excluded":
48+
not_exc += 1
49+
elif info[2] == "not applicable":
50+
na_exc += 1
51+
elif info[2] == "not enough information":
52+
no_info_exc += 1
53+
54+
# get the matching score
55+
score = 0
56+
57+
score += included / (included + not_inc + no_info_inc + eps)
58+
59+
if not_inc > 0:
60+
score -= 1
61+
62+
if excluded > 0:
63+
score -= 1
64+
65+
return score
66+
67+
68+
def get_agg_score(assessment):
69+
try:
70+
rel_score = float(assessment["relevance_score_R"])
71+
eli_score = float(assessment["eligibility_score_E"])
72+
except:
73+
rel_score = 0
74+
eli_score = 0
75+
76+
score = (rel_score + eli_score) / 100
77+
78+
return score
79+
80+
81+
if __name__ == "__main__":
82+
# args are the results paths
83+
matching_results_path = sys.argv[1]
84+
agg_results_path = sys.argv[2]
85+
86+
# loading the results
87+
matching_results = json.load(open(matching_results_path))
88+
agg_results = json.load(open(agg_results_path))
89+
90+
# loop over the patients
91+
for patient_id, label2trial2results in matching_results.items():
92+
93+
trial2score = {}
94+
95+
for _, trial2results in label2trial2results.items():
96+
for trial_id, results in trial2results.items():
97+
98+
matching_score = get_matching_score(results)
99+
100+
if patient_id not in agg_results or trial_id not in agg_results[patient_id]:
101+
print(f"Patient {patient_id} Trial {trial_id} not in the aggregation results.")
102+
agg_score = 0
103+
else:
104+
agg_score = get_agg_score(agg_results[patient_id][trial_id])
105+
106+
trial_score = matching_score + agg_score
107+
108+
trial2score[trial_id] = trial_score
109+
110+
sorted_trial2score = sorted(trial2score.items(), key=lambda x: -x[1])
111+
112+
print()
113+
print(f"Patient ID: {patient_id}")
114+
print("Clinical trial ranking:")
115+
116+
for trial, score in sorted_trial2score:
117+
print(trial, score)
118+
119+
print("===")
120+
print()

0 commit comments

Comments
 (0)