Skip to content

Commit 5aeee1d

Browse files
committed
feat: Added 3D bbox support
Adds a napari reader plugin for .mla files, enabling the loading of image and labels data, as well as bounding box annotations. The reader utilizes `napari-bbox-fix` to correctly display bounding boxes with dimension 3 or higher, leveraging it's custom layer type. For 2D bounding boxes, it converts them to napari shapes. Fixes an issue where bounding boxes were not correctly displayed in napari.
1 parent b43d6a1 commit 5aeee1d

File tree

2 files changed

+124
-102
lines changed

2 files changed

+124
-102
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ requires-python = ">=3.10"
3030
# See best practices: https://napari.org/stable/plugins/building_a_plugin/best_practices.html
3131
dependencies = [
3232
"mlarray",
33-
"numpy"
33+
"numpy",
34+
"napari-bbox-fix"
3435
]
3536

3637
[project.optional-dependencies]

src/napari_mlarray/_reader.py

Lines changed: 122 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -9,68 +9,26 @@
99
from pathlib import Path
1010
import numpy as np
1111

12+
# Ensure napari-bbox registers its custom layer type.
13+
import napari_bbox # noqa: F401
14+
1215

1316
def napari_get_reader(path):
14-
"""A basic implementation of a Reader contribution.
15-
16-
Parameters
17-
----------
18-
path : str or list of str
19-
Path to file, or list of paths.
20-
21-
Returns
22-
-------
23-
function or None
24-
If the path is a recognized format, return a function that accepts the
25-
same path or list of paths, and returns a list of layer data tuples.
26-
"""
17+
"""A basic implementation of a Reader contribution."""
2718
if isinstance(path, list):
28-
# reader plugins may be handed single path, or a list of paths.
29-
# if it is a list, it is assumed to be an image stack...
30-
# so we are only going to look at the first file.
3119
path = path[0]
3220

33-
# the get_reader function should make as many checks as possible
34-
# (without loading the full file) to determine if it can read
35-
# the path. Here, we check the dtype of the array by loading
36-
# it with memmap, so that we don't actually load the full array into memory.
37-
# We pretend that this reader can only read integer arrays.
3821
try:
3922
if not str(path).endswith(".mla"):
4023
return None
41-
# napari_get_reader should never raise an exception, because napari
42-
# raises its own specific errors depending on what plugins are
43-
# available for the given path, so we catch
44-
# the OSError that np.load might raise if the file is malformed
4524
except OSError:
4625
return None
4726

48-
# otherwise we return the *function* that can read ``path``.
4927
return reader_function
5028

5129

