Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/omero_zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ def _configure(self, parser: Parser) -> None:
"Only applies when importing OME/METADATA.ome.xml."
),
)
import_cmd.add_argument(
"--labels", action="store_true", help="Also import labels if present"
)

@gateway_required
def masks(self, args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -418,6 +421,7 @@ def import_cmd(self, args: argparse.Namespace) -> None:
target=args.target,
target_by_name=args.target_by_name,
wait=args.wait,
labels=args.labels,
)

def _lookup(
Expand Down
196 changes: 196 additions & 0 deletions src/omero_zarr/import_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#!/usr/bin/env python

# Copyright (C) 2025 University of Dundee & Open Microscopy Environment.
# All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from typing import Optional

import numpy as np
from omero.gateway import BlitzGateway, ColorHolder
from omero.model import ImageI, MaskI, RoiI
from omero.rtypes import rdouble, rint, rstring
from zarr.creation import open_array
from zarr.errors import GroupNotFoundError
from zarr.hierarchy import open_group
from zarr.storage import Store, StoreLike


def load_attrs(store: StoreLike, path: Optional[str] = None) -> dict:
"""
Load the attrs from the root group or path subgroup
"""
root = open_group(store=store, mode="r", path=path)
attrs = root.attrs.asdict()
if "ome" in attrs:
attrs = attrs["ome"]
return attrs


def masks_from_labels_nd(
labels_nd: np.ndarray, axes: list[str], label_props: dict
) -> dict:
rois = {}

colors_by_value = {}
if "colors" in label_props:
for color in label_props["colors"]:
pixel_value = color.get("label-value", None)
rgba = color.get("rgba", None)
if pixel_value and rgba and len(rgba) == 4:
colors_by_value[pixel_value] = rgba

text_by_value = {}
if "properties" in label_props:
for props in label_props["properties"]:
pixel_value = props.get("label-value", None)
text = props.get("omero:text", None)
if pixel_value and text:
text_by_value[pixel_value] = text

# For each label value, we create an ROI that
# contains 2D masks for each time point, channel, and z-slice.
for i in range(1, int(labels_nd.max()) + 1):
if not np.any(labels_nd == i):
continue

masks = []
bin_img = labels_nd == i

sizes = {dim: labels_nd.shape[axes.index(dim)] for dim in axes}
size_t = sizes.get("t", 1)
size_c = sizes.get("c", 1)
size_z = sizes.get("z", 1)

for t in range(size_t):
for c in range(size_c):
for z in range(size_z):

indices = []
if "t" in axes:
indices.append(t)
if "c" in axes:
indices.append(c)
if "z" in axes:
indices.append(z)

# indices.append(np.s_[::])
# indices.append(np.s_[x:x_max:])

# slice down to 2D plane
plane = bin_img[tuple(indices)]

if not np.any(plane):
continue

# plane = plane.compute()

# Find bounding box to minimise size of mask
xmask = plane.sum(0).nonzero()[0]
ymask = plane.sum(1).nonzero()[0]
# if any(xmask) and any(ymask):
x0 = min(xmask)
w = max(xmask) - x0 + 1
y0 = min(ymask)
h = max(ymask) - y0 + 1
submask = plane[y0 : (y0 + h), x0 : (x0 + w)]

mask = MaskI()
mask.setBytes(np.packbits(np.asarray(submask, dtype=int)))
mask.setWidth(rdouble(w))
mask.setHeight(rdouble(h))
mask.setX(rdouble(x0))
mask.setY(rdouble(y0))

if i in colors_by_value:
ch = ColorHolder.fromRGBA(*colors_by_value[i])
mask.setFillColor(rint(ch.getInt()))
if "z" in axes:
mask.setTheZ(rint(z))
if "c" in axes:
mask.setTheC(rint(c))
if "t" in axes:
mask.setTheT(rint(t))
if i in text_by_value:
mask.setTextValue(rstring(text_by_value[i]))

masks.append(mask)

rois[i] = masks

return rois


def rois_from_labels_nd(
conn: BlitzGateway,
image_id: int,
labels_nd: np.ndarray,
axes: list[str],
label_props: dict,
) -> None:
# Text is set on Mask shapes, not ROIs
rois = masks_from_labels_nd(labels_nd, axes, label_props)

for label, masks in rois.items():
if len(masks) > 0:
create_roi(conn, image_id, shapes=masks)


def create_roi(conn: BlitzGateway, image_id: int, shapes: list) -> RoiI:
# create an ROI, link it to Image
roi = RoiI()
roi.setImage(ImageI(image_id, False))
for shape in shapes:
roi.addShape(shape)
# Save the ROI (saves any linked shapes too)
print(f"Save ROI for image: {image_id}")
return conn.getUpdateService().saveAndReturnObject(roi)


def create_labels(
conn: BlitzGateway, store: Store, image_id: int, image_path: Optional[str] = None
) -> None:
"""
Create labels for the image
"""
if image_path is None:
image_path = ""
labels_path = image_path + "/labels"
try:
labels_attrs = load_attrs(store, labels_path)
except GroupNotFoundError:
print("No zarr group at", labels_path)
return
if "labels" not in labels_attrs:
print("No labels found at", labels_path)
return
for name in labels_attrs["labels"]:
print("Found label:", name)
label_path = f"{labels_path}/{name}"
print("Loading label from:", label_path)

label_image = load_attrs(store, label_path)

axes = label_image["multiscales"][0]["axes"]
axes_names = [axis["name"] for axis in axes]
label_props = label_image.get("image-label", {})

ds_path = label_image["multiscales"][0]["datasets"][0]["path"]
array_path = f"{label_path}/{ds_path}/"
labels_nd = open_array(store=store, mode="r", path=array_path)
labels_data = labels_nd[slice(None)]

# Create ROIs from the labels
rois_from_labels_nd(conn, image_id, labels_data, axes_names, label_props)
7 changes: 7 additions & 0 deletions src/omero_zarr/zarr_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from zarr.hierarchy import open_group
from zarr.storage import FSStore

from .import_labels import create_labels
from .import_xml import full_import

# TODO: support Zarr v3 - imports for get_omexml_bytes()
Expand Down Expand Up @@ -185,6 +186,9 @@ def create_image(

img_obj = image._obj
set_external_info(img_obj, kwargs, image_path)
if "labels" in kwargs and kwargs["labels"]:
print("Importing labels for image:", img_obj.id.val)
create_labels(conn, store, img_obj.id.val, image_path)

return img_obj, rnd_def

Expand Down Expand Up @@ -553,6 +557,9 @@ def import_zarr(
if rnd_def is not None:
conn.getUpdateService().saveAndReturnObject(rnd_def)
set_external_info(image._obj, kwargs, image_path=image_path)
if "labels" in kwargs and kwargs["labels"]:
print("Importing labels for series:", series)
create_labels(conn, store, image.id, image_path)
# default name is METADATA.ome.xml [series], based on clientPath?
new_name = image.name.replace("METADATA.ome.xml", zarr_name)
print("Imported Image:", image.id)
Expand Down
19 changes: 19 additions & 0 deletions test/integration/clitest/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import pytest
from omero.gateway import BlitzGateway
from omero.model import MaskI
from omero.testlib.cli import AbstractCLITest
from omero_zarr.cli import ZarrControl
from omero_zarr.zarr_import import import_zarr
Expand All @@ -31,6 +32,7 @@
"6001240.zarr": {
"url": "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0062A/6001240.zarr",
"dataset_name": "Test Import 6001240",
"args": "--labels",
},
"13457227.zarr": {
"url": "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0101A/13457227.zarr",
Expand Down Expand Up @@ -130,6 +132,8 @@ def test_register_images(
kwargs["endpoint"] = url_args[1]
if "--nosignrequest" in url_args:
kwargs["nosignrequest"] = True
if "--labels" in url_args:
kwargs["labels"] = True
if "dataset_name" in sample:
kwargs["target_by_name"] = ds_name
else:
Expand Down Expand Up @@ -164,3 +168,18 @@ def test_register_images(
f"Image {img_id} sizeX {size_x} physSizeX {phys_size_x} != "
f"expected {exp_size_x}"
)

if "labels" in sample.get("args", ""):
# check we have labels
for img_id in image_ids:
roi_service = conn.getRoiService()
result = roi_service.findByImage(img_id, None)
assert len(result.rois) > 0, f"No ROIs found for image {img_id}"
for roi in result.rois:
assert (
len(roi.copyShapes()) > 0
), f"No shapes found for ROI {roi.id.val}"
for s in roi.copyShapes():
assert type(s) is MaskI
assert s.getWidth().getValue() > 0
assert s.getHeight().getValue() > 0