Skip to content

Commit cdbda05

Browse files
authored
Improve colors in label map plots (#1362)
1 parent e4f62d4 commit cdbda05

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

src/torchio/visualization.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .types import TypePath
2323

2424
if TYPE_CHECKING:
25+
from matplotlib.colors import BoundaryNorm
2526
from matplotlib.colors import ListedColormap
2627

2728

@@ -42,19 +43,26 @@ def rotate(image: np.ndarray, *, radiological: bool = True, n: int = -1) -> np.n
4243
return image
4344

4445

45-
def _create_categorical_colormap(data: torch.Tensor) -> ListedColormap:
46+
def _create_categorical_colormap(
47+
data: torch.Tensor,
48+
) -> tuple[ListedColormap, BoundaryNorm]:
4649
num_classes = int(data.max())
4750
mpl, _ = import_mpl_plt()
4851

49-
if num_classes == 1: # just do white
50-
distinct_colors = [(1, 1, 1)]
51-
else:
52+
colors = [
53+
(0, 0, 0), # black for background
54+
(1, 1, 1), # white for class 1
55+
]
56+
if num_classes > 1:
5257
from .external.imports import get_distinctipy
5358

5459
distinctipy = get_distinctipy()
55-
distinct_colors = distinctipy.get_colors(num_classes, rng=0)
56-
colors = [(0, 0, 0), *distinct_colors] # prepend black
57-
return mpl.colors.ListedColormap(colors)
60+
distinct_colors = distinctipy.get_colors(num_classes - 1, rng=0)
61+
colors.extend(distinct_colors)
62+
boundaries = np.arange(-0.5, num_classes + 1.5, 1)
63+
colormap = mpl.colors.ListedColormap(colors)
64+
boundary_norm = mpl.colors.BoundaryNorm(boundaries, ncolors=colormap.N)
65+
return colormap, boundary_norm
5866

5967

6068
def plot_volume(
@@ -105,9 +113,14 @@ def plot_volume(
105113
slices = slice_x, slice_y, slice_z
106114
slice_x, slice_y, slice_z = color_labels(slices, cmap)
107115
else:
116+
boundary_norm = None
108117
if cmap is None:
109-
cmap = _create_categorical_colormap(data) if is_label else 'gray'
118+
if is_label:
119+
cmap, boundary_norm = _create_categorical_colormap(data)
120+
else:
121+
cmap = 'gray'
110122
imshow_kwargs['cmap'] = cmap
123+
imshow_kwargs['norm'] = boundary_norm
111124

112125
if is_label:
113126
imshow_kwargs['interpolation'] = 'none'
@@ -155,14 +168,14 @@ def plot_volume(
155168
},
156169
}
157170

158-
for title, info in slices_dict.items():
171+
for axis_title, info in slices_dict.items():
159172
axis = info['axis']
160173
axis.imshow(info['slice'], aspect=info['aspect'], **imshow_kwargs)
161174
if xlabels:
162175
axis.set_xlabel(info['xlabel'])
163176
axis.set_ylabel(info['ylabel'])
164177
axis.invert_xaxis()
165-
axis.set_title(title)
178+
axis.set_title(axis_title)
166179

167180
plt.tight_layout()
168181
if title is not None:

0 commit comments

Comments
 (0)