Skip to content

Commit 0176048

Browse files
authored
Merge branch 'develop' into bug-fix-numpy-upgrade
2 parents 0b7e023 + d5c1995 commit 0176048

File tree

5 files changed

+159
-36
lines changed

5 files changed

+159
-36
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
sudo apt update
3131
sudo apt-get install -y libopenjp2-7 libopenjp2-tools
3232
python -m pip install --upgrade pip
33-
python -m pip install ruff==0.12.2 pytest pytest-cov pytest-runner
33+
python -m pip install ruff==0.12.7 pytest pytest-cov pytest-runner
3434
pip install -r requirements/requirements.txt
3535
- name: Cache tiatoolbox static assets
3636
uses: actions/cache@v3

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ repos:
6060
- id: rst-inline-touching-normal # Detect mistake of inline code touching normal text in rst.
6161
- repo: https://github.com/astral-sh/ruff-pre-commit
6262
# Ruff version.
63-
rev: v0.12.2
63+
rev: v0.12.7
6464
hooks:
6565
- id: ruff
6666
args: [--fix, --exit-non-zero-on-fix]

requirements/requirements_dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pytest>=7.2.0
1010
pytest-cov>=4.0.0
1111
pytest-runner>=6.0
1212
pytest-xdist[psutil]
13-
ruff==0.12.2 # This will be updated by pre-commit bot to latest version
13+
ruff==0.12.7 # This will be updated by pre-commit bot to latest version
1414
toml>=0.10.2
1515
twine>=4.0.1
1616
wheel>=0.37.1

