1- from typing import Optional , Union
1+ from typing import Optional , Sequence , Union
22
33import numpy as np
44
99except ImportError :
1010 _has_matplotlib = False
1111
12+ from matplotlib import colormaps
1213from PIL import Image , ImageDraw , ImageFont
1314
1415from .tiles import _divide_xywh
@@ -34,6 +35,10 @@ def get_annotated_image(
3435 text_proportion : float = 0.75 ,
3536 text_font : str = "monospace" ,
3637 alpha : float = 0.0 ,
38+ cmap : str = None ,
39+ values : np .ndarray = None ,
40+ breaks : Sequence [float ] = None ,
41+ n_bins : int = 30 ,
3742) -> Image .Image :
3843 """Function to draw tiles to an image. Useful for visualising tiles/predictions.
3944
@@ -65,6 +70,14 @@ def get_annotated_image(
6570 Passed to matplotlib's `fontManager.find_font` function.
6671 alpha (float, default=0.0):
6772 Alpha value for blending the original image and drawn image.
73+ cmap (str, default=None):
74+ Colormap to use for the tiles. E.g. "viridis", "plasma", "inferno".
75+ values (np.ndarray, default=None):
76+ Values to map to the colormap. Must be same length as `coordinates`.
77+ breaks (Sequence[float], default=None):
78+ Breakpoints for the colormap. If not provided, will be computed from `values`.
79+ n_bins (int, default=30):
80+ Number of bins to use for the colormap.
6881
6982 Raises:
7083 ValueError: Text item length does not match length of coordinates.
@@ -91,14 +104,33 @@ def get_annotated_image(
91104 font = None
92105 annotated = image .copy ()
93106 draw = ImageDraw .Draw (annotated )
107+
108+ pal = None
109+ if cmap is not None and values is not None :
110+ if values .dtype .kind in ["U" , "S" , "O" , "i" ]:
111+ unique_vals = np .unique (values )
112+ values = np .searchsorted (unique_vals , values )
113+ breaks = np .arange (len (unique_vals ))
114+ elif breaks is None :
115+ min_val = np .round ((values .min ()), 1 )
116+ max_val = np .round ((values .max ()), 1 )
117+ step = np .round (((max_val - min_val ) / n_bins ), 3 )
118+ breaks = np .arange (min_val , max_val , step )
119+
120+ pal = colormaps .get_cmap (cmap ).resampled (len (breaks ))
121+
94122 for idx , (xywh , text ) in enumerate (zip (coordinates , text_items )):
95123 # Downscale coordinates.
96124 x , y , w , h = _divide_xywh (xywh , downsample )
97125 # Draw rectangle.
126+ rgb_uint8 = None
127+ if pal is not None :
128+ rgb_uint8 = tuple ((np .array (pal (values [idx ])) * 255 ).astype (np .uint8 ))
129+
98130 draw .rectangle (
99131 ((x , y ), (x + w , y + h )),
100132 fill = rectangle_fill ,
101- outline = rectangle_outline ,
133+ outline = rectangle_outline if rgb_uint8 is None else rgb_uint8 ,
102134 width = rectangle_width ,
103135 )
104136 if text is not None :
0 commit comments