-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpseudo_labeling_rex.py
More file actions
476 lines (388 loc) · 16.5 KB
/
pseudo_labeling_rex.py
File metadata and controls
476 lines (388 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
import os
import pandas as pd
import polars as pl
import numpy as np
import torch
import random
from tqdm import tqdm
import argparse
import json
from openai import OpenAI
import re
import ast
import logging
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
# Logging setup
log_lock = threading.Lock()
labeling_prompt = '''
You are a medical AI assistant specialized in analyzing chest X-ray radiology reports. Your task is to extract structured disease labels from the report text.
Given a radiology report with Findings and Impression sections, identify the presence or absence of the following 13 conditions:
1. Atelectasis - Partial collapse of lung tissue
2. Cardiomegaly - Enlarged heart
3. Consolidation - Dense opacification of lung tissue (often pneumonia)
4. Edema - Fluid accumulation in lungs (pulmonary edema)
5. Enlarged Cardiomediastinum - Widening of the heart/mediastinal silhouette
6. Fracture - Bone fracture (usually ribs)
7. Lung Lesion - Nodule, mass, or other focal lung abnormality
8. Lung Opacity - Any area of increased opacity in the lungs
9. No Finding - Explicitly normal/unremarkable study
10. Pleural Effusion - Fluid in pleural space
11. Pleural Other - Other pleural abnormalities (thickening, plaques, etc.)
12. Pneumonia - Infectious consolidation/infiltrate
13. Pneumothorax - Air in pleural space (collapsed lung)
## Labeling Rules:
For each condition, assign:
- **1** if the condition is EXPLICITLY stated as PRESENT, IDENTIFIED, or OBSERVED
- **0** if the condition is EXPLICITLY stated as ABSENT, NEGATIVE, or NORMAL
- **-1** if the condition is NOT MENTIONED or UNCERTAIN
## Important Guidelines:
1. **Be Conservative**: Only mark as 1 if clearly present. Use -1 for uncertain cases.
2. **"No Finding" logic**:
- Mark as 1 ONLY if report explicitly states "no acute findings", "normal study", "unremarkable"
- If ANY abnormality is found, "No Finding" should be 0
3. **Terminology mapping**:
- "infiltrate" → Consolidation or Pneumonia (check context)
- "opacity" → Lung Opacity
- "airspace disease" → Consolidation or Edema
- "cardiomediastinal silhouette is normal" → Cardiomegaly=0, Enlarged Cardiomediastinum=0
- "clear lungs" → Most findings should be 0, consider "No Finding"=1
4. **Co-occurrence**: Multiple conditions can be 1 simultaneously
## Output Format:
Return ONLY a valid JSON object with exactly these keys and integer values (1, 0, or -1):
```json
{
"Atelectasis": 0,
"Cardiomegaly": 0,
"Consolidation": 0,
"Edema": 0,
"Enlarged Cardiomediastinum": 0,
"Fracture": 0,
"Lung Lesion": 0,
"Lung Opacity": -1,
"No Finding": 1,
"Pleural Effusion": 0,
"Pleural Other": 0,
"Pneumonia": 0,
"Pneumothorax": 0
}
```
Do NOT include any explanatory text, only the JSON object.
'''
verification_prompt = '''
You are a medical quality assurance AI specialized in verifying disease label accuracy from radiology reports.
You will be given:
1. The original radiology report (Findings and Impression)
2. The extracted disease labels (13 conditions with values 1/0/-1)
Your task is to verify if the extracted labels are correct based on the report text.
## Verification Criteria:
For each label, check:
- **Labels marked as 1 (Present)**: Is there clear evidence in the report?
- **Labels marked as 0 (Absent)**: Is there explicit mention of absence/normality?
- **Labels marked as -1 (Uncertain)**: Is it truly not mentioned or ambiguous?
## Common Errors to Check:
1. **Over-labeling**: Marking as 1 when evidence is weak or uncertain
2. **Under-labeling**: Marking as -1 when condition is clearly mentioned
3. **"No Finding" conflicts**: "No Finding"=1 but other conditions are also marked as 1
4. **Terminology mismatches**: Missing synonyms (e.g., "infiltrate" for pneumonia)
## Output Format:
Return a JSON object with:
- **"is_correct"**: true/false (whether all labels are accurate)
- **"issues"**: list of specific problems found (empty list if correct)
- **"confidence"**: float 0-1 (how confident you are in the original labels)
```json
{
"is_correct": true,
"issues": [],
"confidence": 0.95
}
```
OR if issues found:
```json
{
"is_correct": false,
"issues": [
"Atelectasis marked as 1 but report only mentions 'minimal bibasilar atelectasis' which is uncertain",
"No Finding marked as 1 but Pleural Effusion is also marked as 1 - contradiction"
],
"confidence": 0.6
}
```
Return ONLY the JSON object, no additional text.
'''
def remove_think_tags(text):
"""Remove <think> tags from LLM output"""
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
def extract_json_from_text(text):
"""Extract JSON object from text that might contain other content"""
# Try to find JSON block
match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
if match:
json_str = match.group()
try:
return json.loads(json_str)
except json.JSONDecodeError:
# Try with ast.literal_eval as fallback
try:
return ast.literal_eval(json_str)
except:
return None
return None
def verify_labels(client, model, findings, impression, labels_dict):
"""
Verify if extracted labels are correct using a verification LLM call
Args:
client: OpenAI client
model: Model name
findings: Findings section text
impression: Impression section text
labels_dict: Dictionary of extracted labels
Returns:
dict with keys: is_correct (bool), issues (list), confidence (float)
or None if verification failed
"""
text = f"""Original Report:
Findings: {findings}
Impression: {impression}
Extracted Labels:
{json.dumps(labels_dict, indent=2)}"""
try:
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": verification_prompt},
{"role": "user", "content": text}
],
temperature=0.1
)
response = remove_think_tags(completion.choices[0].message.content)
verification_result = extract_json_from_text(response)
if verification_result and "is_correct" in verification_result:
return verification_result
else:
return None
except Exception as e:
logging.warning(f"Verification failed: {str(e)}")
return None
def extract_labels_from_text(client, model, findings, impression, retry=3, enable_verification=True):
"""
Extract disease labels from clinical text using LLM with optional verification
Args:
client: OpenAI client
model: Model name
findings: Findings section text
impression: Impression section text
retry: Number of retry attempts
enable_verification: Whether to use verification step
Returns:
List of 13 integer labels (1/0/-1) or None if failed
"""
text = f"Findings: {findings}\nImpression: {impression}"
for attempt in range(retry):
try:
# Step 1: Initial labeling
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": labeling_prompt},
{"role": "user", "content": text}
],
temperature=0.1 # Low temperature for consistent extraction
)
response = remove_think_tags(completion.choices[0].message.content)
# Extract JSON
labels_dict = extract_json_from_text(response)
if labels_dict is None:
if attempt < retry - 1:
continue
else:
return None
# Convert to ordered list
label_order = [
"Atelectasis", "Cardiomegaly", "Consolidation", "Edema",
"Enlarged Cardiomediastinum", "Fracture", "Lung Lesion",
"Lung Opacity", "No Finding", "Pleural Effusion",
"Pleural Other", "Pneumonia", "Pneumothorax"
]
labels = [labels_dict.get(key, -1) for key in label_order]
# Validate: all should be -1, 0, or 1
if not all(label in [-1, 0, 1] for label in labels):
if attempt < retry - 1:
continue
else:
return None
# Step 2: Verification (if enabled)
if enable_verification:
verification = verify_labels(client, model, findings, impression, labels_dict)
if verification is None:
# Verification failed, but labels are valid - accept them
logging.warning(f"Verification API failed, accepting labels anyway (attempt {attempt+1}/{retry})")
return labels
if verification.get("is_correct", False):
# Labels verified as correct
confidence = verification.get("confidence", 1.0)
logging.debug(f"Labels verified correct with confidence {confidence:.2f}")
return labels
else:
# Labels have issues, retry if possible
issues = verification.get("issues", [])
confidence = verification.get("confidence", 0.0)
logging.info(f"Verification failed (attempt {attempt+1}/{retry}): confidence={confidence:.2f}, issues={issues}")
if attempt < retry - 1:
continue
else:
# Last attempt failed verification, return anyway
logging.warning(f"Final attempt failed verification, returning labels anyway: {issues}")
return labels
else:
# No verification, return labels directly
return labels
except Exception as e:
if attempt < retry - 1:
continue
else:
raise e
return None
def process_single_row(args):
"""Process a single row for pseudo labeling with verification"""
idx, row, client, model, enable_verification = args
start_time = datetime.now()
try:
labels = extract_labels_from_text(
client, model,
row['Findings'],
row['Impression'],
retry=3,
enable_verification=enable_verification
)
if labels is None:
# Failed to extract valid labels
labels = [-1] * 13 # Mark all as uncertain
success = False
else:
success = True
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
# Only log failures (success is logged every 1000 in main loop)
if not success:
with log_lock:
logging.warning(f"Failed extraction - idx: {idx}, PatientID: {row['PatientID']}, time: {duration:.2f}s")
return idx, labels, success, None
except Exception as e:
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
with log_lock:
logging.error(f"Error - idx: {idx}, time: {duration:.2f}s, error: {str(e)}, Thread: {threading.current_thread().name}")
# Return all -1 on error
return idx, [-1] * 13, False, str(e)
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='Pseudo labeling for ReX dataset with verification')
parser.add_argument('--enable-verification', action='store_true', default=True,
help='Enable verification step (default: True)')
parser.add_argument('--disable-verification', action='store_true', default=False,
help='Disable verification step')
parser.add_argument('--max-workers', type=int, default=64,
help='Number of parallel workers (default: 64)')
args = parser.parse_args()
# Determine verification setting
enable_verification = args.enable_verification and not args.disable_verification
# Logging configuration - both file and console
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# File handler
file_handler = logging.FileHandler('/data/code/CXR_embedding_research/pseudo_labeling_logs.txt')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S'))
logger.addHandler(file_handler)
# Console handler (for progress logs)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S'))
logger.addHandler(console_handler)
# Load ReX dataset
logging.info("Loading ReX dataset...")
rex_df = pl.read_csv('/data/ReXGradient-160K/metadata/train_with_view_embeddings_aug.csv').to_pandas()
total_rows = len(rex_df)
logging.info(f"Total rows to process: {total_rows}")
logging.info(f"Verification enabled: {enable_verification}")
# Initialize OpenAI client
openai_api_key = "sk-1234"
openai_api_base = "http://litellm-litellm-1:4000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
model = 'Qwen3-235B-A22B-Instruct-2507-INT4-W4A16'
logging.info(f"Using model: {model}")
# Initialize label columns
label_columns = [
'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
'Lung Opacity', 'No Finding', 'Pleural Effusion',
'Pleural Other', 'Pneumonia', 'Pneumothorax'
]
for col in label_columns:
rex_df[col] = -1 # Initialize with -1 (uncertain)
# Prepare arguments for parallel processing
process_args = [(idx, row, client, model, enable_verification) for idx, row in rex_df.iterrows()]
# Process with ThreadPoolExecutor
max_workers = args.max_workers
logging.info(f"Starting pseudo labeling with {max_workers} workers...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_idx = {executor.submit(process_single_row, args): args[0] for args in process_args}
completed_count = 0
success_count = 0
fail_count = 0
# Process completed tasks
for future in as_completed(future_to_idx):
try:
idx, labels, success, error = future.result()
# Update DataFrame
for i, col in enumerate(label_columns):
rex_df.at[idx, col] = labels[i]
completed_count += 1
if success:
success_count += 1
else:
fail_count += 1
# Log progress every 1000 samples
if completed_count % 1000 == 0:
with log_lock:
logging.info(f"Progress: {completed_count}/{total_rows} ({completed_count/total_rows*100:.1f}%) - Success: {success_count}, Failed: {fail_count}")
except Exception as e:
idx = future_to_idx[future]
with log_lock:
logging.error(f"Future processing error - idx: {idx}, error: {str(e)}")
# Set all labels to -1 for this row
for col in label_columns:
rex_df.at[idx, col] = -1
fail_count += 1
completed_count += 1
logging.info(f"Pseudo labeling completed - Total: {total_rows}, Success: {success_count}, Failed: {fail_count}")
# Save results
output_path = '/data/ReXGradient-160K/metadata/train_with_view_embeddings_aug_labeled.csv'
rex_df.to_csv(output_path, index=False)
logging.info(f"Results saved to: {output_path}")
# Print statistics
print(f"\n{'='*60}")
print(f"Pseudo Labeling Complete!")
print(f"{'='*60}")
print(f"Total samples: {total_rows}")
print(f"Successfully labeled: {success_count} ({success_count/total_rows*100:.1f}%)")
print(f"Failed: {fail_count} ({fail_count/total_rows*100:.1f}%)")
print(f"Output saved to: {output_path}")
print(f"{'='*60}\n")
# Print label distribution
print("\nLabel Distribution:")
print("-" * 60)
for col in label_columns:
positive = (rex_df[col] == 1).sum()
negative = (rex_df[col] == 0).sum()
uncertain = (rex_df[col] == -1).sum()
print(f"{col:30s} | +1: {positive:6d} | 0: {negative:6d} | -1: {uncertain:6d}")
print("-" * 60)
if __name__ == "__main__":
main()