-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtext2embedding_Mimic-CXR.py
More file actions
103 lines (91 loc) · 3.2 KB
/
text2embedding_Mimic-CXR.py
File metadata and controls
103 lines (91 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import pandas as pd
import numpy as np
import torch
import random
from tqdm import tqdm
import argparse
import json
from PIL import Image
from openai import OpenAI
import re
def standardize_view_position_direct(df, column_name='ViewPosition'):
"""
딕셔너리를 사용한 직접 매핑 방식
"""
mapping = {
'PA': 'PA',
'PA LLD': 'PA',
'PA RLD': 'PA',
'AP': 'AP',
'AP AXIAL': 'AP',
'AP LLD': 'AP',
'AP RLD': 'AP'
}
df_standardized = df.copy()
df_standardized[column_name] = df_standardized[column_name].map(mapping).fillna(df_standardized[column_name])
return df_standardized
def load_text(path):
with open(path,'r') as file:
lines=file.readlines()
file_content=''.join(lines)
return file_content.split("FINAL REPORT\n ")[1].replace('\n ','\n') #
def text_processing(full_text):
findings_pattern = r"FINDINGS:(.*?)"
findings_match = re.search(findings_pattern, full_text, re.DOTALL)
impression_pattern = r"IMPRESSION:(.*?)"
impression_match = re.search(impression_pattern, full_text, re.DOTALL)
if findings_match and impression_match:
findings_start = findings_match.span()[0]
impression_start = impression_match.span()[0]
if findings_start <= impression_start :
text = full_text[findings_start:]
else:
text = full_text[impression_start:]
elif findings_match and (not impression_match):
findings_start = findings_match.span()[0]
text = full_text[findings_start:]
elif (not findings_match) and impression_match:
impression_start = impression_match.span()[0]
text = full_text[impression_start:]
else:
text = full_text
return text
def text2embedding(client, model, text):
responses = client.embeddings.create(
input=[text],
model=model,
)
return responses.data[0].embedding
def main():
df = pd.read_csv('/data/code/CXR_embedding_research/processed_dataset.csv')
df = standardize_view_position_direct(df)
df = df[(df['ViewPosition'] == "PA") | (df['ViewPosition'] == "AP")].reset_index()
openai_api_key = "abc123"
openai_api_base = "http://localhost:8002/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
embedding_1_rows = []
embedding_2_rows = []
for idx, row in tqdm(df.iterrows()):
note_1 = load_text('/data/mimic3_cxr_jpg/'+row['path'])
note_1 = text_processing(note_1)
note_2 = row['paraphrased_note']
embedding_1 = text2embedding(client, model, note_1)
embedding_1_rows.append(embedding_1)
if note_2 != 'Fail':
embedding_2 = text2embedding(client, model, note_2)
embedding_2_rows.append(embedding_2)
else:
embedding_2_rows.append('Fail')
if idx % 1000 == 0:
print(idx)
df['embeddings_1'] = embedding_1_rows
df['embeddings_2'] = embedding_2_rows
df.to_csv('/data/mimic3_cxr_jpg/train_with_view_embeddings_aug.csv',encoding='utf8', index=False)
if __name__=='__main__':
main()