22
33from pathlib import Path
44
5+ import imageio
56import matplotlib .pyplot as plt
67import numpy as np
78import SimpleITK as sitk # noqa: N813
1213from cinema import ConvUNetR , heatmap_soft_argmax
1314
1415
15- def plot_heatmaps (images : np .ndarray , probs : np .ndarray , n_cols : int = 5 ) -> plt . Figure :
16- """Plot heatmaps.
16+ def plot_heatmaps (images : np .ndarray , probs : np .ndarray , filepath : Path ) -> None :
17+ """Plot heatmaps as animated GIF .
1718
1819 Args:
19- images: (x, y, t)
20+ images: (x, y, 1, t)
2021 probs: (3, x, y, t)
21- n_cols: number of columns
22-
23- Returns:
24- figure
22+ filepath: path to save the GIF file.
2523 """
2624 n_frames = probs .shape [- 1 ]
27- n_rows = n_frames // n_cols
28- fig , axs = plt .subplots (n_rows , n_cols , figsize = (n_cols , n_rows ), dpi = 300 )
29- for i in range (n_rows ):
30- for j in range (n_cols ):
31- t = i * n_cols + j
32- axs [i , j ].imshow (images [..., 0 , t ], cmap = "gray" )
33- axs [i , j ].imshow (probs [0 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
34- axs [i , j ].imshow (probs [1 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
35- axs [i , j ].imshow (probs [2 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
36- axs [i , j ].set_xticks ([])
37- axs [i , j ].set_yticks ([])
38- if j == 0 :
39- axs [i , j ].set_ylabel (f"t = { t } " )
40- fig .tight_layout ()
41- fig .subplots_adjust (wspace = 0 , hspace = 0 )
42- return fig
43-
44-
45- def plot_landmarks (images : np .ndarray , coords : np .ndarray , n_cols : int = 5 ) -> plt .Figure :
46- """Plot landmarks.
25+ temp_frame_paths = []
26+
27+ for t in tqdm (range (n_frames ), desc = "Creating heatmap GIF frames" ):
28+ # Create individual frame
29+ fig , ax = plt .subplots (figsize = (5 , 5 ), dpi = 300 )
30+
31+ # Plot image
32+ ax .imshow (images [..., 0 , t ], cmap = "gray" )
33+
34+ # Plot heatmap overlays
35+ ax .imshow (probs [0 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
36+ ax .imshow (probs [1 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
37+ ax .imshow (probs [2 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
38+
39+ # Remove axes
40+ ax .set_xticks ([])
41+ ax .set_yticks ([])
42+
43+ # Save frame
44+ frame_path = f"_tmp_heatmap_frame_{ t :03d} .png"
45+ plt .savefig (frame_path , bbox_inches = "tight" , pad_inches = 0 , dpi = 300 )
46+ plt .close (fig )
47+ temp_frame_paths .append (frame_path )
48+
49+ # Create GIF
50+ with imageio .get_writer (filepath , mode = "I" , duration = 100 , loop = 0 ) as writer :
51+ for frame_path in tqdm (temp_frame_paths , desc = "Creating heatmap GIF" ):
52+ image = imageio .v2 .imread (frame_path )
53+ writer .append_data (image )
54+ # Clean up temporary file
55+ Path (frame_path ).unlink ()
56+
57+
58+ def plot_landmarks (images : np .ndarray , coords : np .ndarray , filepath : Path ) -> None :
59+ """Plot landmarks as animated GIF.
4760
4861 Args:
49- images: (x, y, t)
62+ images: (x, y, 1, t)
5063 coords: (6, t)
51- n_cols: number of columns
52-
53- Returns:
54- figure
64+ filepath: path to save the GIF file.
5565 """
5666 n_frames = images .shape [- 1 ]
57- n_rows = n_frames // n_cols
58- fig , axs = plt .subplots (n_rows , n_cols , figsize = (n_cols , n_rows ), dpi = 300 )
59- for i in range (n_rows ):
60- for j in range (n_cols ):
61- t = i * n_cols + j
62-
63- # draw predictions with cross
64- preds = images [..., t ] * np .array ([1 , 1 , 1 ])[None , None , :]
65- preds = preds .clip (0 , 255 ).astype (np .uint8 )
66- for k in range (3 ):
67- pred_x , pred_y = coords [2 * k , t ], coords [2 * k + 1 , t ]
68- x1 , x2 = max (0 , pred_x - 9 ), min (preds .shape [0 ], pred_x + 10 )
69- y1 , y2 = max (0 , pred_y - 9 ), min (preds .shape [1 ], pred_y + 10 )
70- preds [pred_x , y1 :y2 ] = [255 , 0 , 0 ]
71- preds [x1 :x2 , pred_y ] = [255 , 0 , 0 ]
72-
73- axs [i , j ].imshow (preds )
74- axs [i , j ].set_xticks ([])
75- axs [i , j ].set_yticks ([])
76- if j == 0 :
77- axs [i , j ].set_ylabel (f"t = { t } " )
78- fig .tight_layout ()
79- fig .subplots_adjust (wspace = 0 , hspace = 0 )
80- return fig
81-
82-
83- def plot_lv (coords : np .ndarray ) -> plt .Figure :
67+ temp_frame_paths = []
68+
69+ for t in tqdm (range (n_frames ), desc = "Creating landmark GIF frames" ):
70+ # Create individual frame
71+ fig , ax = plt .subplots (figsize = (5 , 5 ), dpi = 300 )
72+
73+ # draw predictions with cross
74+ preds = images [..., t ] * np .array ([1 , 1 , 1 ])[None , None , :]
75+ preds = preds .clip (0 , 255 ).astype (np .uint8 )
76+ for k in range (3 ):
77+ pred_x , pred_y = coords [2 * k , t ], coords [2 * k + 1 , t ]
78+ x1 , x2 = max (0 , pred_x - 9 ), min (preds .shape [0 ], pred_x + 10 )
79+ y1 , y2 = max (0 , pred_y - 9 ), min (preds .shape [1 ], pred_y + 10 )
80+ preds [pred_x , y1 :y2 ] = [255 , 0 , 0 ]
81+ preds [x1 :x2 , pred_y ] = [255 , 0 , 0 ]
82+
83+ ax .imshow (preds )
84+ ax .set_xticks ([])
85+ ax .set_yticks ([])
86+
87+ # Save frame
88+ frame_path = f"_tmp_landmark_frame_{ t :03d} .png"
89+ plt .savefig (frame_path , bbox_inches = "tight" , pad_inches = 0 , dpi = 300 )
90+ plt .close (fig )
91+ temp_frame_paths .append (frame_path )
92+
93+ # Create GIF
94+ with imageio .get_writer (filepath , mode = "I" , duration = 100 , loop = 0 ) as writer :
95+ for frame_path in tqdm (temp_frame_paths , desc = "Creating landmark GIF" ):
96+ image = imageio .v2 .imread (frame_path )
97+ writer .append_data (image )
98+ # Clean up temporary file
99+ Path (frame_path ).unlink ()
100+
101+
102+ def plot_lv (coords : np .ndarray , filepath : Path ) -> None :
84103 """Plot GL shortening.
85104
86105 Args:
87106 coords: (6, t)
88-
89- Returns:
90- figure
107+ filepath: path to save the PNG file.
91108 """
92109 # GL shortening
93110 x1 , y1 = coords [0 ], coords [1 ]
@@ -108,12 +125,13 @@ def plot_lv(coords: np.ndarray) -> plt.Figure:
108125 ) / 2
109126
110127 fig = plt .figure (figsize = (4 , 4 ), dpi = 120 )
111- plt .plot (lv_lengths , color = "#82B366" , label = "LV " )
128+ plt .plot (lv_lengths , color = "#82B366" , label = "Left Ventricle " )
112129 plt .xlabel ("Frame" )
113130 plt .ylabel ("Length (mm)" )
114- plt .title (f"GLS = { gls :.2f} %, MAPSE = { mapse :.2f} mm" )
131+ plt .title (f"GLS = { gls :.2f} %\n MAPSE = { mapse :.2f} mm" )
115132 plt .legend (loc = "lower right" )
116- return fig
133+ fig .savefig (filepath , dpi = 300 , bbox_inches = "tight" )
134+ plt .close (fig )
117135
118136
119137def run (view : str , seed : int , device : torch .device , dtype : torch .dtype ) -> None :
@@ -150,19 +168,13 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
150168 coords = np .stack (coords_list , axis = - 1 ) # (6, t)
151169
152170 # visualise heatmaps
153- fig = plot_heatmaps (images , probs )
154- fig .savefig (f"landmark_heatmap_probs_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
155- plt .show (block = False )
171+ plot_heatmaps (images , probs , Path (f"landmark_heatmap_probs_{ view } _{ seed } .gif" ))
156172
157173 # visualise landmarks
158- fig = plot_landmarks (images , coords )
159- fig .savefig (f"landmark_heatmap_landmark_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
160- plt .show (block = False )
174+ plot_landmarks (images , coords , Path (f"landmark_heatmap_landmark_{ view } _{ seed } .gif" ))
161175
162176 # visualise LV length changes
163- fig = plot_lv (coords )
164- plt .savefig (f"landmark_heatmap_gls_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
165- plt .show (block = False )
177+ plot_lv (coords , Path (f"landmark_heatmap_gls_{ view } _{ seed } .png" ))
166178
167179
168180if __name__ == "__main__" :
0 commit comments