88from numpy .typing import NDArray
99from mpl_toolkits .axes_grid1 .axes_divider import Size , make_axes_locatable
1010from skimage .measure import find_contours
11- from matplotlib .cm import ScalarMappable
11+ from matplotlib .cm import ScalarMappable , _colormaps
1212
1313
1414def get_coolgraywarm (thresh : float = 3 , max : float = 7 ) -> matplotlib .colorbar .Colorbar :
@@ -154,6 +154,7 @@ def plot_frames_activ(
154154 z_thresh : float = 3 ,
155155 z_max : float = 11 ,
156156 bg_cmap : str = "gray" ,
157+ roi_colors : list [str ] | None = None ,
157158) -> tuple [plt .Axes , matplotlib .image .AxesImage ]:
158159 """Plot activation maps and background.
159160
@@ -192,11 +193,13 @@ def plot_frames_activ(
192193 origin = "lower" ,
193194 )
194195 if rois is not None :
195- for roi in rois :
196+ if roi_colors is None :
197+ roi_colors = ["r" , "g" , "b" , "y" ][: len (rois )]
198+ for roi , color in zip (rois , roi_colors ):
196199 roi_cut = roi [slices ][bbox ].squeeze ()
197200 contours = find_contours (roi_cut )
198201 for c in contours :
199- ax .plot (c [:, 1 ], c [:, 0 ], c = "cyan" , label = "ground-truth " , linewidth = 1 )
202+ ax .plot (c [:, 1 ], c [:, 0 ], c = color , label = "roi " , linewidth = 1 )
200203 ax .set_xticks ([])
201204 ax .set_yticks ([])
202205 return ax , im
@@ -217,6 +220,7 @@ def axis3dcut(
217220 z_thresh : float = 3 ,
218221 z_max : float = 11 ,
219222 tight_crop : bool = False ,
223+ roi_colors : list [str ] | None = None ,
220224) -> tuple [plt .Figure , plt .Axes , tuple [int , ...]]:
221225 """Display a 3D image with zscore and ground truth ROI.
222226
@@ -321,6 +325,7 @@ def axis3dcut(
321325 bg_cmap = bg_cmap ,
322326 z_thresh = z_thresh ,
323327 z_max = z_max ,
328+ roi_colors = roi_colors ,
324329 )
325330
326331 if cbar :
0 commit comments