diff --git a/trialgpt_retrieval/hybrid_fusion_retrieval.py b/trialgpt_retrieval/hybrid_fusion_retrieval.py index d6c7dd1..93d6482 100644 --- a/trialgpt_retrieval/hybrid_fusion_retrieval.py +++ b/trialgpt_retrieval/hybrid_fusion_retrieval.py @@ -223,3 +223,46 @@ def get_medcpt_corpus_index(corpus): with open(output_path, "w") as f: json.dump(qid2nctids, f, indent=4) + + # added in trial metadata and patient metadata extraction + trial_info_path = "trial_info.json" + queries_path = f"dataset/{corpus}/queries.jsonl" + + with open(trial_info_path, "r") as f: + trial_info = json.load(f) + + patients = {} + with open(queries_path, "r") as f: + for line in f: + obj = json.loads(line.strip()) + patients[obj["_id"]] = obj["text"] + + transformed_output = [] + + # iterate over each patient ID and their retrieved trial NCTID list + for patient_id, nctid_list in qid2nctids.items(): + patient_note = patients.get(patient_id, "") + # create a dictionary for each patient + patient_entry = { + "patient_id": patient_id, + "patient": patient_note, + "0": [], + "1": [], + "2": [] + } + + # go through each retrieved trial NCTID and add in metadata + for nctid in nctid_list: + trial_metadata = trial_info.get(nctid) + if trial_metadata: + trial_metadata["NCTID"] = nctid + # added it to label 0 but can change depending on preference + patient_entry["0"].append(trial_metadata) + + transformed_output.append(patient_entry) + + # save the output path to the same location as retrieval output + transformed_output_path = output_path + with open(transformed_output_path, "w") as f: + json.dump(transformed_output, f, indent=4) +