5230
def reader_function(path):
53-
"""Take a path or list of paths and return a list of LayerData tuples.
54-
55-
Readers are expected to return data as a list of tuples, where each tuple
56-
is (data, [add_kwargs, [layer_type]]), "add_kwargs" and "layer_type" are
57-
both optional.
58-
59-
Parameters
60-
----------
61-
path : str or list of str
62-
Path to file, or list of paths.
63-
64-
Returns
65-
-------
66-
layer_data : list of tuples
67-
A list of LayerData tuples where each tuple in the list contains
68-
(data, metadata, layer_type), where data is a numpy array, metadata is
69-
a dict of keyword arguments for the corresponding viewer.add_* method
70-
in napari, and layer_type is a lower-case string naming the type of
71-
layer. Both "meta", and "layer_type" are optional. napari will
72-
default to layer_type=="image" if not provided
73-
"""
31+
"""Take a path or list of paths and return a list of LayerData tuples."""
7432
paths = [path] if isinstance(path, str) else path
7533
layer_data = []
7634
for path in paths:
@@ -82,30 +40,58 @@ def reader_function(path):
8240
layer_type = "labels" if mlarray.meta.is_seg.is_seg == True else "image"
8341
layer_data.append((data, metadata, layer_type))
8442
if mlarray.meta.bbox.bboxes is not None:
85-
data = bboxes_minmax_to_napari_rectangles_2d(mlarray.meta.bbox.bboxes)
86-
edge_color = _napari_bbox_edge_colors(
87-
data,
88-
labels=getattr(mlarray.meta.bbox, "labels", None),
89-
)
90-
text = _napari_bbox_score_text(
91-
scores=getattr(mlarray.meta.bbox, "scores", None),
92-
labels=getattr(mlarray.meta.bbox, "labels", None),
93-
count=len(data),
94-
edge_color=edge_color,
95-
rectangles=data,
96-
)
97-
metadata = {
98-
"name": f"{name} (BBoxes)",
99-
"shape_type": "rectangle",
100-
"affine": mlarray.affine,
101-
"metadata": mlarray.meta.to_mapping(),
102-
"face_color": "transparent",
103-
"edge_color": edge_color,
104-
}
105-
if text is not None:
106-
metadata["text"] = text
107-
layer_type = "shapes"
108-
layer_data.append((data, metadata, layer_type))
43+
bboxes = np.asarray(mlarray.meta.bbox.bboxes)
44+
45+
# MLArray bboxes are always (N, D, 2)
46+
if bboxes.ndim != 3 or bboxes.shape[2] != 2:
47+
raise ValueError(f"Unsupported bbox shape: {bboxes.shape}")
48+
49+
dims = bboxes.shape[1]
50+
51+
# 2D -> keep shapes rectangles (original behavior)
52+
if dims == 2:
53+
data = bboxes_minmax_to_napari_rectangles_2d(bboxes)
54+
edge_color = _napari_bbox_edge_colors(
55+
data,
56+
labels=getattr(mlarray.meta.bbox, "labels", None),
57+
)
58+
text = _napari_bbox_score_text(
59+
scores=getattr(mlarray.meta.bbox, "scores", None),
60+
labels=getattr(mlarray.meta.bbox, "labels", None),
61+
count=len(data),
62+
edge_color=edge_color,
63+
rectangles=data,
64+
)
65+
metadata = {
66+
"name": f"{name} (BBoxes)",
67+
"shape_type": "rectangle",
68+
"affine": mlarray.affine,
69+
"metadata": mlarray.meta.to_mapping(),
70+
"face_color": "transparent",
71+
"edge_color": edge_color,
72+
}
73+
if text is not None:
74+
metadata["text"] = text
75+
layer_type = "shapes"
76+
layer_data.append((data, metadata, layer_type))
77+
78+
# 3D+ -> napari-bbox layer
79+
elif dims >= 3:
80+
data = bboxes_minmax_to_napari_bboxes_nd(bboxes)
81+
edge_color = _napari_bbox_edge_colors_count(
82+
count=len(data),
83+
labels=getattr(mlarray.meta.bbox, "labels", None),
84+
)
85+
metadata = {
86+
"name": f"{name} (BBoxes)",
87+
"affine": mlarray.affine,
88+
"metadata": mlarray.meta.to_mapping(),
89+
"face_color": "transparent",
90+
"edge_color": edge_color,
91+
# "edge_width": 2,
92+
}
93+
layer_type = "boundingboxlayer"
94+
layer_data.append((data, metadata, layer_type))
10995
return layer_data
11096

11197

@@ -115,34 +101,14 @@ def bboxes_minmax_to_napari_rectangles_2d(
115101
dtype=np.float32,
116102
validate: bool = True,
117103
) -> np.ndarray:
118-
"""
119-
Convert 2D axis-aligned bounding boxes from min/max format to napari Shapes rectangles.
120-
121-
Accepted input formats (both mean the same thing):
122-
1) (N, 2, 2): [[min_dim0, max_dim0], [min_dim1, max_dim1]]
123-
Example (dim order is whatever you use, e.g. (y, x)):
124-
[[[ymin, ymax], [xmin, xmax]], ...]
125-
126-
2) (N, 4): [min_dim0, min_dim1, max_dim0, max_dim1]
127-
Example:
128-
[[ymin, xmin, ymax, xmax], ...]
129-
130-
Output format (napari Shapes rectangle vertices):
131-
(N, 4, 2) with vertices in non-twisting cyclic order:
132-
(min0, min1) -> (min0, max1) -> (max0, max1) -> (max0, min1)
133-
134-
Raises:
135-
ValueError if bboxes are not 2D (i.e., D != 2) or shapes are invalid.
136-
"""
104+
"""Convert 2D axis-aligned bounding boxes from min/max format to napari Shapes rectangles."""
137105
arr = np.asarray(bboxes)
138106

139-
# Normalize input to shape (N, 2, 2)
140107
if arr.ndim == 2 and arr.shape[1] == 4:
141-
# (N, 4) -> (N, 2, 2)
142108
arr = np.stack(
143109
[
144-
arr[:, [0, 2]], # dim0: [min0, max0]
145-
arr[:, [1, 3]], # dim1: [min1, max1]
110+
arr[:, [0, 2]],
111+
arr[:, [1, 3]],
146112
],
147113
axis=1,
148114
)
@@ -153,13 +119,20 @@ def bboxes_minmax_to_napari_rectangles_2d(
153119
f"Expected bboxes of shape (N, 2, 2) or (N, 4). Got {arr.shape}."
154120
)
155121

