Skip to content

Commit b43d6a1

Browse files
committed
feat: Added bbox handling
1 parent 905e69a commit b43d6a1

File tree

3 files changed

+201
-6
lines changed

3 files changed

+201
-6
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ requires-python = ">=3.10"
2929
# or any other Qt bindings directly (e.g. PyQt5, PySide2).
3030
# See best practices: https://napari.org/stable/plugins/building_a_plugin/best_practices.html
3131
dependencies = [
32-
"mlarray"
32+
"mlarray",
33+
"numpy"
3334
]
3435

3536
[project.optional-dependencies]

src/napari_mlarray/_reader.py

Lines changed: 198 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
https://napari.org/stable/plugins/building_a_plugin/guides.html#readers
77
"""
88
from mlarray import MLArray
9+
from pathlib import Path
10+
import numpy as np
911

1012

1113
def napari_get_reader(path):
@@ -69,9 +71,201 @@ def reader_function(path):
6971
layer. Both "meta", and "layer_type" are optional. napari will
7072
default to layer_type=="image" if not provided
7173
"""
72-
# handle both a string and a list of strings
7374
paths = [path] if isinstance(path, str) else path
74-
# load all files into array
75-
mlarrays = [MLArray.open(_path) for _path in paths]
76-
layer_data = [(mlarray, {"affine": mlarray.affine, "metadata": mlarray.meta.to_dict()}, "labels" if mlarray.meta.is_seg.is_seg == True else "image") for mlarray in mlarrays]
75+
layer_data = []
76+
for path in paths:
77+
name = Path(path).stem
78+
mlarray = MLArray.open(path)
79+
if mlarray.meta._has_array.has_array == True:
80+
data = mlarray
81+
metadata = {"name": f"{name}", "affine": mlarray.affine, "metadata": mlarray.meta.to_mapping()}
82+
layer_type = "labels" if mlarray.meta.is_seg.is_seg == True else "image"
83+
layer_data.append((data, metadata, layer_type))
84+
if mlarray.meta.bbox.bboxes is not None:
85+
data = bboxes_minmax_to_napari_rectangles_2d(mlarray.meta.bbox.bboxes)
86+
edge_color = _napari_bbox_edge_colors(
87+
data,
88+
labels=getattr(mlarray.meta.bbox, "labels", None),
89+
)
90+
text = _napari_bbox_score_text(
91+
scores=getattr(mlarray.meta.bbox, "scores", None),
92+
labels=getattr(mlarray.meta.bbox, "labels", None),
93+
count=len(data),
94+
edge_color=edge_color,
95+
rectangles=data,
96+
)
97+
metadata = {
98+
"name": f"{name} (BBoxes)",
99+
"shape_type": "rectangle",
100+
"affine": mlarray.affine,
101+
"metadata": mlarray.meta.to_mapping(),
102+
"face_color": "transparent",
103+
"edge_color": edge_color,
104+
}
105+
if text is not None:
106+
metadata["text"] = text
107+
layer_type = "shapes"
108+
layer_data.append((data, metadata, layer_type))
77109
return layer_data
110+
111+
112+
def bboxes_minmax_to_napari_rectangles_2d(
113+
bboxes,
114+
*,
115+
dtype=np.float32,
116+
validate: bool = True,
117+
) -> np.ndarray:
118+
"""
119+
Convert 2D axis-aligned bounding boxes from min/max format to napari Shapes rectangles.
120+
121+
Accepted input formats (both mean the same thing):
122+
1) (N, 2, 2): [[min_dim0, max_dim0], [min_dim1, max_dim1]]
123+
Example (dim order is whatever you use, e.g. (y, x)):
124+
[[[ymin, ymax], [xmin, xmax]], ...]
125+
126+
2) (N, 4): [min_dim0, min_dim1, max_dim0, max_dim1]
127+
Example:
128+
[[ymin, xmin, ymax, xmax], ...]
129+
130+
Output format (napari Shapes rectangle vertices):
131+
(N, 4, 2) with vertices in non-twisting cyclic order:
132+
(min0, min1) -> (min0, max1) -> (max0, max1) -> (max0, min1)
133+
134+
Raises:
135+
ValueError if bboxes are not 2D (i.e., D != 2) or shapes are invalid.
136+
"""
137+
arr = np.asarray(bboxes)
138+
139+
# Normalize input to shape (N, 2, 2)
140+
if arr.ndim == 2 and arr.shape[1] == 4:
141+
# (N, 4) -> (N, 2, 2)
142+
arr = np.stack(
143+
[
144+
arr[:, [0, 2]], # dim0: [min0, max0]
145+
arr[:, [1, 3]], # dim1: [min1, max1]
146+
],
147+
axis=1,
148+
)
149+
elif arr.ndim == 3 and arr.shape[1:] == (2, 2):
150+
pass
151+
else:
152+
raise ValueError(
153+
f"Expected bboxes of shape (N, 2, 2) or (N, 4). Got {arr.shape}."
154+
)
155+
156+
N, D, two = arr.shape
157+
if D != 2 or two != 2:
158+
# Defensive; should never hit because of checks above.
159+
raise ValueError(f"Only 2D bboxes are supported. Got (N, {D}, {two}).")
160+
161+
mins = arr[:, :, 0]
162+
maxs = arr[:, :, 1]
163+
164+
if validate and np.any(maxs < mins):
165+
bad = np.argwhere(maxs < mins)
166+
raise ValueError(
167+
"Found bbox with max < min at indices (bbox_index, dim): "
168+
f"{bad[:10].tolist()}" + (" ..." if len(bad) > 10 else "")
169+
)
170+
171+
min0, min1 = mins[:, 0], mins[:, 1]
172+
max0, max1 = maxs[:, 0], maxs[:, 1]
173+
174+
# Cyclic order (no twisting):
175+
rects = np.stack(
176+
[
177+
np.stack([min0, min1], axis=1),
178+
np.stack([min0, max1], axis=1),
179+
np.stack([max0, max1], axis=1),
180+
np.stack([max0, min1], axis=1),
181+
],
182+
axis=1,
183+
).astype(dtype, copy=False)
184+
185+
return rects
186+
187+
188+
def _napari_bbox_edge_colors(rectangles, labels):
189+
"""Return RGBA edge colors for each bbox."""
190+
count = len(rectangles)
191+
if count == 0:
192+
return np.empty((0, 4), dtype=np.float32)
193+
194+
if labels is not None and len(labels) == count:
195+
unique_labels = list(dict.fromkeys(labels))
196+
label_to_color = {
197+
label: _palette_rgba(idx) for idx, label in enumerate(unique_labels)
198+
}
199+
colors = np.array([label_to_color[label] for label in labels], dtype=np.float32)
200+
else:
201+
colors = np.array([_palette_rgba(idx) for idx in range(count)], dtype=np.float32)
202+
203+
return colors
204+
205+
206+
def _napari_bbox_score_text(scores, labels, count, edge_color, rectangles):
207+
"""Return napari Shapes text metadata if scores are provided."""
208+
have_scores = scores is not None and len(scores) == count
209+
have_labels = labels is not None and len(labels) == count
210+
if not have_scores and not have_labels:
211+
return None
212+
213+
# Place text at the top-left corner of each rectangle.
214+
top_left = rectangles[:, 0, :]
215+
top_left = np.maximum(top_left - np.array([4.0, 0.0], dtype=top_left.dtype), 0)
216+
217+
strings = []
218+
for idx in range(count):
219+
parts = []
220+
if have_labels:
221+
parts.append(f"Label: {labels[idx]}")
222+
if have_scores:
223+
parts.append(f"Score: {scores[idx]:.3f}")
224+
# Add a trailing empty line to create spacing below the score.
225+
parts.append("\n")
226+
strings.append("\n".join(parts))
227+
228+
return {
229+
"string": strings,
230+
"color": edge_color,
231+
"size": 12,
232+
"anchor": "upper_left",
233+
"position": top_left,
234+
}
235+
236+
237+
def _palette_rgba(index):
238+
"""Simple, distinct-ish palette; returns RGBA in 0..1."""
239+
palette = [
240+
(0.90, 0.10, 0.12, 1.0),
241+
(0.00, 0.48, 1.00, 1.0),
242+
(0.20, 0.80, 0.20, 1.0),
243+
(0.98, 0.60, 0.00, 1.0),
244+
(0.60, 0.20, 0.80, 1.0),
245+
(0.10, 0.75, 0.80, 1.0),
246+
(0.80, 0.80, 0.00, 1.0),
247+
(0.95, 0.40, 0.60, 1.0),
248+
(0.90, 0.30, 0.00, 1.0),
249+
(0.00, 0.70, 0.40, 1.0),
250+
(0.40, 0.80, 1.00, 1.0),
251+
(1.00, 0.20, 0.70, 1.0),
252+
(0.50, 0.90, 0.20, 1.0),
253+
(0.20, 0.90, 0.70, 1.0),
254+
(0.70, 0.50, 1.00, 1.0),
255+
(1.00, 0.50, 0.20, 1.0),
256+
(0.20, 0.60, 1.00, 1.0),
257+
(1.00, 0.70, 0.20, 1.0),
258+
(0.60, 1.00, 0.20, 1.0),
259+
(0.20, 1.00, 0.40, 1.0),
260+
(0.20, 1.00, 0.90, 1.0),
261+
(0.20, 0.90, 1.00, 1.0),
262+
(0.40, 0.60, 1.00, 1.0),
263+
(0.80, 0.20, 1.00, 1.0),
264+
(1.00, 0.20, 0.30, 1.0),
265+
(1.00, 0.30, 0.50, 1.0),
266+
(1.00, 0.60, 0.60, 1.0),
267+
(1.00, 0.90, 0.30, 1.0),
268+
(0.60, 1.00, 0.60, 1.0),
269+
(0.60, 0.90, 1.00, 1.0),
270+
]
271+
return palette[index % len(palette)]

src/napari_mlarray/_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def write_single_image(path: str, data: Any, meta: dict) -> list[str]:
3434
-------
3535
[path] : A list containing the string path to the saved file.
3636
"""
37-
mlarray = MLArray(data, meta=Meta.from_dict(meta["metadata"]))
37+
mlarray = MLArray(data, meta=Meta.from_mapping(meta["metadata"]))
3838
mlarray.save(path)
3939

4040
# return path to any file(s) that were successfully written

0 commit comments

Comments
 (0)