Skip to content

Commit caf0602

Browse files
authored
Use colorcet for categorical colormaps (#1365)
1 parent b3bcdfe commit caf0602

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ dependencies = [
4949

5050
[project.optional-dependencies]
5151
csv = ["pandas>=1"]
52-
plot = ["distinctipy>=1.3.4", "matplotlib>=3.4"]
52+
plot = ["colorcet", "matplotlib>=3.4"]
5353
video = ["ffmpeg-python>=0.2.0"]
5454
sklearn = [
5555
"scikit-learn>=1.6.1",

src/torchio/external/imports.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def get_pandas() -> ModuleType:
2626
return _check_and_import(module='pandas', extra='csv')
2727

2828

29-
def get_distinctipy() -> ModuleType:
30-
return _check_and_import(module='distinctipy', extra='plot')
29+
def get_colorcet() -> ModuleType:
30+
return _check_and_import(module='colorcet', extra='plot')
3131

3232

3333
def get_ffmpeg() -> ModuleType:

src/torchio/visualization.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4+
from itertools import cycle
45
from pathlib import Path
56
from typing import TYPE_CHECKING
67
from typing import Any
@@ -47,6 +48,7 @@ def rotate(image: np.ndarray, *, radiological: bool = True, n: int = -1) -> np.n
4748

4849
def _create_categorical_colormap(
4950
data: torch.Tensor,
51+
cmap_name: str = 'glasbey_category10',
5052
) -> tuple[ListedColormap, BoundaryNorm]:
5153
num_classes = int(data.max())
5254
mpl, _ = import_mpl_plt()
@@ -56,10 +58,12 @@ def _create_categorical_colormap(
5658
(1, 1, 1), # white for class 1
5759
]
5860
if num_classes > 1:
59-
from .external.imports import get_distinctipy
61+
from .external.imports import get_colorcet
6062

61-
distinctipy = get_distinctipy()
62-
distinct_colors = distinctipy.get_colors(num_classes - 1, rng=0)
63+
colorcet = get_colorcet()
64+
cmap = getattr(colorcet.cm, cmap_name)
65+
color_cycle = cycle(cmap.colors)
66+
distinct_colors = [next(color_cycle) for _ in range(num_classes - 1)]
6367
colors.extend(distinct_colors)
6468
boundaries = np.arange(-0.5, num_classes + 1.5, 1)
6569
colormap = mpl.colors.ListedColormap(colors)

0 commit comments

Comments
 (0)