Skip to content

Commit 57904c8

Browse files
committed
Update scripts
1 parent 74aa033 commit 57904c8

13 files changed

+55
-46
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,5 @@ wandb/
149149
**/*.safetensors
150150

151151
# docs
152-
examples/*.png
153152
out/
153+
**/*.png

README.md

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<img alt="CineMA logo" src="logo_light.svg" height="256">
55
</picture>
66

7-
# CineMA: A Foundation Model for Cine Cardiac MRI 🎥🫀
7+
# CineMA: A Vision Foundation Model for Cine Cardiac MRI 🎥🫀
88

99
![python](https://img.shields.io/badge/Python-3.11-3776AB.svg?style=flat&logo=python&logoColor=white)
1010
![pytorch](https://img.shields.io/badge/PyTorch-EE4C2C?style=flat&logo=pytorch&logoColor=white)
@@ -15,7 +15,7 @@
1515

1616
## Overview
1717

18-
**CineMA** is a foundation model for **Cine** cardiac magnetic resonance (CMR) imaging based on
18+
**CineMA** is a vision foundation model for **Cine** cardiac magnetic resonance (CMR) imaging based on
1919
**M**asked-**A**utoencoder. CineMA has been pre-trained on UK Biobank data and fine-tuned on multiple clinically
2020
relevant tasks such as ventricle and myocaridum segmentation, ejection fraction (EF) regression, cardiovascular disease
2121
(CVD) detection and classification, and mid-valve plane and apical landmark localization. The model has been evaluated
@@ -73,7 +73,7 @@ python examples/inference/landmark_coordinate.py
7373
| Landmark localization by heatmap regression | LAX 2C or LAX 4C | 1 | [landmark_heatmap.py](examples/inference/landmark_heatmap.py) |
7474
| Landmark localization by coordinates regression | LAX 2C or LAX 4C | 1 | [landmark_coordinate.py](examples/inference/landmark_coordinate.py) |
7575

76-
### Use pre-trained models for fine-tuning
76+
### Use pre-trained models
7777

7878
The pre-trained CineMA model backbone is available at https://huggingface.co/mathpluscode/CineMA. Following scripts
7979
demonstrated how to fine-tune this backbone using
@@ -91,10 +91,17 @@ python examples/train/regression.py
9191
| Cardiovascular disease classification | [classification.py](examples/train/classification.py) |
9292
| Ejection fraction regression | [regression.py](examples/train/regression.py) |
9393

94-
For other datasets, pre-process can be performed using the provided scripts following the documentations. Note that it
95-
is recommended to download the data under `~/.cache/cinema_datasets` as the integration tests uses this path. For
96-
instance, the mnms preprocessed data would be `~/.cache/cinema_datasets/mnms/processed`. Otherwise define the path using
97-
environment variable `CINEMA_DATA_DIR`.
94+
Another two scripts demonstrated the masking and prediction process of MAE and the feature extraction from MAE.
95+
96+
```bash
97+
python examples/inference/mae.py
98+
python examples/inference/mae_feature_extraction.py
99+
```
100+
101+
For fine-tuning CineMA on other datasets, pre-process can be performed using the provided scripts following the
102+
documentations. Note that it is recommended to download the data under `~/.cache/cinema_datasets` as the integration
103+
tests uses this path. For instance, the mnms preprocessed data would be `~/.cache/cinema_datasets/mnms/processed`.
104+
Otherwise define the path using environment variable `CINEMA_DATA_DIR`.
98105

99106
| Training Data | Documentations |
100107
| ------------- | -------------------------------------------- |

cinema/mae/mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def from_pretrained(cls, **kwargs) -> CineMA: # type: ignore[no-untyped-def]
629629
# download config
630630
config_path = hf_hub_download(
631631
repo_id="mathpluscode/CineMA",
632-
filename="pretrained/cinema.yaml",
632+
filename="pretrained/config.yaml",
633633
**kwargs,
634634
)
635635
logger.info(f"Cached model config to {config_path}.")

examples/inference/classification_cvd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ def run(trained_dataset: str, view: str, seed: int) -> None:
1717
# load config to get class names
1818
config_path = hf_hub_download(
1919
repo_id="mathpluscode/CineMA",
20-
filename=f"finetuned/classification_cvd/{trained_dataset}_{view}.yaml",
20+
filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/config.yaml",
2121
)
2222
config = OmegaConf.load(config_path)
2323
classes = list(config.data[config.data.class_column])
2424

2525
# load model
2626
model = ConvViT.from_finetuned(
2727
repo_id="mathpluscode/CineMA",
28-
model_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}_{seed}.safetensors",
29-
config_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}.yaml",
28+
model_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
29+
config_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/config.yaml",
3030
)
3131

3232
# load sample data from mnms2 of class HCM and form a batch of size 1

