@@ -22,27 +22,27 @@ def plot_mae_reconstruction(
2222 n_cols = 4
2323 fig , axs = plt .subplots (n_rows , n_cols , figsize = (n_cols * 2 , n_rows * 2 ), dpi = 300 )
2424 for i , view in enumerate (["lax_2c" , "lax_3c" , "lax_4c" , "sax" ]):
25- masks = masks_dict [view ]
2625 reconstructed = reconstructed_dict [view ]
2726 image = image_dict [view ]
27+ masked = (1 - masks_dict [view ]) * image
2828 error = np .abs (reconstructed - image )
2929
3030 if view == "sax" :
3131 for j in range (sax_slices ):
3232 axs [3 + j , 0 ].set_ylabel (f"SAX slice { j } " )
3333 axs [3 + j , 0 ].imshow (image [..., j ], cmap = "gray" )
34- axs [3 + j , 1 ].imshow (masks [..., j ], cmap = "gray" )
34+ axs [3 + j , 1 ].imshow (masked [..., j ], cmap = "gray" )
3535 axs [3 + j , 2 ].imshow (reconstructed [..., j ], cmap = "gray" )
3636 axs [3 + j , 3 ].imshow (error [..., j ], cmap = "gray" )
3737 else :
3838 axs [i , 0 ].imshow (image , cmap = "gray" )
39- axs [i , 1 ].imshow (masks , cmap = "gray" )
39+ axs [i , 1 ].imshow (masked , cmap = "gray" )
4040 axs [i , 2 ].imshow (reconstructed , cmap = "gray" )
4141 axs [i , 3 ].imshow (error , cmap = "gray" )
4242 axs [i , 0 ].set_ylabel ({"lax_2c" : "LAX 2C" , "lax_3c" : "LAX 3C" , "lax_4c" : "LAX 4C" }[view ])
4343 if i == 0 :
4444 axs [i , 0 ].set_title ("Original" )
45- axs [i , 1 ].set_title ("Mask " )
45+ axs [i , 1 ].set_title ("Masked " )
4646 axs [i , 2 ].set_title ("Reconstructed" )
4747 axs [i , 3 ].set_title ("Error" )
4848 # remove the x and y ticks
0 commit comments