Skip to content

Commit 352fb5f

Browse files
committed
[~] small refacto
1 parent 27817c9 commit 352fb5f

File tree

3 files changed

+144
-5
lines changed

3 files changed

+144
-5
lines changed

README.md

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# PSAT
22

3-
This repository contains the code and configuration files for PSAT (Pediatric Segmentation Approaches via Adult Augmentations and Transfer).
3+
This repository contains the code and configuration files for PSAT (Pediatric Segmentation Approaches via Adult Augmentations and Transfer Learning).
44

55
## Overview
66

@@ -12,6 +12,42 @@ PSAT addresses pediatric segmentation challenges by combining:
1212

1313
<img src="resources/images/PSAT_overview.png" alt="PSAT Overview" style="width:80%; max-width:1000px; display:block; margin: 0 auto;">
1414

15+
## Citation
16+
If you use this code, please cite our paper:
17+
18+
```
19+
@article{kirscher2025psat,
20+
title={PSAT: Pediatric Segmentation Approaches via Adult Augmentations and Transfer Learning},
21+
author={T. Kirscher et al},
22+
journal={MICCAI},
23+
year={2025},
24+
note={arXiv:xxxx.xxxxx}
25+
}
26+
```
27+
28+
## Quickstart
29+
30+
Install dependencies:
31+
```bash
32+
pip install -r requirements.txt
33+
```
34+
35+
Run metrics evaluation (example):
36+
```bash
37+
python scripts/compute_metrics.py <ground_truth_dir> <predictions_dir>
38+
```
39+
Replace `<ground_truth_dir>` and `<predictions_dir>` with your folder paths containing NIfTI files.
40+
41+
## Dependencies
42+
- nibabel
43+
- numpy
44+
- pandas
45+
- p_tqdm
46+
- scipy
47+
- surface-distance
48+
49+
(See `requirements.txt` for full list.)
50+
1551
## Documentation
1652

1753
- [nnUNet](nnUNet/nnUNet.md)
@@ -28,4 +64,4 @@ Install dependencies listed in `requirements.txt` and run:
2864

