diff --git a/tests/test_utils.py b/tests/test_utils.py index 677e3f788..1b07eba87 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,10 +12,14 @@ import numpy as np import pandas as pd import pytest +import tifffile import torch +import zarr +from defusedxml import ElementTree as ET # noqa: N817 from PIL import Image from requests import HTTPError from shapely.geometry import Polygon +from tifffile import TiffFile from tests.test_annotation_stores import cell_polygon from tiatoolbox import rcParam, utils @@ -1858,3 +1862,123 @@ def test_torch_compile_compatibility(caplog: pytest.LogCaptureFixture) -> None: is_torch_compile_compatible() assert "torch.compile" in caplog.text + + +# Tests for OME tiff writer + + +def get_ome_metadata(tiff_path: Path) -> str | None: + """Extracts the OME metadata string from a TIFF file.""" + with TiffFile(tiff_path) as tif: + if tif.ome_metadata: + return tif.ome_metadata + return None + + +def assert_ome_metadata_value( + ome_xml: ET.Element, tag: str, expected_value: str +) -> None: + """Asserts the value of a specific OME metadata tag (as an attribute).""" + namespace = "{http://www.openmicroscopy.org/Schemas/OME/2016-06}" + image_elements = ome_xml.findall(f".//{namespace}Image") + if image_elements: + pixels_elements = image_elements[0].findall(f"./{namespace}Pixels") + if pixels_elements: + actual_value = pixels_elements[0].get(tag) + assert actual_value == expected_value, ( + f"Expected attribute '{tag}' to be '{expected_value}', " + f"but got '{actual_value}'." + ) + return + + # If we reach here, the tag or attribute was not found + pytest.fail(f"Attribute '{tag}' not found in OME metadata.") + + +def test_iwrite_probability_heatmap_as_ome_tiff_errors(tmp_path: Path) -> None: + """Test expected errors in `write_probability_heatmap_as_ome_tiff`.""" + probability = np.zeros(shape=(256, 256, 3)) + + # Input image must have 2 (CY) dimensions. + with pytest.raises(ValueError, match=r".*must have 2 \(YX\).*"): + misc.write_probability_heatmap_as_ome_tiff( + image_path=tmp_path / "failed_test.tif", + probability=probability, + ) + + probability = np.zeros(shape=(256, 256, 3)) + probability = torch.from_numpy(probability) + + # Input image must be a NumPy array or a Zarr array. + with pytest.raises(TypeError, match=r".*must be a NumPy array or a Zarr.*"): + misc.write_probability_heatmap_as_ome_tiff( + image_path=tmp_path / "failed_test.tif", + probability=probability, + ) + + +def test_save_numpy_array_proability_ome_tiff( + tmp_path: Path, source_image: Path +) -> None: + """Tests saving a basic NumPy array.""" + image_path = tmp_path / "numpy_image.ome.tif" + probability = utils.imread(source_image) + probability_0 = probability[:, :, 0] + misc.write_probability_heatmap_as_ome_tiff( + image_path=image_path, + probability=probability_0, + tile_size=(64, 64), + mpp=(0.5, 0.5), + levels=2, + colormap=cv2.COLORMAP_JET, + ) + assert image_path.is_file() + saved_img = tifffile.imread(image_path) + assert probability.shape == saved_img.shape + assert probability.dtype == saved_img.dtype + ome_xml = ET.fromstring(get_ome_metadata(image_path)) + assert ome_xml is not None + + assert_ome_metadata_value(ome_xml, "SizeY", str(probability.shape[0])) + assert_ome_metadata_value(ome_xml, "SizeX", str(probability.shape[1])) + assert_ome_metadata_value(ome_xml, "SizeC", str(3)) + assert_ome_metadata_value(ome_xml, "DimensionOrder", "XYCZT") + assert_ome_metadata_value(ome_xml, "PhysicalSizeX", "0.5") + assert_ome_metadata_value(ome_xml, "PhysicalSizeY", "0.5") + assert_ome_metadata_value(ome_xml, "PhysicalSizeXUnit", "µm") + assert_ome_metadata_value(ome_xml, "PhysicalSizeYUnit", "µm") + + +def test_save_zarr_array_probability_ome_tiff( + tmp_path: Path, source_image: Path +) -> None: + """Tests saving a Zarr array with uint8 dtype.""" + image_path = tmp_path / "zarr_uint8_image.ome.tif" + + img = utils.imread(source_image) + probability = img[:, 0:200, 0] + img_zarr = zarr.zeros(shape=probability.shape, dtype=np.uint8) + img_zarr[:] = probability + + misc.write_probability_heatmap_as_ome_tiff( + image_path, + img_zarr, + tile_size=(32, 32), + levels=2, + colormap=cv2.COLORMAP_INFERNO, + ) + assert image_path.is_file() + saved_img = tifffile.imread(image_path, squeeze=True) + assert img_zarr.shape == saved_img.shape[0:2] + assert img_zarr.dtype == saved_img.dtype + ome_xml = ET.fromstring(get_ome_metadata(image_path)) + assert ome_xml is not None + + assert_ome_metadata_value(ome_xml, "SizeY", str(img_zarr.shape[0])) + assert_ome_metadata_value(ome_xml, "SizeX", str(img_zarr.shape[1])) + assert_ome_metadata_value(ome_xml, "SizeC", str(3)) + assert_ome_metadata_value(ome_xml, "DimensionOrder", "XYCZT") + assert_ome_metadata_value(ome_xml, "PhysicalSizeX", "0.25") + assert_ome_metadata_value(ome_xml, "PhysicalSizeY", "0.25") + assert_ome_metadata_value(ome_xml, "PhysicalSizeXUnit", "µm") + assert_ome_metadata_value(ome_xml, "PhysicalSizeYUnit", "µm") diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index a12dd96b2..5fb77ae9a 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd import requests +import tifffile import yaml import zarr from filelock import FileLock @@ -23,12 +24,14 @@ from shapely.geometry import Polygon from shapely.geometry import shape as feature2geometry from skimage import exposure +from tqdm import trange from tiatoolbox import logger from tiatoolbox.annotation.storage import Annotation, AnnotationStore, SQLiteStore from tiatoolbox.utils.exceptions import FileNotSupportedError if TYPE_CHECKING: # pragma: no cover + from collections.abc import Iterator from os import PathLike from shapely import geometry @@ -160,7 +163,7 @@ def imwrite(image_path: PathLike, img: np.ndarray) -> None: def imread(image_path: PathLike, as_uint8: bool | None = None) -> np.ndarray: - """Read an image as numpy array. + """Read an image as a NumPy array. Args: image_path (PathLike): @@ -1283,6 +1286,117 @@ def dict_to_store( return store +def _tiles( + in_img: np.ndarray | zarr.core.Array, + tile_size: tuple[int, int], + colormap: int = cv2.COLORMAP_JET, + level: int = 0, +) -> Iterator[np.ndarray]: + for y in trange(0, in_img.shape[0], tile_size[0]): + for x in range(0, in_img.shape[1], tile_size[1]): + in_img_ = in_img[ + y : y + tile_size[0] : 2**level, x : x + tile_size[1] : 2**level + ] + yield cv2.applyColorMap(in_img_, colormap) + + +def write_probability_heatmap_as_ome_tiff( + image_path: Path, + probability: np.ndarray | zarr.core.Array, + tile_size: tuple[int, int] = (64, 64), + levels: int = 2, + mpp: tuple[float, float] = (0.25, 0.25), + colormap: int = cv2.COLORMAP_JET, +) -> None: + """Saves output probability maps from segmentation models as heatmaps. + + This function converts the probability maps from individual classes to heatmaps + and saves them as pyramidal ome tiffs. + + Args: + image_path (Path): + File path (including extension) to save image to. + probability (np.ndarray or zarr.core.Array): + The input image data in YXC (Height, Width, Channels) format. + tile_size (tuple): + Tile/Chunk size (YX/HW) for writing the tiff file. + Only allows tile shapes allowed by tifffile. Default is (64, 64). + levels (int): + Number of levels for saving pyramidal ome tiffs. Default is 2. + mpp (tuple[float, float]): + Tuple of mpp values in y and x (YX/HW). Default is (0.25, 0.25). + colormap (int): + Colormap to save the heatmaps. Default is 2 (cv2.COLORMAP_JET). + + Raises: + TypeError: + If the input `img` is not a NumPy or Zarr array or does not have 3 + dimensions. + ValueError: + If input dimensions is not 3 (HWC) dimensions. + + Examples: + >>> probability_map = imread("path/to/probability_map") + >>> write_probability_heatmap_as_ome_tiff( + ... image_path=image_path, + ... probability=probability_map, + ... tile_size=(64, 64), + ... class_name="tumor", + ... levels=2, + ... mpp=(0.5, 0.5), + ... colormap=cv2.COLORMAP_JET, + ... ) + + """ + if not isinstance(probability, (zarr.core.Array, np.ndarray)): + msg = "Input 'probability' must be a NumPy array or a Zarr array." + raise TypeError(msg) + + if probability.ndim != 2: # noqa: PLR2004 + msg = "Input 'probability' must have 2 (YX) dimensions." + raise ValueError(msg) + + ome_metadata = { + "axes": "YXC", + "PhysicalSizeX": mpp[1], + "PhysicalSizeXUnit": "µm", + "PhysicalSizeY": mpp[0], + "PhysicalSizeYUnit": "µm", + } + + h = probability.shape[0] + w = probability.shape[1] + + with tifffile.TiffWriter(image_path, bigtiff=True, ome=True) as tif: + tif.write( + _tiles(in_img=probability, tile_size=tile_size, colormap=colormap), + dtype="uint8", + shape=(h, w, 3), + tile=tile_size, + compression="jpeg", + metadata=ome_metadata, + subifds=levels - 1, + ) + + for level_ in range(1, levels): + tif.write( + _tiles( + in_img=probability, + tile_size=tile_size, + colormap=colormap, + level=level_, + ), + dtype="uint8", + shape=(h // 2**level_, w // 2**level_, 3), + tile=(tile_size[0] // 2**level_, tile_size[1] // 2**level_), + compression="jpeg", + subfiletype=0, + ) + + msg = f"Image saved as OME-TIFF to {image_path}." + logger.info(msg) + + def dict_to_zarr( raw_predictions: dict, save_path: Path,