Skip to content

Commit eef5839

Browse files
committed
Update scripts
1 parent 74aa033 commit eef5839

13 files changed

+53
-44
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,5 @@ wandb/
149149
**/*.safetensors
150150

151151
# docs
152-
examples/*.png
153152
out/
153+
**/*.png

README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ python examples/inference/landmark_coordinate.py
7373
| Landmark localization by heatmap regression | LAX 2C or LAX 4C | 1 | [landmark_heatmap.py](examples/inference/landmark_heatmap.py) |
7474
| Landmark localization by coordinates regression | LAX 2C or LAX 4C | 1 | [landmark_coordinate.py](examples/inference/landmark_coordinate.py) |
7575

76-
### Use pre-trained models for fine-tuning
76+
### Use pre-trained models
7777

7878
The pre-trained CineMA model backbone is available at https://huggingface.co/mathpluscode/CineMA. Following scripts
7979
demonstrated how to fine-tune this backbone using
@@ -91,10 +91,17 @@ python examples/train/regression.py
9191
| Cardiovascular disease classification | [classification.py](examples/train/classification.py) |
9292
| Ejection fraction regression | [regression.py](examples/train/regression.py) |
9393

94-
For other datasets, pre-process can be performed using the provided scripts following the documentations. Note that it
95-
is recommended to download the data under `~/.cache/cinema_datasets` as the integration tests uses this path. For
96-
instance, the mnms preprocessed data would be `~/.cache/cinema_datasets/mnms/processed`. Otherwise define the path using
97-
environment variable `CINEMA_DATA_DIR`.
94+
Another two scripts demonstrated the masking and prediction process of MAE and the feature extraction from MAE.
95+
96+
```bash
97+
python examples/inference/mae.py
98+
python examples/inference/mae_feature_extraction.py
99+
```
100+
101+
For fine-tuning CineMA on other datasets, pre-process can be performed using the provided scripts following the
102+
documentations. Note that it is recommended to download the data under `~/.cache/cinema_datasets` as the integration
103+
tests uses this path. For instance, the mnms preprocessed data would be `~/.cache/cinema_datasets/mnms/processed`.
104+
Otherwise define the path using environment variable `CINEMA_DATA_DIR`.
98105

99106
| Training Data | Documentations |
100107
| ------------- | -------------------------------------------- |

cinema/mae/mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def from_pretrained(cls, **kwargs) -> CineMA: # type: ignore[no-untyped-def]
629629
# download config
630630
config_path = hf_hub_download(
631631
repo_id="mathpluscode/CineMA",
632-
filename="pretrained/cinema.yaml",
632+
filename="pretrained/config.yaml",
633633
**kwargs,
634634
)
635635
logger.info(f"Cached model config to {config_path}.")

examples/inference/classification_cvd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ def run(trained_dataset: str, view: str, seed: int) -> None:
1717
# load config to get class names
1818
config_path = hf_hub_download(
1919
repo_id="mathpluscode/CineMA",
20-
filename=f"finetuned/classification_cvd/{trained_dataset}_{view}.yaml",
20+
filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/config.yaml",
2121
)
2222
config = OmegaConf.load(config_path)
2323
classes = list(config.data[config.data.class_column])
2424

2525
# load model
2626
model = ConvViT.from_finetuned(
2727
repo_id="mathpluscode/CineMA",
28-
model_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}_{seed}.safetensors",
29-
config_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}.yaml",
28+
model_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
29+
config_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/config.yaml",
3030
)
3131

3232
# load sample data from mnms2 of class HCM and form a batch of size 1

examples/inference/classification_sex.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ def run(seed: int) -> None:
1818
# load config to get class names
1919
config_path = hf_hub_download(
2020
repo_id="mathpluscode/CineMA",
21-
filename=f"finetuned/classification_sex/{trained_dataset}_{view}.yaml",
21+
filename=f"finetuned/classification_sex/{trained_dataset}_{view}/config.yaml",
2222
)
2323
config = OmegaConf.load(config_path)
2424
classes = list(config.data[config.data.class_column])
2525

2626
# load model
2727
model = ConvViT.from_finetuned(
2828
repo_id="mathpluscode/CineMA",
29-
model_filename=f"finetuned/classification_sex/{trained_dataset}_{view}_{seed}.safetensors",
30-
config_filename=f"finetuned/classification_sex/{trained_dataset}_{view}.yaml",
29+
model_filename=f"finetuned/classification_sex/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
30+
config_filename=f"finetuned/classification_sex/{trained_dataset}_{view}/config.yaml",
3131
)
3232

