Skip to content

Commit e9a4d1a

Browse files
committed
Refactor
1 parent 55bbcd5 commit e9a4d1a

File tree

22 files changed

+85
-52
lines changed

22 files changed

+85
-52
lines changed

cinema/classification/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def classification_eval_dataset( # pylint:disable=too-many-statements
8383
model = get_classification_or_regression_model(config)
8484
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
8585
model.load_state_dict(checkpoint["model"])
86-
model.to(device)
8786
model.eval()
87+
model.to(device)
8888

8989
# inference
9090
pred_labels = []

cinema/examples/inference/classification_cvd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def run(trained_dataset: str, view: str, seed: int, device: torch.device, dtype:
2828
model_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
2929
config_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/config.yaml",
3030
)
31+
model.eval()
3132
model.to(device)
3233

3334
# load sample data from mnms2 of class HCM and form a batch of size 1

cinema/examples/inference/classification_sex.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def run(seed: int, device: torch.device, dtype: torch.dtype) -> None:
2929
model_filename=f"finetuned/classification_sex/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
3030
config_filename=f"finetuned/classification_sex/{trained_dataset}_{view}/config.yaml",
3131
)
32+
model.eval()
3233
model.to(device)
3334

3435
# load sample data from mnms2 of class HCM and form a batch of size 1

cinema/examples/inference/classification_vendor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
2929
model_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
3030
config_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/config.yaml",
3131
)
32+
model.eval()
3233
model.to(device)
3334

3435
# load sample data from mnms2 of class HCM and form a batch of size 1

cinema/examples/inference/landmark_coordinate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
2020
model_filename=f"finetuned/landmark_coordinate/{view}/{view}_{seed}.safetensors",
2121
config_filename=f"finetuned/landmark_coordinate/{view}/config.yaml",
2222
)
23+
model.eval()
2324
model.to(device)
2425

2526
# load sample data and form a batch of size 1

cinema/examples/inference/landmark_heatmap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
2020
model_filename=f"finetuned/landmark_heatmap/{view}/{view}_{seed}.safetensors",
2121
config_filename=f"finetuned/landmark_heatmap/{view}/config.yaml",
2222
)
23+
model.eval()
2324
model.to(device)
2425

2526
# load sample data and form a batch of size 1

cinema/examples/inference/mae.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,72 @@
1111
from cinema import CineMA, patchify, unpatchify
1212

1313

14+
def plot_mae_reconstruction(
15+
model: CineMA,
16+
batch: dict[str, torch.Tensor],
17+
pred_dict: dict[str, torch.Tensor],
18+
enc_mask_dict: dict[str, torch.Tensor],
19+
sax_slices: int,
20+
) -> plt.Figure:
21+
"""Plot MAE reconstruction."""
22+
n_rows = sax_slices + 3
23+
n_cols = 4
24+
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), dpi=300)
25+
for i, view in enumerate(["lax_2c", "lax_3c", "lax_4c", "sax"]):
26+
patches = patchify(image=batch[view], patch_size=model.dec_patch_size_dict[view])
27+
patches[enc_mask_dict[view]] = pred_dict[view]
28+
masks = torch.zeros_like(patches)
29+
masks[enc_mask_dict[view]] = 1
30+
masks = unpatchify(
31+
masks, patch_size=model.dec_patch_size_dict[view], grid_size=model.enc_down_dict[view].patch_embed.grid_size
32+
)
33+
masks = masks[0, 0]
34+
reconstructed = unpatchify(
35+
patches,
36+
patch_size=model.dec_patch_size_dict[view],
37+
grid_size=model.enc_down_dict[view].patch_embed.grid_size,
38+
)
39+
reconstructed = reconstructed[0, 0].detach().cpu().numpy()
40+
image = batch[view][0, 0].detach().cpu().numpy()
41+
error = np.abs(reconstructed - image)
42+
43+
if view == "sax":
44+
reconstructed = reconstructed[..., :sax_slices]
45+
for j in range(sax_slices):
46+
axs[3 + j, 0].set_ylabel(f"SAX slice {j}")
47+
axs[3 + j, 0].imshow(image[..., j], cmap="gray")
48+
axs[3 + j, 1].imshow(masks[..., j], cmap="gray")
49+
axs[3 + j, 2].imshow(reconstructed[..., j], cmap="gray")
50+
axs[3 + j, 3].imshow(error[..., j], cmap="gray")
51+
else:
52+
axs[i, 0].imshow(image, cmap="gray")
53+
axs[i, 1].imshow(masks, cmap="gray")
54+
axs[i, 2].imshow(reconstructed, cmap="gray")
55+
axs[i, 3].imshow(error, cmap="gray")
56+
axs[i, 0].set_ylabel({"lax_2c": "LAX 2C", "lax_3c": "LAX 3C", "lax_4c": "LAX 4C"}[view])
57+
if i == 0:
58+
axs[i, 0].set_title("Original")
59+
axs[i, 1].set_title("Mask")
60+
axs[i, 2].set_title("Reconstructed")
61+
axs[i, 3].set_title("Error")
62+
# remove the x and y ticks
63+
for i in range(n_rows):
64+
for j in range(n_cols):
65+
axs[i, j].set_xticks([])
66+
axs[i, j].set_yticks([])
67+
fig.tight_layout()
68+
fig.subplots_adjust(wspace=0, hspace=0)
69+
return fig
70+
71+
1472
def run(device: torch.device, dtype: torch.dtype) -> None:
1573
"""Run MAE reconstruction."""
74+
t = 25 # which time frame to use
75+
1676
# load model
1777
model = CineMA.from_pretrained()
18-
model.to(device)
1978
model.eval()
79+
model.to(device)
2080

