Skip to content

Commit 5b6497d

Browse files
Merge pull request #78 from FocoosAI/feat/crop-masks-with-bbox
feat: crop masks with bbox and enhance postprocess efficiency
2 parents fc77ccd + cba7889 commit 5b6497d

File tree

10 files changed

+269
-77
lines changed

10 files changed

+269
-77
lines changed

focoos/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
base64mask_to_mask,
3131
binary_mask_to_base64,
3232
class_to_index,
33-
focoos_detections_to_supervision,
33+
fai_detections_to_sv,
3434
image_loader,
3535
image_preprocess,
3636
index_to_class,
37-
sv_to_focoos_detections,
37+
sv_to_fai_detections,
3838
)
3939

4040
__all__ = [
@@ -67,10 +67,10 @@
6767
"base64mask_to_mask",
6868
"binary_mask_to_base64",
6969
"class_to_index",
70-
"focoos_detections_to_supervision",
70+
"fai_detections_to_sv",
7171
"image_loader",
7272
"image_preprocess",
7373
"index_to_class",
74-
"sv_to_focoos_detections",
74+
"sv_to_fai_detections",
7575
"get_logger",
7676
]

focoos/local_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from focoos.utils.vision import (
4242
image_preprocess,
4343
scale_detections,
44-
sv_to_focoos_detections,
44+
sv_to_fai_detections,
4545
)
4646

