-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_pseudo_labeling.py
More file actions
96 lines (83 loc) · 3.54 KB
/
test_pseudo_labeling.py
File metadata and controls
96 lines (83 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Test script for pseudo labeling on a small sample with verification
Run this before processing the full dataset
"""
import os
import pandas as pd
import json
from openai import OpenAI
from pseudo_labeling_rex import extract_labels_from_text, verify_labels, extract_json_from_text, remove_think_tags
def test_on_samples(n_samples=5, enable_verification=True):
"""Test pseudo labeling on n_samples from ReX dataset"""
print("="*60)
print("Testing Pseudo Labeling on Sample Data")
print(f"Verification: {'ENABLED' if enable_verification else 'DISABLED'}")
print("="*60)
# Load ReX dataset
print("\nLoading ReX dataset...")
rex_df = pd.read_csv('/data/ReXGradient-160K/metadata/train_with_view_embeddings_aug.csv', nrows=n_samples)
print(f"Loaded {len(rex_df)} samples")
# Initialize OpenAI client
openai_api_key = "sk-1234"
openai_api_base = "http://localhost:4000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
print(f"Using model: {model}\n")
label_columns = [
'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
'Lung Opacity', 'No Finding', 'Pleural Effusion',
'Pleural Other', 'Pneumonia', 'Pneumothorax'
]
# Process each sample
for idx, row in rex_df.iterrows():
print(f"\n{'='*60}")
print(f"Sample {idx + 1}/{n_samples}")
print(f"{'='*60}")
print(f"Patient ID: {row['PatientID']}")
print(f"\nFindings:")
print(f" {row['Findings'][:200]}..." if len(row['Findings']) > 200 else f" {row['Findings']}")
print(f"\nImpression:")
print(f" {row['Impression'][:200]}..." if len(row['Impression']) > 200 else f" {row['Impression']}")
# Extract labels
print("\nExtracting labels...")
labels = extract_labels_from_text(
client, model,
row['Findings'],
row['Impression'],
retry=3,
enable_verification=enable_verification
)
if labels is None:
print("❌ Failed to extract labels!")
continue
print("\n✅ Extracted Labels:")
print("-" * 60)
print(f"{'Condition':<30s} | Label")
print("-" * 60)
for col, label in zip(label_columns, labels):
emoji = "✓" if label == 1 else "✗" if label == 0 else "?"
label_str = "Present (+1)" if label == 1 else "Absent (0)" if label == 0 else "Uncertain (-1)"
print(f"{col:<30s} | {emoji} {label_str}")
print("-" * 60)
print(f"\n{'='*60}")
print("Test Complete!")
print("="*60)
print("\nIf the results look reasonable, run the full pseudo labeling:")
print(" # With verification (default):")
print(" python /data/code/CXR_embedding_research/pseudo_labeling_rex.py")
print("")
print(" # Without verification:")
print(" python /data/code/CXR_embedding_research/pseudo_labeling_rex.py --disable-verification")
print("="*60)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Test pseudo labeling with verification')
parser.add_argument('--n-samples', type=int, default=5, help='Number of samples to test')
parser.add_argument('--disable-verification', action='store_true', help='Disable verification')
args = parser.parse_args()
test_on_samples(n_samples=args.n_samples, enable_verification=not args.disable_verification)