11from __future__ import annotations
22
33import warnings
4+ from itertools import cycle
45from pathlib import Path
56from typing import TYPE_CHECKING
67from typing import Any
@@ -47,6 +48,7 @@ def rotate(image: np.ndarray, *, radiological: bool = True, n: int = -1) -> np.n
4748
4849def _create_categorical_colormap (
4950 data : torch .Tensor ,
51+ cmap_name : str = 'glasbey_category10' ,
5052) -> tuple [ListedColormap , BoundaryNorm ]:
5153 num_classes = int (data .max ())
5254 mpl , _ = import_mpl_plt ()
@@ -56,10 +58,12 @@ def _create_categorical_colormap(
5658 (1 , 1 , 1 ), # white for class 1
5759 ]
5860 if num_classes > 1 :
59- from .external .imports import get_distinctipy
61+ from .external .imports import get_colorcet
6062
61- distinctipy = get_distinctipy ()
62- distinct_colors = distinctipy .get_colors (num_classes - 1 , rng = 0 )
63+ colorcet = get_colorcet ()
64+ cmap = getattr (colorcet .cm , cmap_name )
65+ color_cycle = cycle (cmap .colors )
66+ distinct_colors = [next (color_cycle ) for _ in range (num_classes - 1 )]
6367 colors .extend (distinct_colors )
6468 boundaries = np .arange (- 0.5 , num_classes + 1.5 , 1 )
6569 colormap = mpl .colors .ListedColormap (colors )
0 commit comments