2181
# load sample data and form a batch of size 1
2282
transform = Compose(
@@ -46,7 +106,7 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
46106
lax_4c_image = torch.from_numpy(
47107
np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_lax_4c.nii.gz")))
48108
)
49-
t = 25 # which time frame to use
109+
sax_slices = sax_image.shape[-2]
50110
batch = {
51111
"sax": sax_image[None, ..., t],
52112
"lax_2c": lax_2c_image[None, ..., 0, t],
@@ -62,45 +122,8 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
62122
_, pred_dict, enc_mask_dict, _ = model(batch, enc_mask_ratio=0.75)
63123

64124
# visualize
65-
_, axs = plt.subplots(6, 4, figsize=(12, 18))
66-
for i, view in enumerate(["lax_2c", "lax_3c", "lax_4c", "sax"]):
67-
patches = patchify(image=batch[view], patch_size=model.dec_patch_size_dict[view])
68-
patches[enc_mask_dict[view]] = pred_dict[view]
69-
masks = torch.zeros_like(patches)
70-
masks[enc_mask_dict[view]] = 1
71-
masks = unpatchify(
72-
masks, patch_size=model.dec_patch_size_dict[view], grid_size=model.enc_down_dict[view].patch_embed.grid_size
73-
)
74-
masks = masks[0, 0]
75-
reconstructed = unpatchify(
76-
patches,
77-
patch_size=model.dec_patch_size_dict[view],
78-
grid_size=model.enc_down_dict[view].patch_embed.grid_size,
79-
)
80-
reconstructed = reconstructed[0, 0].detach().cpu().numpy()
81-
image = batch[view][0, 0].detach().cpu().numpy()
82-
error = np.abs(reconstructed - image)
83-
84-
if view == "sax":
85-
for j in range(3):
86-
z = j * 3
87-
axs[3 + j, 0].set_ylabel(f"{view} slice {z}")
88-
axs[3 + j, 0].imshow(image[..., z], cmap="gray")
89-
axs[3 + j, 1].imshow(masks[..., z], cmap="gray")
90-
axs[3 + j, 2].imshow(reconstructed[..., z], cmap="gray")
91-
axs[3 + j, 3].imshow(error[..., z], cmap="gray")
92-
else:
93-
axs[i, 0].imshow(image, cmap="gray")
94-
axs[i, 1].imshow(masks, cmap="gray")
95-
axs[i, 2].imshow(reconstructed, cmap="gray")
96-
axs[i, 3].imshow(error, cmap="gray")
97-
axs[i, 0].set_ylabel(view)
98-
if i == 0:
99-
axs[i, 0].set_title("Original")
100-
axs[i, 1].set_title("Mask")
101-
axs[i, 2].set_title("Reconstructed")
102-
axs[i, 3].set_title("Error")
103-
plt.savefig("mae_reconstruction.png", dpi=300, bbox_inches="tight")
125+
fig = plot_mae_reconstruction(model, batch, pred_dict, enc_mask_dict, sax_slices)
126+
fig.savefig("mae_reconstruction.png", dpi=300, bbox_inches="tight")
104127
plt.show(block=False)
105128

106129

cinema/examples/inference/mae_feature_extraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
1414
"""Run MAE feature extraction."""
1515
# load model
1616
model = CineMA.from_pretrained()
17-
model.to(device)
1817
model.eval()
18+
model.to(device)
1919

2020
# load sample data and form a batch of size 1
2121
transform = Compose(

cinema/examples/inference/regression_age.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def run(seed: int, device: torch.device, dtype: torch.dtype) -> None:
3030
model_filename=f"finetuned/regression_age/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
3131
config_filename=f"finetuned/regression_age/{trained_dataset}_{view}/config.yaml",
3232
)
33+
model.eval()
3334
model.to(device)
3435

3536
# load sample data

cinema/examples/inference/regression_bmi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def run(seed: int, device: torch.device, dtype: torch.dtype) -> None:
3030
model_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
3131
config_filename=f"finetuned/regression_bmi/{trained_dataset}_{view}/config.yaml",
3232
)
33+
model.eval()
3334
model.to(device)
3435

3536
# load sample data

0 commit comments

Comments
 (0)