examples/inference/classification_sex.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ def run(seed: int) -> None:
1818
# load config to get class names
1919
config_path = hf_hub_download(
2020
repo_id="mathpluscode/CineMA",
21-
filename=f"finetuned/classification_sex/{trained_dataset}_{view}.yaml",
21+
filename=f"finetuned/classification_sex/{trained_dataset}_{view}/config.yaml",
2222
)
2323
config = OmegaConf.load(config_path)
2424
classes = list(config.data[config.data.class_column])
2525

2626
# load model
2727
model = ConvViT.from_finetuned(
2828
repo_id="mathpluscode/CineMA",
29-
model_filename=f"finetuned/classification_sex/{trained_dataset}_{view}_{seed}.safetensors",
30-
config_filename=f"finetuned/classification_sex/{trained_dataset}_{view}.yaml",
29+
model_filename=f"finetuned/classification_sex/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
30+
config_filename=f"finetuned/classification_sex/{trained_dataset}_{view}/config.yaml",
3131
)
3232

3333
# load sample data from mnms2 of class HCM and form a batch of size 1

examples/inference/classification_vendor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ def run(view: str, seed: int) -> None:
1818
# load config to get class names
1919
config_path = hf_hub_download(
2020
repo_id="mathpluscode/CineMA",
21-
filename=f"finetuned/classification_vendor/{trained_dataset}_{view}.yaml",
21+
filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/config.yaml",
2222
)
2323
config = OmegaConf.load(config_path)
2424
classes = list(config.data[config.data.class_column])
2525

2626
# load model
2727
model = ConvViT.from_finetuned(
2828
repo_id="mathpluscode/CineMA",
29-
model_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}_{seed}.safetensors",
30-
config_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}.yaml",
29+
model_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
30+
config_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/config.yaml",
3131
)
3232

3333
# load sample data from mnms2 of class HCM and form a batch of size 1

examples/inference/landmark_coordinate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def run(view: str, seed: int) -> None:
1717
# load model
1818
model = ConvViT.from_finetuned(
1919
repo_id="mathpluscode/CineMA",
20-
model_filename=f"finetuned/landmark_coordinate/{view}_{seed}.safetensors",
21-
config_filename=f"finetuned/landmark_coordinate/{view}.yaml",
20+
model_filename=f"finetuned/landmark_coordinate/{view}/{view}_{seed}.safetensors",
21+
config_filename=f"finetuned/landmark_coordinate/{view}/config.yaml",
2222
)
2323

2424
# load sample data and form a batch of size 1

examples/inference/landmark_heatmap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def run(view: str, seed: int) -> None:
1717
# load model
1818
model = ConvUNetR.from_finetuned(
1919
repo_id="mathpluscode/CineMA",
20-
model_filename=f"finetuned/landmark_heatmap/{view}_{seed}.safetensors",
21-
config_filename=f"finetuned/landmark_heatmap/{view}.yaml",
20+
model_filename=f"finetuned/landmark_heatmap/{view}/{view}_{seed}.safetensors",
21+
config_filename=f"finetuned/landmark_heatmap/{view}/config.yaml",
2222
)
2323

2424
# load sample data and form a batch of size 1

examples/inference/regression_age.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def run(seed: int) -> None:
1818
# load config to get class names
1919
config_path = hf_hub_download(
2020
repo_id="mathpluscode/CineMA",
21-
filename=f"finetuned/regression_age/{trained_dataset}_{view}.yaml",
21+
filename=f"finetuned/regression_age/{trained_dataset}_{view}/config.yaml",
2222
)
2323
config = OmegaConf.load(config_path)
2424
mean = config.data[config.data.regression_column].mean
@@ -27,8 +27,8 @@ def run(seed: int) -> None:
2727
# load model
2828
model = ConvViT.from_finetuned(
2929
repo_id="mathpluscode/CineMA",
30-
model_filename=f"finetuned/regression_age/{trained_dataset}_{view}_{seed}.safetensors",
31-
config_filename=f"finetuned/regression_age/{trained_dataset}_{view}.yaml",
30+
model_filename=f"finetuned/regression_age/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
31+
config_filename=f"finetuned/regression_age/{trained_dataset}_{view}/config.yaml",
3232
)
3333

3434
# load sample data

examples/inference/regression_bmi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def run(seed: int) -> None:
1818
# load config to get class names
1919
config_path = hf_hub_download(
2020
repo_id="mathpluscode/CineMA",
21-
filename=f"finetuned/regression_bmi/{trained_dataset}_{view}.yaml",
21+
filename=f"finetuned/regression_bmi/{trained_dataset}_{view}/config.yaml",
2222
)
2323
config = OmegaConf.load(config_path)
2424
mean = config.data[config.data.regression_column].mean
@@ -27,8 +27,8 @@ def run(seed: int) -> None:
2727
# load model
2828
model = ConvViT.from_finetuned(
2929
repo_id="mathpluscode/CineMA",
30-
model_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}_{seed}.safetensors",
31-
config_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}.yaml",
30+
model_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
31+
config_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}/config.yaml",
3232
)
3333

3434
# load sample data

0 commit comments

Comments
 (0)