2222from .types import TypePath
2323
2424if TYPE_CHECKING :
25+ from matplotlib .colors import BoundaryNorm
2526 from matplotlib .colors import ListedColormap
2627
2728
@@ -42,19 +43,26 @@ def rotate(image: np.ndarray, *, radiological: bool = True, n: int = -1) -> np.n
4243 return image
4344
4445
45- def _create_categorical_colormap (data : torch .Tensor ) -> ListedColormap :
46+ def _create_categorical_colormap (
47+ data : torch .Tensor ,
48+ ) -> tuple [ListedColormap , BoundaryNorm ]:
4649 num_classes = int (data .max ())
4750 mpl , _ = import_mpl_plt ()
4851
49- if num_classes == 1 : # just do white
50- distinct_colors = [(1 , 1 , 1 )]
51- else :
52+ colors = [
53+ (0 , 0 , 0 ), # black for background
54+ (1 , 1 , 1 ), # white for class 1
55+ ]
56+ if num_classes > 1 :
5257 from .external .imports import get_distinctipy
5358
5459 distinctipy = get_distinctipy ()
55- distinct_colors = distinctipy .get_colors (num_classes , rng = 0 )
56- colors = [(0 , 0 , 0 ), * distinct_colors ] # prepend black
57- return mpl .colors .ListedColormap (colors )
60+ distinct_colors = distinctipy .get_colors (num_classes - 1 , rng = 0 )
61+ colors .extend (distinct_colors )
62+ boundaries = np .arange (- 0.5 , num_classes + 1.5 , 1 )
63+ colormap = mpl .colors .ListedColormap (colors )
64+ boundary_norm = mpl .colors .BoundaryNorm (boundaries , ncolors = colormap .N )
65+ return colormap , boundary_norm
5866
5967
6068def plot_volume (
@@ -105,9 +113,14 @@ def plot_volume(
105113 slices = slice_x , slice_y , slice_z
106114 slice_x , slice_y , slice_z = color_labels (slices , cmap )
107115 else :
116+ boundary_norm = None
108117 if cmap is None :
109- cmap = _create_categorical_colormap (data ) if is_label else 'gray'
118+ if is_label :
119+ cmap , boundary_norm = _create_categorical_colormap (data )
120+ else :
121+ cmap = 'gray'
110122 imshow_kwargs ['cmap' ] = cmap
123+ imshow_kwargs ['norm' ] = boundary_norm
111124
112125 if is_label :
113126 imshow_kwargs ['interpolation' ] = 'none'
@@ -155,14 +168,14 @@ def plot_volume(
155168 },
156169 }
157170
158- for title , info in slices_dict .items ():
171+ for axis_title , info in slices_dict .items ():
159172 axis = info ['axis' ]
160173 axis .imshow (info ['slice' ], aspect = info ['aspect' ], ** imshow_kwargs )
161174 if xlabels :
162175 axis .set_xlabel (info ['xlabel' ])
163176 axis .set_ylabel (info ['ylabel' ])
164177 axis .invert_xaxis ()
165- axis .set_title (title )
178+ axis .set_title (axis_title )
166179
167180 plt .tight_layout ()
168181 if title is not None :
0 commit comments