Skip to content

Commit 6ccba55

Browse files
committed
feat: add colormap coloring to thumbnail grid overlay annotations
1 parent f8f9239 commit 6ccba55

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

cellseg_models_pytorch/wsi/image.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union
1+
from typing import Optional, Sequence, Union
22

33
import numpy as np
44

@@ -9,6 +9,7 @@
99
except ImportError:
1010
_has_matplotlib = False
1111

12+
from matplotlib import colormaps
1213
from PIL import Image, ImageDraw, ImageFont
1314

1415
from .tiles import _divide_xywh
@@ -34,6 +35,10 @@ def get_annotated_image(
3435
text_proportion: float = 0.75,
3536
text_font: str = "monospace",
3637
alpha: float = 0.0,
38+
cmap: str = None,
39+
values: np.ndarray = None,
40+
breaks: Sequence[float] = None,
41+
n_bins: int = 30,
3742
) -> Image.Image:
3843
"""Function to draw tiles to an image. Useful for visualising tiles/predictions.
3944
@@ -65,6 +70,14 @@ def get_annotated_image(
6570
Passed to matplotlib's `fontManager.find_font` function.
6671
alpha (float, default=0.0):
6772
Alpha value for blending the original image and drawn image.
73+
cmap (str, default=None):
74+
Colormap to use for the tiles. E.g. "viridis", "plasma", "inferno".
75+
values (np.ndarray, default=None):
76+
Values to map to the colormap. Must be same length as `coordinates`.
77+
breaks (Sequence[float], default=None):
78+
Breakpoints for the colormap. If not provided, will be computed from `values`.
79+
n_bins (int, default=30):
80+
Number of bins to use for the colormap.
6881
6982
Raises:
7083
ValueError: Text item length does not match length of coordinates.
@@ -91,14 +104,33 @@ def get_annotated_image(
91104
font = None
92105
annotated = image.copy()
93106
draw = ImageDraw.Draw(annotated)
107+
108+
pal = None
109+
if cmap is not None and values is not None:
110+
if values.dtype.kind in ["U", "S", "O", "i"]:
111+
unique_vals = np.unique(values)
112+
values = np.searchsorted(unique_vals, values)
113+
breaks = np.arange(len(unique_vals))
114+
elif breaks is None:
115+
min_val = np.round((values.min()), 1)
116+
max_val = np.round((values.max()), 1)
117+
step = np.round(((max_val - min_val) / n_bins), 3)
118+
breaks = np.arange(min_val, max_val, step)
119+
120+
pal = colormaps.get_cmap(cmap).resampled(len(breaks))
121+
94122
for idx, (xywh, text) in enumerate(zip(coordinates, text_items)):
95123
# Downscale coordinates.
96124
x, y, w, h = _divide_xywh(xywh, downsample)
97125
# Draw rectangle.
126+
rgb_uint8 = None
127+
if pal is not None:
128+
rgb_uint8 = tuple((np.array(pal(values[idx])) * 255).astype(np.uint8))
129+
98130
draw.rectangle(
99131
((x, y), (x + w, y + h)),
100132
fill=rectangle_fill,
101-
outline=rectangle_outline,
133+
outline=rectangle_outline if rgb_uint8 is None else rgb_uint8,
102134
width=rectangle_width,
103135
)
104136
if text is not None:

0 commit comments

Comments
 (0)