1212from cinema import ConvUNetR , heatmap_soft_argmax
1313
1414
15+ def plot_heatmaps (images : np .ndarray , probs : np .ndarray , n_cols : int = 5 ) -> plt .Figure :
16+ """Plot heatmaps.
17+
18+ Args:
19+ images: (x, y, t)
20+ probs: (3, x, y, t)
21+ n_cols: number of columns
22+
23+ Returns:
24+ figure
25+ """
26+ n_frames = probs .shape [- 1 ]
27+ n_rows = n_frames // n_cols
28+ fig , axs = plt .subplots (n_rows , n_cols , figsize = (n_cols , n_rows ), dpi = 300 )
29+ for i in range (n_rows ):
30+ for j in range (n_cols ):
31+ t = i * n_cols + j
32+ axs [i , j ].imshow (images [..., 0 , t ], cmap = "gray" )
33+ axs [i , j ].imshow (probs [0 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
34+ axs [i , j ].imshow (probs [1 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
35+ axs [i , j ].imshow (probs [2 , ..., t , None ] * np .array ([1.0 , 0.0 , 0.0 , 1.0 ]))
36+ axs [i , j ].set_xticks ([])
37+ axs [i , j ].set_yticks ([])
38+ if j == 0 :
39+ axs [i , j ].set_ylabel (f"t = { t } " )
40+ fig .tight_layout ()
41+ fig .subplots_adjust (wspace = 0 , hspace = 0 )
42+ return fig
43+
44+
45+ def plot_landmarks (images : np .ndarray , coords : np .ndarray , n_cols : int = 5 ) -> plt .Figure :
46+ """Plot landmarks.
47+
48+ Args:
49+ images: (x, y, t)
50+ coords: (6, t)
51+ n_cols: number of columns
52+
53+ Returns:
54+ figure
55+ """
56+ n_frames = images .shape [- 1 ]
57+ n_rows = n_frames // n_cols
58+ fig , axs = plt .subplots (n_rows , n_cols , figsize = (n_cols , n_rows ), dpi = 300 )
59+ for i in range (n_rows ):
60+ for j in range (n_cols ):
61+ t = i * n_cols + j
62+
63+ # draw predictions with cross
64+ preds = images [..., t ] * np .array ([1 , 1 , 1 ])[None , None , :]
65+ preds = preds .clip (0 , 255 ).astype (np .uint8 )
66+ for k in range (3 ):
67+ pred_x , pred_y = coords [2 * k , t ], coords [2 * k + 1 , t ]
68+ x1 , x2 = max (0 , pred_x - 9 ), min (preds .shape [0 ], pred_x + 10 )
69+ y1 , y2 = max (0 , pred_y - 9 ), min (preds .shape [1 ], pred_y + 10 )
70+ preds [pred_x , y1 :y2 ] = [255 , 0 , 0 ]
71+ preds [x1 :x2 , pred_y ] = [255 , 0 , 0 ]
72+
73+ axs [i , j ].imshow (preds )
74+ axs [i , j ].set_xticks ([])
75+ axs [i , j ].set_yticks ([])
76+ if j == 0 :
77+ axs [i , j ].set_ylabel (f"t = { t } " )
78+ fig .tight_layout ()
79+ fig .subplots_adjust (wspace = 0 , hspace = 0 )
80+ return fig
81+
82+
83+ def plot_lv (coords : np .ndarray ) -> plt .Figure :
84+ """Plot GL shortening.
85+
86+ Args:
87+ coords: (6, t)
88+
89+ Returns:
90+ figure
91+ """
92+ # GL shortening
93+ x1 , y1 = coords [0 ], coords [1 ]
94+ x2 , y2 = coords [2 ], coords [3 ]
95+ x3 , y3 = coords [4 ], coords [5 ]
96+ lv_lengths = (((x1 + x2 ) / 2 - x3 ) ** 2 + ((y1 + y2 ) / 2 - y3 ) ** 2 ) ** 0.5
97+ gls = (max (lv_lengths ) - min (lv_lengths )) / max (lv_lengths ) * 100
98+
99+ # MAPSE
100+ ed_idx = np .argmin (lv_lengths )
101+ es_idx = np .argmax (lv_lengths )
102+ x1_ed , y1_ed = coords [0 , ed_idx ], coords [1 , ed_idx ]
103+ x2_ed , y2_ed = coords [2 , ed_idx ], coords [3 , ed_idx ]
104+ x1_es , y1_es = coords [0 , es_idx ], coords [1 , es_idx ]
105+ x2_es , y2_es = coords [2 , es_idx ], coords [3 , es_idx ]
106+ mapse = (
107+ ((x1_ed - x1_es ) ** 2 + (y1_ed - y1_es ) ** 2 ) ** 0.5 + ((x2_ed - x2_es ) ** 2 + (y2_ed - y2_es ) ** 2 ) ** 0.5
108+ ) / 2
109+
110+ fig = plt .figure (figsize = (4 , 4 ), dpi = 120 )
111+ plt .plot (lv_lengths , color = "#82B366" , label = "LV" )
112+ plt .xlabel ("Frame" )
113+ plt .ylabel ("Length (mm)" )
114+ plt .title (f"GLS = { gls :.2f} %, MAPSE = { mapse :.2f} mm" )
115+ plt .legend (loc = "lower right" )
116+ return fig
117+
118+
15119def run (view : str , seed : int , device : torch .device , dtype : torch .dtype ) -> None :
16120 """Run landmark localization on LAX images using fine-tuned checkpoint."""
17121 # load model
@@ -31,8 +135,7 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
31135 images = np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / f"data/ukb/1/1_{ view } .nii.gz" )))
32136 n_frames = images .shape [- 1 ]
33137 probs_list = []
34- preds_list = []
35- lv_lengths = []
138+ coords_list = []
36139 for t in tqdm (range (n_frames ), total = n_frames ):
37140 batch = transform ({view : torch .from_numpy (images [None , ..., 0 , t ])})
38141 batch = {k : v [None , ...].to (device = device , dtype = dtype ) for k , v in batch .items ()}
@@ -42,68 +145,23 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
42145 probs_list .append (probs [0 ].detach ().to (torch .float32 ).cpu ().numpy ())
43146 coords = heatmap_soft_argmax (probs )[0 ].numpy ()
44147 coords = [int (x ) for x in coords ]
45-
46- # draw predictions with cross
47- preds = images [..., t ] * np .array ([1 , 1 , 1 ])[None , None , :]
48- preds = preds .clip (0 , 255 ).astype (np .uint8 )
49- for i in range (3 ):
50- pred_x , pred_y = coords [2 * i ], coords [2 * i + 1 ]
51- x1 , x2 = max (0 , pred_x - 9 ), min (preds .shape [0 ], pred_x + 10 )
52- y1 , y2 = max (0 , pred_y - 9 ), min (preds .shape [1 ], pred_y + 10 )
53- preds [pred_x , y1 :y2 ] = [255 , 0 , 0 ]
54- preds [x1 :x2 , pred_y ] = [255 , 0 , 0 ]
55- preds_list .append (preds )
56-
57- # record LV length
58- x1 , y1 , x2 , y2 , x3 , y3 = coords
59- lv_len = (((x1 + x2 ) / 2 - x3 ) ** 2 + ((y1 + y2 ) / 2 - y3 ) ** 2 ) ** 0.5
60- lv_lengths .append (lv_len )
148+ coords_list .append (coords )
61149 probs = np .stack (probs_list , axis = - 1 ) # (3, x, y, t)
62- preds = np .stack (preds_list , axis = - 1 ) # (3, x, y , t)
150+ coords = np .stack (coords_list , axis = - 1 ) # (6 , t)
63151
64152 # visualise heatmaps
65- _ , axs = plt .subplots (10 , 5 , figsize = (10 , 20 ))
66- for i in range (10 ):
67- for j in range (5 ):
68- t = i * 5 + j
69- axs [i , j ].imshow (images [..., 0 , t ], cmap = "gray" )
70- axs [i , j ].imshow ((probs [0 , ..., t , None ]) * np .array ([108 / 255 , 142 / 255 , 191 / 255 , 1.0 ]))
71- axs [i , j ].imshow ((probs [1 , ..., t , None ]) * np .array ([214 / 255 , 182 / 255 , 86 / 255 , 1.0 ]))
72- axs [i , j ].imshow ((probs [2 , ..., t , None ]) * np .array ([130 / 255 , 179 / 255 , 102 / 255 , 1.0 ]))
73- axs [i , j ].set_xticks ([])
74- axs [i , j ].set_yticks ([])
75- if j == 0 :
76- axs [i , j ].set_ylabel (f"t = { t } " )
77- plt .subplots_adjust (wspace = 0.02 , hspace = 0.02 )
78- plt .savefig (f"landmark_heatmap_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
153+ fig = plot_heatmaps (images , probs )
154+ fig .savefig (f"landmark_heatmap_probs_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
79155 plt .show (block = False )
80156
81157 # visualise landmarks
82- _ , axs = plt .subplots (10 , 5 , figsize = (10 , 20 ))
83- for i in range (10 ):
84- for j in range (5 ):
85- t = i * 5 + j
86- axs [i , j ].imshow (preds [..., t ])
87- axs [i , j ].set_xticks ([])
88- axs [i , j ].set_yticks ([])
89- if j == 0 :
90- axs [i , j ].set_ylabel (f"t = { t } " )
91- plt .subplots_adjust (wspace = 0.02 , hspace = 0.02 )
92- plt .savefig (f"landmark_heatmap_landmark_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
158+ fig = plot_landmarks (images , coords )
159+ fig .savefig (f"landmark_heatmap_landmark_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
93160 plt .show (block = False )
94161
95162 # visualise LV length changes
96- plt .figure (figsize = (4 , 3 ))
97- if view == "lax_2c" :
98- # first frame is empty for this particular example
99- lv_lengths = lv_lengths [1 :]
100- lvef = (max (lv_lengths ) - min (lv_lengths )) / max (lv_lengths ) * 100
101- plt .plot (lv_lengths , color = "#82B366" , label = "LV" )
102- plt .xlabel ("Frame" )
103- plt .ylabel ("Length (mm)" )
104- plt .title (f"LVEF = { lvef :.2f} %" )
105- plt .legend (loc = "lower right" )
106- plt .savefig (f"landmark_heatmap_lv_length_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
163+ fig = plot_lv (coords )
164+ plt .savefig (f"landmark_heatmap_gls_{ view } _{ seed } .png" , dpi = 300 , bbox_inches = "tight" )
107165 plt .show (block = False )
108166
109167
0 commit comments