Skip to content

Commit 460cf08

Browse files
committed
Improve MAE reconstruction example
1 parent 3f79189 commit 460cf08

File tree

1 file changed

+4
-4
lines changed
  • cinema/examples/inference

1 file changed

+4
-4
lines changed

cinema/examples/inference/mae.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,27 @@ def plot_mae_reconstruction(
2222
n_cols = 4
2323
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), dpi=300)
2424
for i, view in enumerate(["lax_2c", "lax_3c", "lax_4c", "sax"]):
25-
masks = masks_dict[view]
2625
reconstructed = reconstructed_dict[view]
2726
image = image_dict[view]
27+
masked = (1 - masks_dict[view]) * image
2828
error = np.abs(reconstructed - image)
2929

3030
if view == "sax":
3131
for j in range(sax_slices):
3232
axs[3 + j, 0].set_ylabel(f"SAX slice {j}")
3333
axs[3 + j, 0].imshow(image[..., j], cmap="gray")
34-
axs[3 + j, 1].imshow(masks[..., j], cmap="gray")
34+
axs[3 + j, 1].imshow(masked[..., j], cmap="gray")
3535
axs[3 + j, 2].imshow(reconstructed[..., j], cmap="gray")
3636
axs[3 + j, 3].imshow(error[..., j], cmap="gray")
3737
else:
3838
axs[i, 0].imshow(image, cmap="gray")
39-
axs[i, 1].imshow(masks, cmap="gray")
39+
axs[i, 1].imshow(masked, cmap="gray")
4040
axs[i, 2].imshow(reconstructed, cmap="gray")
4141
axs[i, 3].imshow(error, cmap="gray")
4242
axs[i, 0].set_ylabel({"lax_2c": "LAX 2C", "lax_3c": "LAX 3C", "lax_4c": "LAX 4C"}[view])
4343
if i == 0:
4444
axs[i, 0].set_title("Original")
45-
axs[i, 1].set_title("Mask")
45+
axs[i, 1].set_title("Masked")
4646
axs[i, 2].set_title("Reconstructed")
4747
axs[i, 3].set_title("Error")
4848
# remove the x and y ticks

0 commit comments

Comments
 (0)