|
| 1 | +#!/usr/bin/env python3 |
| 2 | +import os |
| 3 | +import argparse |
| 4 | +import logging |
| 5 | +import pandas as pd |
| 6 | +import numpy as np |
| 7 | +from scipy.stats import mannwhitneyu |
| 8 | + |
| 9 | +def setup_logging(): |
| 10 | + """Configure logging output.""" |
| 11 | + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
| 12 | + |
| 13 | +def escape_latex(s: str) -> str: |
| 14 | + """Escape underscores for LaTeX output.""" |
| 15 | + return s.replace("_", r"\_") |
| 16 | + |
| 17 | +def is_baseline(trainer: str) -> bool: |
| 18 | + """ |
| 19 | + Determine if a trainer is considered a baseline model. |
| 20 | + Baseline trainers contain "TotalSegmentator" or "ARTPLAN" in their name. |
| 21 | + """ |
| 22 | + return "TotalSegmentator" in trainer or "ARTPLAN" in trainer |
| 23 | + |
| 24 | +def parse_arguments(): |
| 25 | + """Parse command-line arguments.""" |
| 26 | + parser = argparse.ArgumentParser( |
| 27 | + description="Generate LaTeX table for Dice comparison across datasets." |
| 28 | + ) |
| 29 | + parser.add_argument( |
| 30 | + "--datasets", type=str, default="all", |
| 31 | + help="Comma-separated list of dataset numbers (e.g., '500,67,297') or 'all' to include all." |
| 32 | + ) |
| 33 | + return parser.parse_args() |
| 34 | + |
| 35 | +def get_dataset_folders(base_dir: str, selected_datasets: set, dataset_map: dict) -> list: |
| 36 | + """ |
| 37 | + Retrieve and filter dataset folders based on selected datasets. |
| 38 | + |
| 39 | + Parameters: |
| 40 | + base_dir (str): The directory containing dataset folders. |
| 41 | + selected_datasets (set): Set of dataset numbers to include (or None for all). |
| 42 | + dataset_map (dict): Mapping of dataset numbers to folder names. |
| 43 | + |
| 44 | + Returns: |
| 45 | + list: List of dataset folder paths. |
| 46 | + """ |
| 47 | + all_folders = [ |
| 48 | + os.path.join(base_dir, d) |
| 49 | + for d in os.listdir(base_dir) |
| 50 | + if os.path.isdir(os.path.join(base_dir, d)) |
| 51 | + ] |
| 52 | + if selected_datasets: |
| 53 | + valid_folders = {dataset_map[ds] for ds in selected_datasets if ds in dataset_map} |
| 54 | + return [d for d in all_folders if os.path.basename(d) in valid_folders] |
| 55 | + return all_folders |
| 56 | + |
| 57 | +def process_dataset_folder(dataset_path: str, results: dict) -> str: |
| 58 | + """ |
| 59 | + Process a single dataset folder: read trainer subfolders and update the results dictionary. |
| 60 | + |
| 61 | + Parameters: |
| 62 | + dataset_path (str): Path to the dataset folder. |
| 63 | + results (dict): Nested dictionary to store dice values and later metrics. |
| 64 | + |
| 65 | + Returns: |
| 66 | + str: The dataset name. |
| 67 | + """ |
| 68 | + dataset_name = os.path.basename(dataset_path) |
| 69 | + logging.info(f"Processing dataset: {dataset_name}") |
| 70 | + |
| 71 | + trainer_folders = [ |
| 72 | + os.path.join(dataset_path, d) |
| 73 | + for d in os.listdir(dataset_path) |
| 74 | + if os.path.isdir(os.path.join(dataset_path, d)) |
| 75 | + ] |
| 76 | + |
| 77 | + for trainer_path in trainer_folders: |
| 78 | + trainer_name = os.path.basename(trainer_path) |
| 79 | + metrics_csv = os.path.join(trainer_path, "patient_wise_metrics.csv") |
| 80 | + if not os.path.exists(metrics_csv): |
| 81 | + logging.warning(f"{metrics_csv} not found. Skipping trainer {trainer_name}.") |
| 82 | + continue |
| 83 | + |
| 84 | + df = pd.read_csv(metrics_csv) |
| 85 | + # Identify dice columns (columns starting with "dice-") |
| 86 | + dice_columns = [col for col in df.columns if col.startswith("dice-")] |
| 87 | + if not dice_columns: |
| 88 | + continue |
| 89 | + |
| 90 | + for dice_col in dice_columns: |
| 91 | + # Crop ROI name to the first 4 characters for brevity |
| 92 | + roi = dice_col.split("-")[1][:4] |
| 93 | + dice_values = df[dice_col].dropna().values |
| 94 | + |
| 95 | + # Initialize the nested dictionary structure as needed |
| 96 | + results.setdefault(trainer_name, {}).setdefault(roi, {})[dataset_name] = { |
| 97 | + "dice": dice_values |
| 98 | + } |
| 99 | + return dataset_name |
| 100 | + |
| 101 | +def compute_statistical_tests(results: dict): |
| 102 | + """ |
| 103 | + Compute p-values comparing each non-baseline trainer against the baseline for each ROI and dataset |
| 104 | + using the Mann–Whitney U test for non-normal data. |
| 105 | + Additionally, store the maximum baseline mean Dice score (i.e. the best performing baseline) |
| 106 | + among all baselines. |
| 107 | + Updates the results dictionary in place. |
| 108 | + """ |
| 109 | + for trainer, rois in results.items(): |
| 110 | + if is_baseline(trainer): |
| 111 | + continue |
| 112 | + for roi, dataset_metrics in rois.items(): |
| 113 | + for dataset, metrics in dataset_metrics.items(): |
| 114 | + baseline_dice = [] |
| 115 | + baseline_means = [] |
| 116 | + for baseline_trainer, baseline_rois in results.items(): |
| 117 | + if is_baseline(baseline_trainer) and roi in baseline_rois and dataset in baseline_rois[roi]: |
| 118 | + b_dice = baseline_rois[roi][dataset]["dice"] |
| 119 | + baseline_dice.extend(b_dice) |
| 120 | + baseline_means.append(np.mean(b_dice)) |
| 121 | + if baseline_dice and baseline_means: |
| 122 | + # Compute p-value using the Mann–Whitney U test |
| 123 | + stat, pvalue = mannwhitneyu(metrics["dice"], baseline_dice, alternative='two-sided') |
| 124 | + # Get the best performing baseline mean Dice score |
| 125 | + max_baseline = max(baseline_means) |
| 126 | + results[trainer][roi][dataset]["pvalue"] = pvalue |
| 127 | + results[trainer][roi][dataset]["max_baseline"] = max_baseline |
| 128 | + |
| 129 | +def determine_best_scores(results: dict, dataset_names: list, trainers: list, all_rois: list) -> dict: |
| 130 | + """ |
| 131 | + For each ROI and dataset, determine which trainer achieved the highest mean Dice score. |
| 132 | + |
| 133 | + Returns: |
| 134 | + dict: A nested dictionary mapping each ROI and dataset to the best trainer. |
| 135 | + """ |
| 136 | + best_scores = {roi: {} for roi in all_rois} |
| 137 | + for roi in all_rois: |
| 138 | + for dataset in dataset_names: |
| 139 | + best_score = -1 |
| 140 | + best_trainer = None |
| 141 | + for trainer in trainers: |
| 142 | + trainer_data = results.get(trainer, {}) |
| 143 | + if roi in trainer_data and dataset in trainer_data[roi]: |
| 144 | + score = np.mean(trainer_data[roi][dataset]["dice"]) |
| 145 | + if score > best_score: |
| 146 | + best_score = score |
| 147 | + best_trainer = trainer |
| 148 | + best_scores[roi][dataset] = best_trainer |
| 149 | + return best_scores |
| 150 | + |
| 151 | +def format_cell(trainer: str, roi: str, dataset_names: list, results: dict, best_scores: dict) -> str: |
| 152 | + """ |
| 153 | + Format a table cell for a given trainer and ROI. |
| 154 | + |
| 155 | + Displays mean Dice percentages per dataset, adds a dagger if: |
| 156 | + - The trainer is not a baseline, |
| 157 | + - The p-value is below 0.05, and |
| 158 | + - The mean Dice score (in percentage) exceeds the maximum baseline mean (also converted to percentage), |
| 159 | + i.e. the model improves over every baseline, |
| 160 | + - The dataset is one of ["Dataset067_Pediatric_Internal", "Dataset500_TCIA"]. |
| 161 | + |
| 162 | + The best score is bolded. |
| 163 | + """ |
| 164 | + cell_values = [] |
| 165 | + for dname in dataset_names: |
| 166 | + trainer_data = results.get(trainer, {}) |
| 167 | + if roi in trainer_data and dname in trainer_data[roi]: |
| 168 | + rec = trainer_data[roi][dname] |
| 169 | + # Calculate the mean dice and convert to percentage |
| 170 | + if len(rec["dice"]) > 0: |
| 171 | + dice_mean = np.nanmean(rec["dice"]) * 100 |
| 172 | + else: |
| 173 | + dice_mean = float('nan') |
| 174 | + # Round the values to the same precision as displayed (nearest integer) |
| 175 | + if not np.isnan(dice_mean): |
| 176 | + rounded_model = round(dice_mean) |
| 177 | + rounded_baseline = round(rec.get("max_baseline", 0) * 100) |
| 178 | + dice_text = f"{rounded_model:.0f}" |
| 179 | + # Add dagger only if the rounded new score is strictly greater than the rounded baseline score. |
| 180 | + if (not is_baseline(trainer) and |
| 181 | + rec.get("pvalue") is not None and |
| 182 | + rec.get("pvalue") < 0.05 and |
| 183 | + rounded_model > rounded_baseline and |
| 184 | + dname in ["Dataset067_Pediatric_Internal", "Dataset500_TCIA"]): |
| 185 | + dice_text += r"$^{\dagger}$" |
| 186 | + # Bold the best score for this ROI and dataset |
| 187 | + if trainer == best_scores.get(roi, {}).get(dname): |
| 188 | + dice_text = r"\textbf{" + dice_text + "}" |
| 189 | + cell_values.append(dice_text) |
| 190 | + else: |
| 191 | + cell_values.append("-") |
| 192 | + else: |
| 193 | + # If the model does not segment this ROI, indicate with an asterisk |
| 194 | + if "ARTPLAN" in trainer and roi in {"Panc", "Gall"}: |
| 195 | + cell_values.append("*") |
| 196 | + else: |
| 197 | + cell_values.append("-") |
| 198 | + return "/".join(cell_values) |
| 199 | + |
| 200 | +def build_latex_table(results: dict, dataset_names: list, all_rois: list, |
| 201 | + trainers: list, best_scores: dict, latex_path: str): |
| 202 | + """ |
| 203 | + Build the LaTeX table lines and write them to the specified file. |
| 204 | + The table groups trainers into categories and highlights best scores. |
| 205 | + """ |
| 206 | + # Reorder datasets as 297, 500, 67 using the mapping folder names |
| 207 | + desired_order = ["Dataset297_TotalSegmentator", "Dataset500_TCIA", "Dataset067_pediatric"] |
| 208 | + dataset_names = [d for d in desired_order if d in dataset_names] |
| 209 | + |
| 210 | + # Remove the Liver ROI (abbreviated as "Live") |
| 211 | + all_rois = [roi for roi in all_rois if roi.lower() != "live"] |
| 212 | + |
| 213 | + # Set the caption as requested. |
| 214 | + latex_lines = [ |
| 215 | + r"\begin{sidewaystable}[htbp]", |
| 216 | + r"\centering", |
| 217 | + r"\caption{Dice coefficient (\%) comparison across datasets (adult / pediatric / pediatric internal).}" |
| 218 | + ] |
| 219 | + |
| 220 | + # Total number of columns is assumed to be 12 as specified. |
| 221 | + col_spec = "l" + "c" * len(all_rois) |
| 222 | + latex_lines.append(r"\begin{tabular}{" + col_spec + "}") |
| 223 | + latex_lines.append(r"\toprule") |
| 224 | + |
| 225 | + # Insert an extra header row for ROI columns using the same formatting as for training types |
| 226 | + header_cells = [r"\rowcolor{gray!30} Trainer"] + [escape_latex(roi) for roi in all_rois] |
| 227 | + latex_lines.append(" & ".join(header_cells) + r" \\") |
| 228 | + latex_lines.append(r"\midrule") |
| 229 | + |
| 230 | + # Categorize trainers based on naming heuristics |
| 231 | + direct_learning = [] |
| 232 | + hybrid_learning = [] |
| 233 | + transfer_learning = [] |
| 234 | + baseline = [] |
| 235 | + for trainer in trainers: |
| 236 | + # The categorization logic below is heuristic based on name patterns. |
| 237 | + if trainer.endswith("T_o$") and len(trainer) > 6 and trainer[3] == trainer[6]: |
| 238 | + direct_learning.append(trainer) |
| 239 | + elif trainer.endswith("T_o$") and len(trainer) > 6 and trainer[3] != trainer[6]: |
| 240 | + hybrid_learning.append(trainer) |
| 241 | + elif is_baseline(trainer): |
| 242 | + baseline.append(trainer) |
| 243 | + else: |
| 244 | + transfer_learning.append(trainer) |
| 245 | + |
| 246 | + sorted_categories = [ |
| 247 | + ("Baseline", baseline), |
| 248 | + ("Direct Learning", direct_learning), |
| 249 | + ("Hybrid Learning", hybrid_learning), |
| 250 | + ("Transfer Learning", transfer_learning) |
| 251 | + ] |
| 252 | + |
| 253 | + # For alternating row colors in trainer rows. |
| 254 | + row_counter = 0 |
| 255 | + for category, trainer_list in sorted_categories: |
| 256 | + if trainer_list: |
| 257 | + # Insert a formatted category header row using the specified style. |
| 258 | + latex_lines.append(r"\midrule") |
| 259 | + latex_lines.append(r"\rowcolor{gray!30}") |
| 260 | + latex_lines.append(r"\multicolumn{12}{c}{\textbf{" + category + r"}} \\") |
| 261 | + latex_lines.append(r"\midrule") |
| 262 | + for trainer in trainer_list: |
| 263 | + row_prefix = "" |
| 264 | + if row_counter % 2 == 0: |
| 265 | + row_prefix = r"\rowcolor{gray!10} " |
| 266 | + row_cells = [row_prefix + trainer] # Trainer names are not escaped |
| 267 | + for roi in all_rois: |
| 268 | + cell_text = format_cell(trainer, roi, dataset_names, results, best_scores) |
| 269 | + row_cells.append(cell_text) |
| 270 | + latex_lines.append(" & ".join(row_cells) + r" \\") |
| 271 | + row_counter += 1 |
| 272 | + |
| 273 | + latex_lines.extend([ |
| 274 | + r"\bottomrule", |
| 275 | + r"\end{tabular}", |
| 276 | + r"\begin{tablenotes}", |
| 277 | + r"\footnotesize", |
| 278 | + r"\footnotesize \textbf{Bold} marks the best DSC for each ROI/dataset. ROI names are abbreviated: Blad (Bladder), Duod (Duodenum), Esop (Esophagus), Gall (Gallbladder), Kidn (Kidney), Panc (Pancreas), Pros (Prostate), Smal (Small Intestine), Spin (Spinal Canal), Sple (Spleen), Stom (Stomach). $^{\dagger}$ indicates a statistically significant improvement over the best performing baseline ($p < 0.05$). “-” indicates that the physician reference is unavailable and “*” that the model does not segment this ROI.", |
| 279 | + r"\end{tablenotes}", |
| 280 | + r"\end{sidewaystable}" |
| 281 | + ]) |
| 282 | + |
| 283 | + with open(latex_path, "w") as f: |
| 284 | + f.write("\n".join(latex_lines)) |
| 285 | + logging.info(f"LaTeX dice table saved to {latex_path}") |
| 286 | + |
| 287 | +def main(): |
| 288 | + setup_logging() |
| 289 | + args = parse_arguments() |
| 290 | + |
| 291 | + # Determine which datasets to include |
| 292 | + selected_datasets = None if args.datasets.lower() == "all" else set(args.datasets.split(",")) |
| 293 | + |
| 294 | + base_dir = "nnUNet_predict" |
| 295 | + output_folder = "analysis_output" |
| 296 | + latex_folder = os.path.join(output_folder, "latex_tables") |
| 297 | + os.makedirs(latex_folder, exist_ok=True) |
| 298 | + tex_filename = os.path.join(latex_folder, "dice_comparison.tex") |
| 299 | + |
| 300 | + # Mapping dataset numbers to folder names |
| 301 | + dataset_map = { |
| 302 | + "297": "Dataset297_TotalSegmentator", |
| 303 | + "500": "Dataset500_TCIA", |
| 304 | + "67": "Dataset067_Pediatric_Internal" |
| 305 | + } |
| 306 | + |
| 307 | + dataset_folders = get_dataset_folders(base_dir, selected_datasets, dataset_map) |
| 308 | + |
| 309 | + results = {} |
| 310 | + dataset_names = [] |
| 311 | + for dataset_path in dataset_folders: |
| 312 | + dataset_name = process_dataset_folder(dataset_path, results) |
| 313 | + if dataset_name not in dataset_names: |
| 314 | + dataset_names.append(dataset_name) |
| 315 | + |
| 316 | + dataset_names = list(set(dataset_names)) |
| 317 | + |
| 318 | + # Determine the union of all ROIs across trainers |
| 319 | + all_rois = sorted({roi for trainer_data in results.values() for roi in trainer_data.keys()}) |
| 320 | + |
| 321 | + # Sorted list of trainer names |
| 322 | + trainers = sorted(list(results.keys())) |
| 323 | + |
| 324 | + # Compute statistical tests comparing each trainer to the baseline |
| 325 | + compute_statistical_tests(results) |
| 326 | + |
| 327 | + # Determine the best trainer (highest mean Dice) for each ROI and dataset |
| 328 | + best_scores = determine_best_scores(results, dataset_names, trainers, all_rois) |
| 329 | + |
| 330 | + # Build and save the LaTeX table |
| 331 | + build_latex_table(results, dataset_names, all_rois, trainers, best_scores, tex_filename) |
| 332 | + |
| 333 | +if __name__ == "__main__": |
| 334 | + main() |
0 commit comments