diff --git a/src/omero_zarr/cli.py b/src/omero_zarr/cli.py index 98bb79a..034ba41 100644 --- a/src/omero_zarr/cli.py +++ b/src/omero_zarr/cli.py @@ -369,6 +369,9 @@ def _configure(self, parser: Parser) -> None: "Only applies when importing OME/METADATA.ome.xml." ), ) + import_cmd.add_argument( + "--labels", action="store_true", help="Also import labels if present" + ) @gateway_required def masks(self, args: argparse.Namespace) -> None: @@ -418,6 +421,7 @@ def import_cmd(self, args: argparse.Namespace) -> None: target=args.target, target_by_name=args.target_by_name, wait=args.wait, + labels=args.labels, ) def _lookup( diff --git a/src/omero_zarr/import_labels.py b/src/omero_zarr/import_labels.py new file mode 100644 index 0000000..b2a1afe --- /dev/null +++ b/src/omero_zarr/import_labels.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python + +# Copyright (C) 2025 University of Dundee & Open Microscopy Environment. +# All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from typing import Optional + +import numpy as np +from omero.gateway import BlitzGateway, ColorHolder +from omero.model import ImageI, MaskI, RoiI +from omero.rtypes import rdouble, rint, rstring +from zarr.creation import open_array +from zarr.errors import GroupNotFoundError +from zarr.hierarchy import open_group +from zarr.storage import Store, StoreLike + + +def load_attrs(store: StoreLike, path: Optional[str] = None) -> dict: + """ + Load the attrs from the root group or path subgroup + """ + root = open_group(store=store, mode="r", path=path) + attrs = root.attrs.asdict() + if "ome" in attrs: + attrs = attrs["ome"] + return attrs + + +def masks_from_labels_nd( + labels_nd: np.ndarray, axes: list[str], label_props: dict +) -> dict: + rois = {} + + colors_by_value = {} + if "colors" in label_props: + for color in label_props["colors"]: + pixel_value = color.get("label-value", None) + rgba = color.get("rgba", None) + if pixel_value and rgba and len(rgba) == 4: + colors_by_value[pixel_value] = rgba + + text_by_value = {} + if "properties" in label_props: + for props in label_props["properties"]: + pixel_value = props.get("label-value", None) + text = props.get("omero:text", None) + if pixel_value and text: + text_by_value[pixel_value] = text + + # For each label value, we create an ROI that + # contains 2D masks for each time point, channel, and z-slice. + for i in range(1, int(labels_nd.max()) + 1): + if not np.any(labels_nd == i): + continue + + masks = [] + bin_img = labels_nd == i + + sizes = {dim: labels_nd.shape[axes.index(dim)] for dim in axes} + size_t = sizes.get("t", 1) + size_c = sizes.get("c", 1) + size_z = sizes.get("z", 1) + + for t in range(size_t): + for c in range(size_c): + for z in range(size_z): + + indices = [] + if "t" in axes: + indices.append(t) + if "c" in axes: + indices.append(c) + if "z" in axes: + indices.append(z) + + # indices.append(np.s_[::]) + # indices.append(np.s_[x:x_max:]) + + # slice down to 2D plane + plane = bin_img[tuple(indices)] + + if not np.any(plane): + continue + + # plane = plane.compute() + + # Find bounding box to minimise size of mask + xmask = plane.sum(0).nonzero()[0] + ymask = plane.sum(1).nonzero()[0] + # if any(xmask) and any(ymask): + x0 = min(xmask) + w = max(xmask) - x0 + 1 + y0 = min(ymask) + h = max(ymask) - y0 + 1 + submask = plane[y0 : (y0 + h), x0 : (x0 + w)] + + mask = MaskI() + mask.setBytes(np.packbits(np.asarray(submask, dtype=int))) + mask.setWidth(rdouble(w)) + mask.setHeight(rdouble(h)) + mask.setX(rdouble(x0)) + mask.setY(rdouble(y0)) + + if i in colors_by_value: + ch = ColorHolder.fromRGBA(*colors_by_value[i]) + mask.setFillColor(rint(ch.getInt())) + if "z" in axes: + mask.setTheZ(rint(z)) + if "c" in axes: + mask.setTheC(rint(c)) + if "t" in axes: + mask.setTheT(rint(t)) + if i in text_by_value: + mask.setTextValue(rstring(text_by_value[i])) + + masks.append(mask) + + rois[i] = masks + + return rois + + +def rois_from_labels_nd( + conn: BlitzGateway, + image_id: int, + labels_nd: np.ndarray, + axes: list[str], + label_props: dict, +) -> None: + # Text is set on Mask shapes, not ROIs + rois = masks_from_labels_nd(labels_nd, axes, label_props) + + for label, masks in rois.items(): + if len(masks) > 0: + create_roi(conn, image_id, shapes=masks) + + +def create_roi(conn: BlitzGateway, image_id: int, shapes: list) -> RoiI: + # create an ROI, link it to Image + roi = RoiI() + roi.setImage(ImageI(image_id, False)) + for shape in shapes: + roi.addShape(shape) + # Save the ROI (saves any linked shapes too) + print(f"Save ROI for image: {image_id}") + return conn.getUpdateService().saveAndReturnObject(roi) + + +def create_labels( + conn: BlitzGateway, store: Store, image_id: int, image_path: Optional[str] = None +) -> None: + """ + Create labels for the image + """ + if image_path is None: + image_path = "" + labels_path = image_path + "/labels" + try: + labels_attrs = load_attrs(store, labels_path) + except GroupNotFoundError: + print("No zarr group at", labels_path) + return + if "labels" not in labels_attrs: + print("No labels found at", labels_path) + return + for name in labels_attrs["labels"]: + print("Found label:", name) + label_path = f"{labels_path}/{name}" + print("Loading label from:", label_path) + + label_image = load_attrs(store, label_path) + + axes = label_image["multiscales"][0]["axes"] + axes_names = [axis["name"] for axis in axes] + label_props = label_image.get("image-label", {}) + + ds_path = label_image["multiscales"][0]["datasets"][0]["path"] + array_path = f"{label_path}/{ds_path}/" + labels_nd = open_array(store=store, mode="r", path=array_path) + labels_data = labels_nd[slice(None)] + + # Create ROIs from the labels + rois_from_labels_nd(conn, image_id, labels_data, axes_names, label_props) diff --git a/src/omero_zarr/zarr_import.py b/src/omero_zarr/zarr_import.py index b02aadc..ef46312 100644 --- a/src/omero_zarr/zarr_import.py +++ b/src/omero_zarr/zarr_import.py @@ -44,6 +44,7 @@ from zarr.hierarchy import open_group from zarr.storage import FSStore +from .import_labels import create_labels from .import_xml import full_import # TODO: support Zarr v3 - imports for get_omexml_bytes() @@ -185,6 +186,9 @@ def create_image( img_obj = image._obj set_external_info(img_obj, kwargs, image_path) + if "labels" in kwargs and kwargs["labels"]: + print("Importing labels for image:", img_obj.id.val) + create_labels(conn, store, img_obj.id.val, image_path) return img_obj, rnd_def @@ -553,6 +557,9 @@ def import_zarr( if rnd_def is not None: conn.getUpdateService().saveAndReturnObject(rnd_def) set_external_info(image._obj, kwargs, image_path=image_path) + if "labels" in kwargs and kwargs["labels"]: + print("Importing labels for series:", series) + create_labels(conn, store, image.id, image_path) # default name is METADATA.ome.xml [series], based on clientPath? new_name = image.name.replace("METADATA.ome.xml", zarr_name) print("Imported Image:", image.id) diff --git a/test/integration/clitest/test_import.py b/test/integration/clitest/test_import.py index b58dc28..2b45bfc 100644 --- a/test/integration/clitest/test_import.py +++ b/test/integration/clitest/test_import.py @@ -23,6 +23,7 @@ import pytest from omero.gateway import BlitzGateway +from omero.model import MaskI from omero.testlib.cli import AbstractCLITest from omero_zarr.cli import ZarrControl from omero_zarr.zarr_import import import_zarr @@ -31,6 +32,7 @@ "6001240.zarr": { "url": "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0062A/6001240.zarr", "dataset_name": "Test Import 6001240", + "args": "--labels", }, "13457227.zarr": { "url": "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0101A/13457227.zarr", @@ -130,6 +132,8 @@ def test_register_images( kwargs["endpoint"] = url_args[1] if "--nosignrequest" in url_args: kwargs["nosignrequest"] = True + if "--labels" in url_args: + kwargs["labels"] = True if "dataset_name" in sample: kwargs["target_by_name"] = ds_name else: @@ -164,3 +168,18 @@ def test_register_images( f"Image {img_id} sizeX {size_x} physSizeX {phys_size_x} != " f"expected {exp_size_x}" ) + + if "labels" in sample.get("args", ""): + # check we have labels + for img_id in image_ids: + roi_service = conn.getRoiService() + result = roi_service.findByImage(img_id, None) + assert len(result.rois) > 0, f"No ROIs found for image {img_id}" + for roi in result.rois: + assert ( + len(roi.copyShapes()) > 0 + ), f"No shapes found for ROI {roi.id.val}" + for s in roi.copyShapes(): + assert type(s) is MaskI + assert s.getWidth().getValue() > 0 + assert s.getHeight().getValue() > 0