4747
logger = get_logger(__name__)
@@ -194,18 +194,18 @@ def infer(
194194
if resize:
195195
detections = scale_detections(detections, (resize, resize), (im0.shape[1], im0.shape[0]))
196196
logger.debug(f"Inference time: {t2 - t1:.3f} seconds")
197-
im = None
198-
if annotate:
199-
im = self._annotate(im0, detections)
200197

201-
out = sv_to_focoos_detections(detections, classes=self.metadata.classes)
198+
out = sv_to_fai_detections(detections, classes=self.metadata.classes)
202199
t3 = perf_counter()
203-
out.latency = {
200+
latency = {
204201
"inference": round(t2 - t1, 3),
205202
"preprocess": round(t1 - t0, 3),
206203
"postprocess": round(t3 - t2, 3),
207204
}
208-
return out, im
205+
im = None
206+
if annotate:
207+
im = self._annotate(im0, detections)
208+
return FocoosDetections(detections=out, latency=latency), im
209209

210210
def benchmark(self, iterations: int, size: int) -> LatencyMetrics:
211211
"""

focoos/ports.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,12 +443,26 @@ class FocoosDet(FocoosBaseModel):
443443
```
444444
"""
445445

446-
bbox: Optional[list[float]] = None
446+
bbox: Optional[list[int]] = None
447447
conf: Optional[float] = None
448448
cls_id: Optional[int] = None
449449
label: Optional[str] = None
450450
mask: Optional[str] = None
451451

452+
@classmethod
453+
def from_json(cls, data: Union[str, dict]):
454+
if isinstance(data, str):
455+
with open(data, encoding="utf-8") as f:
456+
data_dict = json.load(f)
457+
else:
458+
data_dict = data
459+
460+
bbox = data_dict.get("bbox")
461+
if bbox is not None: # Retrocompatibility fix for remote results with float bbox, !TODO remove asap
462+
data_dict["bbox"] = list(map(int, bbox))
463+
464+
return cls.model_validate(data_dict)
465+
452466

453467
class FocoosDetections(FocoosBaseModel):
454468
"""Collection of detection results from a model.

focoos/remote_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from focoos.utils.logger import get_logger
4949
from focoos.utils.metrics import MetricsVisualizer
5050
from focoos.utils.system import HttpClient
51-
from focoos.utils.vision import focoos_detections_to_supervision, image_loader
51+
from focoos.utils.vision import fai_detections_to_sv, image_loader
5252

5353
logger = get_logger()
5454

@@ -299,7 +299,7 @@ def infer(
299299
preview = None
300300
if annotate:
301301
im0 = image_loader(image)
302-
sv_detections = focoos_detections_to_supervision(detections)
302+
sv_detections = fai_detections_to_sv(detections, im0.shape[:-1])
303303
preview = self._annotate(im0, sv_detections)
304304
return detections, preview
305305
else:

focoos/runtime.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
import numpy as np
2626

27+
from focoos.utils.vision import mask_to_xyxy
28+
2729
try:
2830
import torch
2931

@@ -40,6 +42,7 @@
4042

4143
import supervision as sv
4244

45+
# from supervision.detection.utils import mask_to_xyxy
4346
from focoos.ports import (
4447
FocoosTask,
4548
LatencyMetrics,
@@ -108,10 +111,10 @@ def semseg_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_t
108111
masks = masks[high_conf_indices].astype(bool)
109112
cls_ids = cls_ids[high_conf_indices].astype(int)
110113
confs = confs[high_conf_indices].astype(float)
114+
xyxy = mask_to_xyxy(masks)
111115
return sv.Detections(
112116
mask=masks,
113-
# xyxy is required from supervision
114-
xyxy=np.zeros(shape=(len(high_conf_indices), 4), dtype=np.uint8),
117+
xyxy=xyxy,
115118
class_id=cls_ids,
116119
confidence=confs,
117120
)
@@ -128,10 +131,10 @@ def instance_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf
128131
masks = mask[high_conf_indices].astype(bool)
129132
cls_ids = cls_ids[high_conf_indices].astype(int)
130133
confs = confs[high_conf_indices].astype(float)
134+
xyxy = mask_to_xyxy(masks)
131135
return sv.Detections(
132136
mask=masks,
133-
# xyxy is required from supervision
134-
xyxy=np.zeros(shape=(len(high_conf_indices), 4), dtype=np.uint8),
137+
xyxy=xyxy,
135138
class_id=cls_ids,
136139
confidence=confs,
137140
)

focoos/utils/vision.py

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import base64
2-
import io
32
from pathlib import Path
4-
from typing import Optional, Tuple, Union
3+
from typing import List, Optional, Tuple, Union
54

65
import cv2
76
import numpy as np
@@ -129,22 +128,40 @@ def scale_detections(detections: sv.Detections, in_shape: tuple, out_shape: tupl
129128

130129

131130
def base64mask_to_mask(base64mask: str) -> np.ndarray:
132-
return np.array(Image.open(io.BytesIO(base64.b64decode(base64mask))))
131+
"""
132+
Convert a base64-encoded mask to a binary mask using OpenCV.
133+
134+
Args:
135+
base64mask (str): Base64-encoded string representing the mask.
133136
137+
Returns:
138+
np.ndarray: Decoded binary mask as a NumPy array.
139+
"""
140+
# Decode the base64 string to bytes and convert to a NumPy array in one step
141+
np_arr = np.frombuffer(base64.b64decode(base64mask), np.uint8)
142+
# Decode the NumPy array to an image using OpenCV and convert to a binary mask in one step
143+
binary_mask = cv2.imdecode(np_arr, cv2.IMREAD_GRAYSCALE) > 0
144+
return binary_mask.astype(bool)
134145

135-
def focoos_detections_to_supervision(
136-
inference_output: FocoosDetections,
137-
) -> sv.Detections:
146+
147+
def fai_detections_to_sv(inference_output: FocoosDetections, im0_shape: tuple) -> sv.Detections:
138148
xyxy = np.array([d.bbox if d.bbox is not None else np.empty(4) for d in inference_output.detections])
139149
class_id = np.array([d.cls_id for d in inference_output.detections])
140150
confidence = np.array([d.conf for d in inference_output.detections])
141151
if xyxy.shape[0] == 0:
142152
xyxy = np.empty((0, 4))
143153
_masks = []
144-
for det in inference_output.detections:
145-
if det.mask:
146-
mask = base64mask_to_mask(det.mask)
147-
_masks.append(mask)
154+
if len(inference_output.detections) > 0 and inference_output.detections[0].mask:
155+
_masks = [np.zeros(im0_shape, dtype=bool) for _ in inference_output.detections]
156+
for i, det in enumerate(inference_output.detections):
157+
if det.mask:
158+
mask = base64mask_to_mask(det.mask)
159+
if det.bbox is not None and not np.array_equal(det.bbox, [0, 0, 0, 0]):
160+
x1, y1, x2, y2 = map(int, det.bbox)
161+
y2, x2 = min(y2, _masks[i].shape[0]), min(x2, _masks[i].shape[1])
162+
_masks[i][y1:y2, x1:x2] = mask[: y2 - y1, : x2 - x1]
163+
else:
164+
_masks[i] = mask
148165
masks = np.array(_masks).astype(bool) if len(_masks) > 0 else None
149166
return sv.Detections(
150167
xyxy=xyxy,
@@ -156,7 +173,7 @@ def focoos_detections_to_supervision(
156173

157174
def binary_mask_to_base64(binary_mask: np.ndarray) -> str:
158175
"""
159-
Converts a binary mask (NumPy array) to a base64-encoded PNG image.
176+
Converts a binary mask (NumPy array) to a base64-encoded PNG image using OpenCV.
160177
161178
This function takes a binary mask, where values of `True` represent the areas of interest (usually 1s)
162179
and `False` represents the background (usually 0s). The binary mask is then converted to an image,
@@ -168,23 +185,19 @@ def binary_mask_to_base64(binary_mask: np.ndarray) -> str:
168185
Returns:
169186
str: A base64-encoded string representing the PNG image of the binary mask.
170187
"""
171-
# Convert the binary mask to uint8 type, then multiply by 255 to set True values to 255 (white)
172-
# and False values to 0 (black).
173-
binary_mask = binary_mask.astype(np.uint8) * 255
174-
175-
# Create a PIL image from the NumPy array
176-
binary_mask_image = Image.fromarray(binary_mask)
188+
# Directly convert the binary mask to uint8 and multiply by 255 in one step
189+
binary_mask = (binary_mask * 255).astype(np.uint8)
177190

178-
# Save the image to an in-memory buffer as PNG
179-
with io.BytesIO() as buffer:
180-
binary_mask_image.save(buffer, bitmap_format="png", format="PNG")
181-
# Get the PNG image in binary form and encode it to base64
182-
encoded_png = base64.b64encode(buffer.getvalue()).decode("utf-8")
191+
# Use OpenCV to encode the image as PNG
192+
success, encoded_image = cv2.imencode(".png", binary_mask)
193+
if not success:
194+
raise ValueError("Failed to encode image")
183195

184-
return encoded_png
196+
# Encode the image to base64
197+
return base64.b64encode(encoded_image).decode("utf-8")
185198

186199

187-
def sv_to_focoos_detections(detections: sv.Detections, classes: Optional[list[str]] = None) -> FocoosDetections:
200+
def sv_to_fai_detections(detections: sv.Detections, classes: Optional[list[str]] = None) -> List[FocoosDet]:
188201
"""
189202
Convert a list of detections from the supervision format to Focoos detection format.
190203
@@ -213,12 +226,44 @@ def sv_to_focoos_detections(detections: sv.Detections, classes: Optional[list[st
213226
"""
214227
res = []
215228
for xyxy, mask, conf, cls_id, _, _ in detections:
229+
if mask is not None:
230+
cropped_mask = mask[int(xyxy[1]) : int(xyxy[3]), int(xyxy[0]) : int(xyxy[2])]
231+
mask = binary_mask_to_base64(cropped_mask)
216232
det = FocoosDet(
217233
cls_id=int(cls_id) if cls_id is not None else None,
218-
bbox=[round(float(x), 2) for x in xyxy],
219-
mask=binary_mask_to_base64(mask) if mask is not None else None,
234+
bbox=[int(x) for x in xyxy],
235+
mask=mask,
220236
conf=round(float(conf), 2) if conf is not None else None,
221237
label=(classes[cls_id] if classes is not None and cls_id is not None else None),
222238
)
223239
res.append(det)
224-
return FocoosDetections(detections=res)
240+
return res
241+
242+
243+
def mask_to_xyxy(masks: np.ndarray) -> np.ndarray:
244+
"""
245+
Converts a 3D `np.array` of 2D bool masks into a 2D `np.array` of bounding boxes.
246+
247+
Parameters:
248+
masks (np.ndarray): A 3D `np.array` of shape `(N, W, H)`
249+
containing 2D bool masks
250+
251+
Returns:
252+
np.ndarray: A 2D `np.array` of shape `(N, 4)` containing the bounding boxes
253+
`(x_min, y_min, x_max, y_max)` for each mask
254+
"""
255+
# Vectorized approach to find bounding boxes
256+
n = masks.shape[0]
257+
xyxy = np.zeros((n, 4), dtype=int)
258+
259+
# Use np.any to quickly find rows and columns with True values
260+
for i, mask in enumerate(masks):
261+
rows = np.any(mask, axis=1)
262+
cols = np.any(mask, axis=0)
263+
264+
if np.any(rows) and np.any(cols):
265+
y_min, y_max = np.where(rows)[0][[0, -1]]
266+
x_min, x_max = np.where(cols)[0][[0, -1]]
267+
xyxy[i, :] = [x_min, y_min, x_max, y_max]
268+
269+
return xyxy

tests/test_local_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def mock_infer_setup(
188188
mock_scale_detections.return_value = mock_sv_detections
189189

190190
# Mock sv_to_focoos_detections
191-
mock_sv_to_focoos_detections = mocker.patch("focoos.local_model.sv_to_focoos_detections")
192-
mock_sv_to_focoos_detections.return_value = mock_focoos_detections
191+
mock_sv_to_focoos_detections = mocker.patch("focoos.local_model.sv_to_fai_detections")
192+
mock_sv_to_focoos_detections.return_value = mock_focoos_detections.detections
193193

194194
# Mock _annotate
195195
mock_annotate = mocker.patch.object(mock_local_model, "_annotate", autospec=True)
@@ -216,7 +216,7 @@ def __call__(self, *args, **kwargs):
216216

217217

218218
@pytest.mark.parametrize("annotate", [(False, None)])
219-
def test_infer_(
219+
def test_infer_onnx(
220220
mocker,
221221
mock_local_model_onnx,
222222
image_ndarray,

0 commit comments

Comments
 (0)