tiatoolbox/models/dataset/classification.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC):
163163
164164
"""
165165

166-
def __init__( # skipcq: PY-R1000 # noqa: PLR0915
166+
def __init__( # skipcq: PY-R1000
167167
self: WSIPatchDataset,
168168
img_path: str | Path,
169169
mode: str = "wsi",
@@ -262,40 +262,17 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
262262
raise ValueError(msg)
263263

264264
self.preproc_func = preproc_func
265-
img_path = Path(img_path)
266-
if mode == "wsi":
267-
self.reader = WSIReader.open(img_path)
268-
else:
269-
logger.warning(
270-
"WSIPatchDataset only reads image tile at "
271-
'`units="baseline"` and `resolution=1.0`.',
272-
stacklevel=2,
273-
)
274-
img = imread(img_path)
275-
axes = "YXS"[: len(img.shape)]
276-
# initialise metadata for VirtualWSIReader.
277-
# here, we simulate a whole-slide image, but with a single level.
278-
# ! should we expose this so that use can provide their metadata ?
279-
metadata = WSIMeta(
280-
mpp=np.array([1.0, 1.0]),
281-
axes=axes,
282-
objective_power=10,
283-
slide_dimensions=np.array(img.shape[:2][::-1]),
284-
level_downsamples=[1.0],
285-
level_dimensions=[np.array(img.shape[:2][::-1])],
286-
)
287-
# infer value such that read if mask provided is through
288-
# 'mpp' or 'power' as varying 'baseline' is locked atm
265+
self.img_path = Path(img_path)
266+
self.mode = mode
267+
self.reader = None
268+
reader = self._get_reader(self.img_path)
269+
if mode != "wsi":
289270
units = "mpp"
290271
resolution = 1.0
291-
self.reader = VirtualWSIReader(
292-
img,
293-
info=metadata,
294-
)
295272

296273
# may decouple into misc ?
297274
# the scaling factor will scale base level to requested read resolution/units
298-
wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units)
275+
wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
299276

300277
# use all patches, as long as it overlaps source image
301278
self.inputs = PatchExtractor.get_coordinates(
@@ -316,13 +293,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
316293
mask = np.array(mask > 0, dtype=np.uint8)
317294

318295
mask_reader = VirtualWSIReader(mask)
319-
mask_reader.info = self.reader.info
296+
mask_reader.info = reader.info
320297
elif auto_get_mask and mode == "wsi" and mask_path is None:
321298
# if no mask provided and `wsi` mode, generate basic tissue
322299
# mask on the fly
323-
mask_reader = self.reader.tissue_mask(resolution=1.25, units="power")
300+
mask_reader = reader.tissue_mask(resolution=1.25, units="power")
324301
# ? will this mess up ?
325-
mask_reader.info = self.reader.info
302+
mask_reader.info = reader.info
326303

327304
if mask_reader is not None:
328305
selected = PatchExtractor.filter_coordinates(
@@ -344,10 +321,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
344321
# Perform check on the input
345322
self._check_input_integrity(mode="wsi")
346323

324+
def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader:
325+
"""Get a reader for the image."""
326+
if self.mode == "wsi":
327+
reader = WSIReader.open(img_path)
328+
else:
329+
logger.warning(
330+
"WSIPatchDataset only reads image tile at "
331+
'`units="baseline"` and `resolution=1.0`.',
332+
stacklevel=2,
333+
)
334+
img = imread(img_path)
335+
axes = "YXS"[: len(img.shape)]
336+
# initialise metadata for VirtualWSIReader.
337+
# here, we simulate a whole-slide image, but with a single level.
338+
# ! should we expose this so that use can provide their metadata ?
339+
metadata = WSIMeta(
340+
mpp=np.array([1.0, 1.0]),
341+
axes=axes,
342+
objective_power=10,
343+
slide_dimensions=np.array(img.shape[:2][::-1]),
344+
level_downsamples=[1.0],
345+
level_dimensions=[np.array(img.shape[:2][::-1])],
346+
)
347+
# infer value such that read if mask provided is through
348+
# 'mpp' or 'power' as varying 'baseline' is locked atm
349+
reader = VirtualWSIReader(
350+
img,
351+
info=metadata,
352+
)
353+
return reader
354+
347355
def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
348356
"""Get an item from the dataset."""
349357
coords = self.inputs[idx]
350358
# Read image patch from the whole-slide image
359+
if self.reader is None:
360+
# only set the reader on first call so that it is initially picklable
361+
self.reader = self._get_reader(self.img_path)
351362
patch = self.reader.read_bounds(
352363
coords,
353364
resolution=self.resolution,

tiatoolbox/utils/visualization.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def random_colors(num_colors: int, *, bright: bool) -> np.ndarray:
5656
np.ndarray:
5757
Array of (r, g, b) colors.
5858
59+
Examples:
60+
>>> from tiatoolbox.utils.visualization import random_colors
61+
>>> colors = random_colors(10, bright=True)
62+
5963
"""
6064
brightness = 1.0 if bright else 0.7
6165
hsv = [(i / num_colors, 1, brightness) for i in range(num_colors)]
@@ -75,6 +79,15 @@ def colourise_image(img: np.ndarray, cmap: str = "viridis") -> np.ndarray:
7579
7680
Returns:
7781
img(ndarray): An RGB image.
82+
83+
Examples:
84+
>>> from tiatoolbox.utils.visualization import colourise_image
85+
>>> import numpy as np
86+
>>> # Generate a random example; replace with your own data
87+
>>> img = np.random.rand(255, 255)
88+
>>> # Example usage of colourise_image
89+
>>> coloured_image = colourise_image(img, 'viridis')
90+
7891
"""
7992
if len(img.shape) == 2: # noqa: PLR2004
8093
# Single channel, make into rgb with colormap.
@@ -124,6 +137,30 @@ def overlay_prediction_mask(
124137
If return_ax is True, return the matplotlib ax object. Else,
125138
return the overlay array.
126139
140+
Examples:
141+
>>> from tiatoolbox.utils.visualization import overlay_prediction_mask
142+
>>> import numpy as np
143+
>>> from matplotlib import pyplot as plt
144+
>>> # Generate a random example; replace with your own data
145+
>>> img = np.random.randint(0, 256, size=(256, 256, 3), dtype=np.uint8)
146+
>>> prediction = np.random.randint(0, 3, size=(256, 256), dtype=np.uint8)
147+
>>> label_info = {
148+
... 0: ("Background", (0, 0, 0)),
149+
... 1: ("Tumor", (255, 0, 0)),
150+
... 2: ("Stroma", (0, 255, 0))
151+
... }
152+
>>> # Example usage of overlay_prediction_mask
153+
>>> ax = overlay_prediction_mask(
154+
... img=img,
155+
... prediction=prediction,
156+
... alpha=0.5,
157+
... label_info=label_info,
158+
... min_val=0.0,
159+
... ax=None,
160+
... return_ax=True
161+
... )
162+
>>> plt.show()
163+
127164
"""
128165
# Validate inputs
129166
if img.shape[:2] != prediction.shape[:2]:
@@ -310,6 +347,25 @@ def overlay_probability_map(
310347
If return_ax is True, return the matplotlib ax object. Else,
311348
return the overlay array.
312349
350+
Examples:
351+
>>> from tiatoolbox.utils.visualization import overlay_probability_map
352+
>>> import numpy as np
353+
>>> from matplotlib import pyplot as plt
354+
>>> # Generate a random example; replace with your own data
355+
>>> img = np.random.randint(0, 256, size=(256, 256, 3), dtype=np.uint8)
356+
>>> probability_map = np.random.rand(256, 256).astype(np.float32)
357+
>>> # Example usage of overlay_probability_map
358+
>>> ax = overlay_probability_map(
359+
... img=img,
360+
... prediction=probability_map,
361+
... alpha=0.35,
362+
... colour_map="jet",
363+
... min_val=0.0,
364+
... ax=None,
365+
... return_ax=True,
366+
... )
367+
>>> plt.show()
368+
313369
"""
314370
prediction = prediction.astype(np.float32)
315371
img = _validate_overlay_probability_map(img, prediction, min_val)
@@ -455,6 +511,40 @@ def overlay_prediction_contours(
455511
:class:`numpy.ndarray`:
456512
The overlaid image.
457513
514+
Examples:
515+
>>> from tiatoolbox.utils.visualization import overlay_prediction_contours
516+
>>> import numpy as np
517+
>>> from matplotlib import pyplot as plt
518+
>>> # Generate a random example; replace with your own data
519+
>>> canvas = np.zeros((256, 256, 3), dtype=np.uint8)
520+
>>> inst_dict = {
521+
... 1: {
522+
... "type": 0,
523+
... "contour": [[50, 50], [60, 45], [70, 50],
524+
... [70, 60], [60, 65], [50, 60]],
525+
... "centroid": [60, 55]
526+
... },
527+
... 2: {
528+
... "type": 1,
529+
... "contour": [[100, 100], [120, 100], [120, 120], [100, 120]],
530+
... "centroid": [110, 110]
531+
... }
532+
... }
533+
>>> type_colours = {
534+
... 0: ("Type A", (0, 255, 0)),
535+
... 1: ("Type B", (0, 0, 255))
536+
... }
537+
>>> # Example usage of overlay_prediction_contours
538+
>>> overlaid_canvas = overlay_prediction_contours(
539+
... canvas=canvas,
540+
... inst_dict=inst_dict,
541+
... type_colours=type_colours,
542+
... line_thickness=1,
543+
... draw_dot=True
544+
... )
545+
>>> plt.imshow(overlaid_canvas)
546+
>>> plt.show()
547+
458548
"""
459549
overlay = np.copy(canvas)
460550

@@ -531,6 +621,28 @@ def plot_graph(
531621
edge_size (int):
532622
Line width of the edge.
533623
624+
Examples:
625+
>>> from tiatoolbox.utils.visualization import plot_graph
626+
>>> import numpy as np
627+
>>> # Generate a random example; replace with your own data
628+
>>> canvas = np.zeros((256, 256, 3), dtype=np.uint8)
629+
>>> num_nodes = 10
630+
>>> nodes = np.random.randint(0, 255, size=(num_nodes, 2))
631+
>>> num_edges = 15
632+
>>> edges = np.random.randint(0, num_nodes, size=(num_edges, 2))
633+
>>> node_colors = np.random.randint(0, 256, size=(num_nodes, 3))
634+
>>> edge_colors = np.random.randint(0, 256, size=(num_edges, 3))
635+
>>> # Example usage of overlay_prediction_contours
636+
>>> overlaid_canvas = plot_graph(
637+
... canvas=canvas,
638+
... nodes=nodes,
639+
... edges=edges,
640+
... node_colors=node_colors,
641+
... node_size=8,
642+
... edge_colors=edge_colors,
643+
... edge_size=3
644+
... )
645+
534646
"""
535647
if isinstance(node_colors, tuple):
536648
node_colors_list = np.array([node_colors] * len(nodes))

0 commit comments

Comments
 (0)