156-
N, D, two = arr.shape
122+
# MLArray uses (N, D, 2) -> convert to (N, 2, 2)
123+
if arr.shape == (arr.shape[0], 2, 2):
124+
arr2 = arr
125+
else:
126+
arr2 = np.transpose(arr, (0, 2, 1))
127+
128+
N, D, two = arr2.shape
157129
if D != 2 or two != 2:
158-
# Defensive; should never hit because of checks above.
159130
raise ValueError(f"Only 2D bboxes are supported. Got (N, {D}, {two}).")
160131

161-
mins = arr[:, :, 0]
162-
maxs = arr[:, :, 1]
132+
mins = arr2[:, 0, :]
133+
maxs = arr2[:, 1, :]
134+
# Ensure proper min/max ordering even if input is flipped
135+
mins, maxs = np.minimum(mins, maxs), np.maximum(mins, maxs)
163136

164137
if validate and np.any(maxs < mins):
165138
bad = np.argwhere(maxs < mins)
@@ -171,7 +144,6 @@ def bboxes_minmax_to_napari_rectangles_2d(
171144
min0, min1 = mins[:, 0], mins[:, 1]
172145
max0, max1 = maxs[:, 0], maxs[:, 1]
173146

174-
# Cyclic order (no twisting):
175147
rects = np.stack(
176148
[
177149
np.stack([min0, min1], axis=1),
@@ -185,6 +157,40 @@ def bboxes_minmax_to_napari_rectangles_2d(
185157
return rects
186158

187159

160+
def bboxes_minmax_to_napari_bboxes_nd(
161+
bboxes,
162+
*,
163+
dtype=np.float32,
164+
validate: bool = True,
165+
):
166+
"""
167+
Convert N-D axis-aligned bboxes from min/max to napari-bbox format.
168+
Input (MLArray): (N, D, 2) where [:, :, 0] are mins and [:, :, 1] are maxs.
169+
Returns:
170+
- list of (2, D) arrays, one per bbox.
171+
"""
172+
arr = np.asarray(bboxes)
173+
174+
if arr.ndim != 3 or arr.shape[2] != 2:
175+
raise ValueError(
176+
f"Expected bboxes of shape (N, D, 2). Got {arr.shape}."
177+
)
178+
179+
mins = arr[:, :, 0]
180+
maxs = arr[:, :, 1]
181+
# Ensure proper min/max ordering even if input is flipped
182+
mins, maxs = np.minimum(mins, maxs), np.maximum(mins, maxs)
183+
if validate and np.any(maxs < mins):
184+
bad = np.argwhere(maxs < mins)
185+
raise ValueError(
186+
"Found bbox with max < min at indices (bbox_index, dim): "
187+
f"{bad[:10].tolist()}" + (" ..." if len(bad) > 10 else "")
188+
)
189+
190+
arr2 = np.stack([mins, maxs], axis=1).astype(dtype, copy=False)
191+
return [arr2[i] for i in range(arr2.shape[0])]
192+
193+
188194
def _napari_bbox_edge_colors(rectangles, labels):
189195
"""Return RGBA edge colors for each bbox."""
190196
count = len(rectangles)
@@ -203,14 +209,30 @@ def _napari_bbox_edge_colors(rectangles, labels):
203209
return colors
204210

205211

212+
def _napari_bbox_edge_colors_count(count, labels=None):
213+
"""Return RGBA edge colors for each bbox (count-based)."""
214+
if count == 0:
215+
return np.empty((0, 4), dtype=np.float32)
216+
217+
if labels is not None and len(labels) == count:
218+
unique_labels = list(dict.fromkeys(labels))
219+
label_to_color = {
220+
label: _palette_rgba(idx) for idx, label in enumerate(unique_labels)
221+
}
222+
colors = np.array([label_to_color[label] for label in labels], dtype=np.float32)
223+
else:
224+
colors = np.array([_palette_rgba(idx) for idx in range(count)], dtype=np.float32)
225+
226+
return colors
227+
228+
206229
def _napari_bbox_score_text(scores, labels, count, edge_color, rectangles):
207230
"""Return napari Shapes text metadata if scores are provided."""
208231
have_scores = scores is not None and len(scores) == count
209232
have_labels = labels is not None and len(labels) == count
210233
if not have_scores and not have_labels:
211234
return None
212235

213-
# Place text at the top-left corner of each rectangle.
214236
top_left = rectangles[:, 0, :]
215237
top_left = np.maximum(top_left - np.array([4.0, 0.0], dtype=top_left.dtype), 0)
216238

@@ -221,7 +243,6 @@ def _napari_bbox_score_text(scores, labels, count, edge_color, rectangles):
221243
parts.append(f"Label: {labels[idx]}")
222244
if have_scores:
223245
parts.append(f"Score: {scores[idx]:.3f}")
224-
# Add a trailing empty line to create spacing below the score.
225246
parts.append("\n")
226247
strings.append("\n".join(parts))
227248

0 commit comments

Comments
 (0)