3333
# load sample data from mnms2 of class HCM and form a batch of size 1

examples/inference/classification_vendor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ def run(view: str, seed: int) -> None:
1818
# load config to get class names
1919
config_path = hf_hub_download(
2020
repo_id="mathpluscode/CineMA",
21-
filename=f"finetuned/classification_vendor/{trained_dataset}_{view}.yaml",
21+
filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/config.yaml",
2222
)
2323
config = OmegaConf.load(config_path)
2424
classes = list(config.data[config.data.class_column])
2525

2626
# load model
2727
model = ConvViT.from_finetuned(
2828
repo_id="mathpluscode/CineMA",
29-
model_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}_{seed}.safetensors",
30-
config_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}.yaml",
29+
model_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
30+
config_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/config.yaml",
3131
)
3232

3333
# load sample data from mnms2 of class HCM and form a batch of size 1

examples/inference/landmark_coordinate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def run(view: str, seed: int) -> None:
1717
# load model
1818
model = ConvViT.from_finetuned(
1919
repo_id="mathpluscode/CineMA",
20-
model_filename=f"finetuned/landmark_coordinate/{view}_{seed}.safetensors",
21-
config_filename=f"finetuned/landmark_coordinate/{view}.yaml",
20+
model_filename=f"finetuned/landmark_coordinate/{view}/{view}_{seed}.safetensors",
21+
config_filename=f"finetuned/landmark_coordinate/{view}/config.yaml",
2222
)
2323

2424
# load sample data and form a batch of size 1

examples/inference/landmark_heatmap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def run(view: str, seed: int) -> None:
1717
# load model
1818
model = ConvUNetR.from_finetuned(
1919
repo_id="mathpluscode/CineMA",
20-
model_filename=f"finetuned/landmark_heatmap/{view}_{seed}.safetensors",
21-
config_filename=f"finetuned/landmark_heatmap/{view}.yaml",
20+
model_filename=f"finetuned/landmark_heatmap/{view}/{view}_{seed}.safetensors",
21+
config_filename=f"finetuned/landmark_heatmap/{view}/config.yaml",
2222
)
2323

2424
# load sample data and form a batch of size 1

examples/inference/regression_age.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def run(seed: int) -> None:
1818
# load config to get class names
1919
config_path = hf_hub_download(
2020
repo_id="mathpluscode/CineMA",
21-
filename=f"finetuned/regression_age/{trained_dataset}_{view}.yaml",
21+
filename=f"finetuned/regression_age/{trained_dataset}_{view}/config.yaml",
2222
)
2323
config = OmegaConf.load(config_path)
2424
mean = config.data[config.data.regression_column].mean
@@ -27,8 +27,8 @@ def run(seed: int) -> None:
2727
# load model
2828
model = ConvViT.from_finetuned(
2929
repo_id="mathpluscode/CineMA",
30-
model_filename=f"finetuned/regression_age/{trained_dataset}_{view}_{seed}.safetensors",
31-
config_filename=f"finetuned/regression_age/{trained_dataset}_{view}.yaml",
30+
model_filename=f"finetuned/regression_age/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
31+
config_filename=f"finetuned/regression_age/{trained_dataset}_{view}/config.yaml",
3232
)
3333

3434
# load sample data

examples/inference/regression_bmi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def run(seed: int) -> None:
1818
# load config to get class names
1919
config_path = hf_hub_download(
2020
repo_id="mathpluscode/CineMA",
21-
filename=f"finetuned/regression_bmi/{trained_dataset}_{view}.yaml",
21+
filename=f"finetuned/regression_bmi/{trained_dataset}_{view}/config.yaml",
2222
)
2323
config = OmegaConf.load(config_path)
2424
mean = config.data[config.data.regression_column].mean
@@ -27,8 +27,8 @@ def run(seed: int) -> None:
2727
# load model
2828
model = ConvViT.from_finetuned(
2929
repo_id="mathpluscode/CineMA",
30-
model_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}_{seed}.safetensors",
31-
config_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}.yaml",
30+
model_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
31+
config_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}/config.yaml",
3232
)
3333

3434
# load sample data

0 commit comments

Comments
 (0)