Skip to content

Commit 8b763a8

Browse files
committed
[+] add get_results.py
1 parent a940751 commit 8b763a8

File tree

2 files changed

+336
-1
lines changed

2 files changed

+336
-1
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ torch
22
pandas
33
numpy
44
matplotlib
5-
nnunetv2
5+
nnunetv2
6+
scipy

scripts/get_results.py

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
#!/usr/bin/env python3
2+
import os
3+
import argparse
4+
import logging
5+
import pandas as pd
6+
import numpy as np
7+
from scipy.stats import mannwhitneyu
8+
9+
def setup_logging():
10+
"""Configure logging output."""
11+
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
12+
13+
def escape_latex(s: str) -> str:
14+
"""Escape underscores for LaTeX output."""
15+
return s.replace("_", r"\_")
16+
17+
def is_baseline(trainer: str) -> bool:
18+
"""
19+
Determine if a trainer is considered a baseline model.
20+
Baseline trainers contain "TotalSegmentator" or "ARTPLAN" in their name.
21+
"""
22+
return "TotalSegmentator" in trainer or "ARTPLAN" in trainer
23+
24+
def parse_arguments():
25+
"""Parse command-line arguments."""
26+
parser = argparse.ArgumentParser(
27+
description="Generate LaTeX table for Dice comparison across datasets."
28+
)
29+
parser.add_argument(
30+
"--datasets", type=str, default="all",
31+
help="Comma-separated list of dataset numbers (e.g., '500,67,297') or 'all' to include all."
32+
)
33+
return parser.parse_args()
34+
35+
def get_dataset_folders(base_dir: str, selected_datasets: set, dataset_map: dict) -> list:
36+
"""
37+
Retrieve and filter dataset folders based on selected datasets.
38+
39+
Parameters:
40+
base_dir (str): The directory containing dataset folders.
41+
selected_datasets (set): Set of dataset numbers to include (or None for all).
42+
dataset_map (dict): Mapping of dataset numbers to folder names.
43+
44+
Returns:
45+
list: List of dataset folder paths.
46+
"""
47+
all_folders = [
48+
os.path.join(base_dir, d)
49+
for d in os.listdir(base_dir)
50+
if os.path.isdir(os.path.join(base_dir, d))
51+
]
52+
if selected_datasets:
53+
valid_folders = {dataset_map[ds] for ds in selected_datasets if ds in dataset_map}
54+
return [d for d in all_folders if os.path.basename(d) in valid_folders]
55+
return all_folders
56+
57+
def process_dataset_folder(dataset_path: str, results: dict) -> str:
58+
"""
59+
Process a single dataset folder: read trainer subfolders and update the results dictionary.
60+
61+
Parameters:
62+
dataset_path (str): Path to the dataset folder.
63+
results (dict): Nested dictionary to store dice values and later metrics.
64+
65+
Returns:
66+
str: The dataset name.
67+
"""
68+
dataset_name = os.path.basename(dataset_path)
69+
logging.info(f"Processing dataset: {dataset_name}")
70+
71+
trainer_folders = [
72+
os.path.join(dataset_path, d)
73+
for d in os.listdir(dataset_path)
74+
if os.path.isdir(os.path.join(dataset_path, d))
75+
]
76+
77+
for trainer_path in trainer_folders:
78+
trainer_name = os.path.basename(trainer_path)
79+
metrics_csv = os.path.join(trainer_path, "patient_wise_metrics.csv")
80+
if not os.path.exists(metrics_csv):
81+
logging.warning(f"{metrics_csv} not found. Skipping trainer {trainer_name}.")
82+
continue
83+
84+
df = pd.read_csv(metrics_csv)
85+
# Identify dice columns (columns starting with "dice-")
86+
dice_columns = [col for col in df.columns if col.startswith("dice-")]
87+
if not dice_columns:
88+
continue
89+
90+
for dice_col in dice_columns:
91+
# Crop ROI name to the first 4 characters for brevity
92+
roi = dice_col.split("-")[1][:4]
93+
dice_values = df[dice_col].dropna().values
94+
95+
# Initialize the nested dictionary structure as needed
96+
results.setdefault(trainer_name, {}).setdefault(roi, {})[dataset_name] = {
97+
"dice": dice_values
98+
}
99+
return dataset_name
100+
101+
def compute_statistical_tests(results: dict):
102+
"""
103+
Compute p-values comparing each non-baseline trainer against the baseline for each ROI and dataset
104+
using the Mann–Whitney U test for non-normal data.
105+
Additionally, store the maximum baseline mean Dice score (i.e. the best performing baseline)
106+
among all baselines.
107+
Updates the results dictionary in place.
108+
"""
109+
for trainer, rois in results.items():
110+
if is_baseline(trainer):
111+
continue
112+
for roi, dataset_metrics in rois.items():
113+
for dataset, metrics in dataset_metrics.items():
114+
baseline_dice = []
115+
baseline_means = []
116+
for baseline_trainer, baseline_rois in results.items():
117+
if is_baseline(baseline_trainer) and roi in baseline_rois and dataset in baseline_rois[roi]:
118+
b_dice = baseline_rois[roi][dataset]["dice"]
119+
baseline_dice.extend(b_dice)
120+
baseline_means.append(np.mean(b_dice))
121+
if baseline_dice and baseline_means:
122+
# Compute p-value using the Mann–Whitney U test
123+
stat, pvalue = mannwhitneyu(metrics["dice"], baseline_dice, alternative='two-sided')
124+
# Get the best performing baseline mean Dice score
125+
max_baseline = max(baseline_means)
126+
results[trainer][roi][dataset]["pvalue"] = pvalue
127+
results[trainer][roi][dataset]["max_baseline"] = max_baseline
128+
129+
def determine_best_scores(results: dict, dataset_names: list, trainers: list, all_rois: list) -> dict:
130+
"""
131+
For each ROI and dataset, determine which trainer achieved the highest mean Dice score.
132+
133+
Returns:
134+
dict: A nested dictionary mapping each ROI and dataset to the best trainer.
135+
"""
136+
best_scores = {roi: {} for roi in all_rois}
137+
for roi in all_rois:
138+
for dataset in dataset_names:
139+
best_score = -1
140+
best_trainer = None
141+
for trainer in trainers:
142+
trainer_data = results.get(trainer, {})
143+
if roi in trainer_data and dataset in trainer_data[roi]:
144+
score = np.mean(trainer_data[roi][dataset]["dice"])
145+
if score > best_score:
146+
best_score = score
147+
best_trainer = trainer
148+
best_scores[roi][dataset] = best_trainer
149+
return best_scores
150+
151+
def format_cell(trainer: str, roi: str, dataset_names: list, results: dict, best_scores: dict) -> str:
152+
"""
153+
Format a table cell for a given trainer and ROI.
154+
155+
Displays mean Dice percentages per dataset, adds a dagger if:
156+
- The trainer is not a baseline,
157+
- The p-value is below 0.05, and
158+
- The mean Dice score (in percentage) exceeds the maximum baseline mean (also converted to percentage),
159+
i.e. the model improves over every baseline,
160+
- The dataset is one of ["Dataset067_Pediatric_Internal", "Dataset500_TCIA"].
161+
162+
The best score is bolded.
163+
"""
164+
cell_values = []
165+
for dname in dataset_names:
166+
trainer_data = results.get(trainer, {})
167+
if roi in trainer_data and dname in trainer_data[roi]:
168+
rec = trainer_data[roi][dname]
169+
# Calculate the mean dice and convert to percentage
170+
if len(rec["dice"]) > 0:
171+
dice_mean = np.nanmean(rec["dice"]) * 100
172+
else:
173+
dice_mean = float('nan')
174+
# Round the values to the same precision as displayed (nearest integer)
175+
if not np.isnan(dice_mean):
176+
rounded_model = round(dice_mean)
177+
rounded_baseline = round(rec.get("max_baseline", 0) * 100)
178+
dice_text = f"{rounded_model:.0f}"
179+
# Add dagger only if the rounded new score is strictly greater than the rounded baseline score.
180+
if (not is_baseline(trainer) and
181+
rec.get("pvalue") is not None and
182+
rec.get("pvalue") < 0.05 and
183+
rounded_model > rounded_baseline and
184+
dname in ["Dataset067_Pediatric_Internal", "Dataset500_TCIA"]):
185+
dice_text += r"$^{\dagger}$"
186+
# Bold the best score for this ROI and dataset
187+
if trainer == best_scores.get(roi, {}).get(dname):
188+
dice_text = r"\textbf{" + dice_text + "}"
189+
cell_values.append(dice_text)
190+
else:
191+
cell_values.append("-")
192+
else:
193+
# If the model does not segment this ROI, indicate with an asterisk
194+
if "ARTPLAN" in trainer and roi in {"Panc", "Gall"}:
195+
cell_values.append("*")
196+
else:
197+
cell_values.append("-")
198+
return "/".join(cell_values)
199+
200+
def build_latex_table(results: dict, dataset_names: list, all_rois: list,
201+
trainers: list, best_scores: dict, latex_path: str):
202+
"""
203+
Build the LaTeX table lines and write them to the specified file.
204+
The table groups trainers into categories and highlights best scores.
205+
"""
206+
# Reorder datasets as 297, 500, 67 using the mapping folder names
207+
desired_order = ["Dataset297_TotalSegmentator", "Dataset500_TCIA", "Dataset067_pediatric"]
208+
dataset_names = [d for d in desired_order if d in dataset_names]
209+
210+
# Remove the Liver ROI (abbreviated as "Live")
211+
all_rois = [roi for roi in all_rois if roi.lower() != "live"]
212+
213+
# Set the caption as requested.
214+
latex_lines = [
215+
r"\begin{sidewaystable}[htbp]",
216+
r"\centering",
217+
r"\caption{Dice coefficient (\%) comparison across datasets (adult / pediatric / pediatric internal).}"
218+
]
219+
220+
# Total number of columns is assumed to be 12 as specified.
221+
col_spec = "l" + "c" * len(all_rois)
222+
latex_lines.append(r"\begin{tabular}{" + col_spec + "}")
223+
latex_lines.append(r"\toprule")
224+
225+
# Insert an extra header row for ROI columns using the same formatting as for training types
226+
header_cells = [r"\rowcolor{gray!30} Trainer"] + [escape_latex(roi) for roi in all_rois]
227+
latex_lines.append(" & ".join(header_cells) + r" \\")
228+
latex_lines.append(r"\midrule")
229+
230+
# Categorize trainers based on naming heuristics
231+
direct_learning = []
232+
hybrid_learning = []
233+
transfer_learning = []
234+
baseline = []
235+
for trainer in trainers:
236+
# The categorization logic below is heuristic based on name patterns.
237+
if trainer.endswith("T_o$") and len(trainer) > 6 and trainer[3] == trainer[6]:
238+
direct_learning.append(trainer)
239+
elif trainer.endswith("T_o$") and len(trainer) > 6 and trainer[3] != trainer[6]:
240+
hybrid_learning.append(trainer)
241+
elif is_baseline(trainer):
242+
baseline.append(trainer)
243+
else:
244+
transfer_learning.append(trainer)
245+
246+
sorted_categories = [
247+
("Baseline", baseline),
248+
("Direct Learning", direct_learning),
249+
("Hybrid Learning", hybrid_learning),
250+
("Transfer Learning", transfer_learning)
251+
]
252+
253+
# For alternating row colors in trainer rows.
254+
row_counter = 0
255+
for category, trainer_list in sorted_categories:
256+
if trainer_list:
257+
# Insert a formatted category header row using the specified style.
258+
latex_lines.append(r"\midrule")
259+
latex_lines.append(r"\rowcolor{gray!30}")
260+
latex_lines.append(r"\multicolumn{12}{c}{\textbf{" + category + r"}} \\")
261+
latex_lines.append(r"\midrule")
262+
for trainer in trainer_list:
263+
row_prefix = ""
264+
if row_counter % 2 == 0:
265+
row_prefix = r"\rowcolor{gray!10} "
266+
row_cells = [row_prefix + trainer] # Trainer names are not escaped
267+
for roi in all_rois:
268+
cell_text = format_cell(trainer, roi, dataset_names, results, best_scores)
269+
row_cells.append(cell_text)
270+
latex_lines.append(" & ".join(row_cells) + r" \\")
271+
row_counter += 1
272+
273+
latex_lines.extend([
274+
r"\bottomrule",
275+
r"\end{tabular}",
276+
r"\begin{tablenotes}",
277+
r"\footnotesize",
278+
r"\footnotesize \textbf{Bold} marks the best DSC for each ROI/dataset. ROI names are abbreviated: Blad (Bladder), Duod (Duodenum), Esop (Esophagus), Gall (Gallbladder), Kidn (Kidney), Panc (Pancreas), Pros (Prostate), Smal (Small Intestine), Spin (Spinal Canal), Sple (Spleen), Stom (Stomach). $^{\dagger}$ indicates a statistically significant improvement over the best performing baseline ($p < 0.05$). “-” indicates that the physician reference is unavailable and “*” that the model does not segment this ROI.",
279+
r"\end{tablenotes}",
280+
r"\end{sidewaystable}"
281+
])
282+
283+
with open(latex_path, "w") as f:
284+
f.write("\n".join(latex_lines))
285+
logging.info(f"LaTeX dice table saved to {latex_path}")
286+
287+
def main():
288+
setup_logging()
289+
args = parse_arguments()
290+
291+
# Determine which datasets to include
292+
selected_datasets = None if args.datasets.lower() == "all" else set(args.datasets.split(","))
293+
294+
base_dir = "nnUNet_predict"
295+
output_folder = "analysis_output"
296+
latex_folder = os.path.join(output_folder, "latex_tables")
297+
os.makedirs(latex_folder, exist_ok=True)
298+
tex_filename = os.path.join(latex_folder, "dice_comparison.tex")
299+
300+
# Mapping dataset numbers to folder names
301+
dataset_map = {
302+
"297": "Dataset297_TotalSegmentator",
303+
"500": "Dataset500_TCIA",
304+
"67": "Dataset067_Pediatric_Internal"
305+
}
306+
307+
dataset_folders = get_dataset_folders(base_dir, selected_datasets, dataset_map)
308+
309+
results = {}
310+
dataset_names = []
311+
for dataset_path in dataset_folders:
312+
dataset_name = process_dataset_folder(dataset_path, results)
313+
if dataset_name not in dataset_names:
314+
dataset_names.append(dataset_name)
315+
316+
dataset_names = list(set(dataset_names))
317+
318+
# Determine the union of all ROIs across trainers
319+
all_rois = sorted({roi for trainer_data in results.values() for roi in trainer_data.keys()})
320+
321+
# Sorted list of trainer names
322+
trainers = sorted(list(results.keys()))
323+
324+
# Compute statistical tests comparing each trainer to the baseline
325+
compute_statistical_tests(results)
326+
327+
# Determine the best trainer (highest mean Dice) for each ROI and dataset
328+
best_scores = determine_best_scores(results, dataset_names, trainers, all_rois)
329+
330+
# Build and save the LaTeX table
331+
build_latex_table(results, dataset_names, all_rois, trainers, best_scores, tex_filename)
332+
333+
if __name__ == "__main__":
334+
main()

0 commit comments

Comments
 (0)