Skip to content

Commit 7418ccf

Browse files
committed
Refactor MAE example
1 parent a631b67 commit 7418ccf

File tree

1 file changed

+64
-45
lines changed
  • cinema/examples/inference

1 file changed

+64
-45
lines changed

cinema/examples/inference/mae.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,22 @@
1212

1313

1414
def plot_mae_reconstruction(
15-
batch: dict[str, torch.Tensor],
16-
pred_dict: dict[str, torch.Tensor],
17-
enc_mask_dict: dict[str, torch.Tensor],
18-
patch_size_dict: dict[str, tuple[int, ...]],
19-
grid_size_dict: dict[str, tuple[int, ...]],
20-
sax_slices: int,
15+
image_dict: dict[str, torch.Tensor],
16+
reconstructed_dict: dict[str, torch.Tensor],
17+
masks_dict: dict[str, torch.Tensor],
2118
) -> plt.Figure:
2219
"""Plot MAE reconstruction."""
20+
sax_slices = image_dict["sax"].shape[-1]
2321
n_rows = sax_slices + 3
2422
n_cols = 4
2523
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), dpi=300)
2624
for i, view in enumerate(["lax_2c", "lax_3c", "lax_4c", "sax"]):
27-
patches = patchify(image=batch[view], patch_size=patch_size_dict[view])
28-
patches[enc_mask_dict[view]] = pred_dict[view]
29-
masks = torch.zeros_like(patches)
30-
masks[enc_mask_dict[view]] = 1
31-
masks = unpatchify(masks, patch_size=patch_size_dict[view], grid_size=grid_size_dict[view])
32-
masks = masks[0, 0]
33-
reconstructed = unpatchify(
34-
patches,
35-
patch_size=patch_size_dict[view],
36-
grid_size=grid_size_dict[view],
37-
)
38-
reconstructed = reconstructed[0, 0].numpy()
39-
image = batch[view][0, 0].numpy()
25+
masks = masks_dict[view]
26+
reconstructed = reconstructed_dict[view]
27+
image = image_dict[view]
4028
error = np.abs(reconstructed - image)
4129

4230
if view == "sax":
43-
reconstructed = reconstructed[..., :sax_slices]
4431
for j in range(sax_slices):
4532
axs[3 + j, 0].set_ylabel(f"SAX slice {j}")
4633
axs[3 + j, 0].imshow(image[..., j], cmap="gray")
@@ -68,15 +55,42 @@ def plot_mae_reconstruction(
6855
return fig
6956

7057

58+
def reconstruct_images(
59+
batch: dict[str, torch.Tensor],
60+
pred_dict: dict[str, torch.Tensor],
61+
enc_mask_dict: dict[str, torch.Tensor],
62+
patch_size_dict: dict[str, tuple[int, ...]],
63+
grid_size_dict: dict[str, tuple[int, ...]],
64+
sax_slices: int,
65+
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
66+
"""Reconstruct images from predicted patches."""
67+
reconstructed_dict = {}
68+
masks_dict = {}
69+
for view in ["lax_2c", "lax_3c", "lax_4c", "sax"]:
70+
patches = patchify(image=batch[view], patch_size=patch_size_dict[view])
71+
patches[enc_mask_dict[view]] = pred_dict[view]
72+
masks = torch.zeros_like(patches)
73+
masks[enc_mask_dict[view]] = 1
74+
masks = unpatchify(masks, patch_size=patch_size_dict[view], grid_size=grid_size_dict[view])
75+
reconstructed = unpatchify(
76+
patches,
77+
patch_size=patch_size_dict[view],
78+
grid_size=grid_size_dict[view],
79+
)
80+
reconstructed_dict[view] = reconstructed.detach().cpu().numpy()[0, 0]
81+
masks_dict[view] = masks.detach().cpu().numpy()[0, 0]
82+
reconstructed_dict["sax"] = reconstructed_dict["sax"][..., :sax_slices]
83+
masks_dict["sax"] = masks_dict["sax"][..., :sax_slices]
84+
return reconstructed_dict, masks_dict
85+
86+
7187
def run(device: torch.device, dtype: torch.dtype) -> None:
7288
"""Run MAE reconstruction."""
7389
t = 25 # which time frame to use
7490

7591
# load model
7692
model = CineMA.from_pretrained()
7793
model.eval()
78-
patch_size_dict = model.dec_patch_size_dict
79-
grid_size_dict = {k: v.patch_embed.grid_size for k, v in model.enc_down_dict.items()}
8094
model.to(device)
8195

8296
# load sample data and form a batch of size 1
@@ -95,36 +109,41 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
95109
)
96110
# (x, y, z, t) for SAX and (x, y, 1, t) for LAX
97111
exp_dir = Path(__file__).parent.parent.resolve()
98-
sax_image = torch.from_numpy(
99-
np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_sax.nii.gz")))
100-
)
101-
lax_2c_image = torch.from_numpy(
102-
np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_lax_2c.nii.gz")))
103-
)
104-
lax_3c_image = torch.from_numpy(
105-
np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_lax_3c.nii.gz")))
106-
)
107-
lax_4c_image = torch.from_numpy(
108-
np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_lax_4c.nii.gz")))
109-
)
110-
sax_slices = sax_image.shape[-2]
111-
batch = {
112-
"sax": sax_image[None, ..., t],
113-
"lax_2c": lax_2c_image[None, ..., 0, t],
114-
"lax_3c": lax_3c_image[None, ..., 0, t],
115-
"lax_4c": lax_4c_image[None, ..., 0, t],
112+
sax_image = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_sax.nii.gz")))
113+
lax_2c_image = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_lax_2c.nii.gz")))
114+
lax_3c_image = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_lax_3c.nii.gz")))
115+
lax_4c_image = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_lax_4c.nii.gz")))
116+
117+
image_dict = {
118+
"sax": sax_image[..., t],
119+
"lax_2c": lax_2c_image[..., 0, t],
120+
"lax_3c": lax_3c_image[..., 0, t],
121+
"lax_4c": lax_4c_image[..., 0, t],
116122
}
117-
batch = transform(batch)
118-
print(f"SAX view had originally {sax_image.shape[-2]} slices, now zero-padded to {batch['sax'].shape[-1]} slices.") # noqa: T201
119-
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
123+
batch = {k: torch.from_numpy(v[None, ...]) for k, v in image_dict.items()}
120124

121125
# forward
126+
sax_slices = batch["sax"].shape[-1]
127+
batch = transform(batch)
128+
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
122129
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
123130
_, pred_dict, enc_mask_dict, _ = model(batch, enc_mask_ratio=0.75)
131+
grid_size_dict = {k: v.patch_embed.grid_size for k, v in model.enc_down_dict.items()}
132+
reconstructed_dict, masks_dict = reconstruct_images(
133+
batch,
134+
pred_dict,
135+
enc_mask_dict,
136+
model.dec_patch_size_dict,
137+
grid_size_dict,
138+
sax_slices,
139+
)
124140

125141
# visualize
126-
batch = {k: v.detach().cpu() for k, v in batch.items()}
127-
fig = plot_mae_reconstruction(batch, pred_dict, enc_mask_dict, patch_size_dict, grid_size_dict, sax_slices)
142+
fig = plot_mae_reconstruction(
143+
image_dict,
144+
reconstructed_dict,
145+
masks_dict,
146+
)
128147
fig.savefig("mae_reconstruction.png", dpi=300, bbox_inches="tight")
129148
plt.show(block=False)
130149

0 commit comments

Comments
 (0)