2965
```bash
3066
pytest -q
31-
```
67+
```

scripts/compute_metrics.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77
and Hausdorff distances (https://github.com/deepmind/surface-distance).
88
- Dynamic retrieval of pixel spacing to support varying image resolutions.
99
- Reorientation checks to ensure proper alignment and spacing of NIfTI images.
10+
11+
Usage:
12+
python scripts/compute_metrics.py <ground_truth_dir> <predictions_dir>
13+
14+
Arguments:
15+
ground_truth_dir: Directory containing ground truth NIfTI files (*.nii.gz)
16+
predictions_dir: Directory containing predicted NIfTI files (*.nii.gz)
17+
18+
Expected file format:
19+
- Each subject should have a file named <subject>.nii.gz in both directories.
20+
- Files must be 3D or 4D NIfTI images with integer labels.
21+
22+
Dependencies:
23+
- nibabel
24+
- numpy
25+
- pandas
26+
- p_tqdm
27+
- scipy
28+
- surface-distance
1029
"""
1130

1231
import sys
@@ -34,6 +53,11 @@
3453
def dice_score(y_true, y_pred):
3554
"""
3655
Binary Dice score. Same results as sklearn f1 binary.
56+
Args:
57+
y_true (np.ndarray): Binary ground truth mask.
58+
y_pred (np.ndarray): Binary predicted mask.
59+
Returns:
60+
float: Dice coefficient.
3761
"""
3862
intersect = np.sum(y_true * y_pred)
3963
denominator = np.sum(y_true) + np.sum(y_pred)
@@ -42,11 +66,12 @@ def dice_score(y_true, y_pred):
4266

4367

4468
def reorient_to_ras(img):
45-
"""Reorient the image to RAS (Right-Anterior-Superior) orientation."""
69+
"""Reorient the image to RAS (Right-Anterior-Superior) orientation using nibabel."""
4670
return nib.as_closest_canonical(img)
4771

4872

4973
def calc_metrics(subject, gt_dir=None, pred_dir=None, class_map=None):
74+
# Load ground truth and prediction images for a subject
5075
try:
5176
gt_img = nib.load(gt_dir / f"{subject}.nii.gz")
5277
pred_img = nib.load(pred_dir / f"{subject}.nii.gz")
@@ -70,6 +95,7 @@ def calc_metrics(subject, gt_dir=None, pred_dir=None, class_map=None):
7095
gt = gt_all == idx
7196
pred = pred_all == idx
7297

98+
# Handle cases where ground truth or prediction is missing for a class
7399
if gt.max() > 0 and pred.max() == 0:
74100
r[f"dice-{roi_name}"] = 0
75101
r[f"hausdorff-{roi_name}"] = 0
@@ -103,9 +129,12 @@ def calculate_confidence_interval(data, confidence=0.95):
103129
"""
104130
Calculate Dice score and Hausdorff distance for your nnU-Net predictions.
105131
106-
example usage:
107-
python evaluate.py ground_truth_dir predictions_dir
132+
Example usage:
133+
python scripts/compute_metrics.py <ground_truth_dir> <predictions_dir>
134+
135+
See the top-level docstring for more details.
108136
"""
137+
# Parse input arguments
109138
gt_dir = Path(sys.argv[1])
110139
pred_dir = Path(sys.argv[2])
111140

scripts/create_totalseg_subset.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import logging
2+
from pathlib import Path
3+
import shutil
4+
import pandas as pd
5+
import matplotlib.pyplot as plt
6+
from typing import List
7+
8+
def plot_and_save_distribution(df: pd.DataFrame, title: str, filename: str) -> None:
9+
"""Plot age and gender distribution and save to file.
10+
11+
Parameters
12+
----------
13+
df : pandas.DataFrame
14+
DataFrame containing at least ``age`` and ``gender`` columns.
15+
title : str
16+
Title for the plot.
17+
filename : str
18+
Output image file path.
19+
"""
20+
fig, ax = plt.subplots()
21+
df.boxplot(column="age", by="gender", ax=ax)
22+
ax.set_title(title)
23+
ax.set_xlabel("gender")
24+
ax.set_ylabel("age")
25+
plt.suptitle("")
26+
fig.tight_layout()
27+
fig.savefig(filename)
28+
plt.close(fig)
29+
30+
def create_directory_structure(base_path: Path, subdirs: List[str]) -> None:
31+
"""Create the required directory structure for the new dataset."""
32+
base_path.mkdir(parents=True, exist_ok=True)
33+
for subdir in subdirs:
34+
(base_path / subdir).mkdir(parents=True, exist_ok=True)
35+
36+
def copy_selected_files(
37+
original_base: Path, new_base: Path, subdirs: List[str], selected_image_ids: List[str]
38+
) -> None:
39+
"""Copy files matching selected image IDs from the original dataset to the new dataset."""
40+
for subdir in subdirs:
41+
original_dir = original_base / subdir
42+
new_dir = new_base / subdir
43+
44+
if not original_dir.exists():
45+
logging.warning(f"Directory {original_dir} does not exist. Skipping...")
46+
continue
47+
48+
for file_name in original_dir.iterdir():
49+
if any(file_name.name.startswith(image_id) for image_id in selected_image_ids):
50+
shutil.copy(file_name, new_dir / file_name.name)
51+
52+
def main() -> None:
53+
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
54+
55+
# Paths
56+
base_path = Path("nnUNet_raw_data_base/Dataset297_TotalSegmentator")
57+
new_base_path = Path("nnUNet_raw_data_base/Dataset797_TotalSegmentator_plus_TCIA")
58+
subdirs = ["imagesTr", "imagesTs", "labelsTr", "labelsTs"]
59+
60+
# Load the selected subset
61+
subset_csv = Path("resources/TotalSegmentator/dataset_subset.csv")
62+
selected_subset = pd.read_csv(subset_csv)
63+
selected_image_ids = selected_subset["image_id"].tolist()
64+
65+
# Create new dataset directory structure
66+
create_directory_structure(new_base_path, subdirs)
67+
68+
# Copy relevant files
69+
copy_selected_files(base_path, new_base_path, subdirs, selected_image_ids)
70+
71+
logging.info(f"Subset dataset created successfully at: {new_base_path}")
72+
73+
if __name__ == "__main__":
74+
main()

0 commit comments

Comments
 (0)