Skip to content

Commit 4123f5a

Browse files
Merge pull request #75 from FocoosAI/feat/add-instseg-postprocess
feat: add instance segmentation postprocessing support
2 parents b3316a9 + e2d09a5 commit 4123f5a

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

focoos/runtime.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@
5656
logger = get_logger()
5757

5858

59+
def get_postprocess_fn(task: FocoosTask):
60+
if task == FocoosTask.INSTANCE_SEGMENTATION:
61+
return instance_postprocess
62+
elif task == FocoosTask.SEMSEG:
63+
return semseg_postprocess
64+
else:
65+
return det_postprocess
66+
67+
5968
def det_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections:
6069
"""
6170
Postprocesses the output of an object detection model and filters detections
@@ -108,6 +117,26 @@ def semseg_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_t
108117
)
109118

110119

120+
def instance_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections:
121+
"""
122+
Postprocesses the output of an instance segmentation model and filters detections
123+
based on a confidence threshold.
124+
"""
125+
cls_ids, mask, confs = out[0][0], out[1][0], out[2][0]
126+
high_conf_indices = np.where(confs > conf_threshold)[0]
127+
128+
masks = mask[high_conf_indices].astype(bool)
129+
cls_ids = cls_ids[high_conf_indices].astype(int)
130+
confs = confs[high_conf_indices].astype(float)
131+
return sv.Detections(
132+
mask=masks,
133+
# xyxy is required from supervision
134+
xyxy=np.zeros(shape=(len(high_conf_indices), 4), dtype=np.uint8),
135+
class_id=cls_ids,
136+
confidence=confs,
137+
)
138+
139+
111140
class BaseRuntime:
112141
def __init__(self, model_path: str, opts: Any, model_metadata: ModelMetadata):
113142
pass
@@ -136,7 +165,8 @@ def __init__(self, model_path: str, opts: OnnxRuntimeOpts, model_metadata: Model
136165
self.name = Path(model_path).stem
137166
self.opts = opts
138167
self.model_metadata = model_metadata
139-
self.postprocess_fn = det_postprocess if model_metadata.task == FocoosTask.DETECTION else semseg_postprocess
168+
169+
self.postprocess_fn = get_postprocess_fn(model_metadata.task)
140170

141171
# Setup session options
142172
options = ort.SessionOptions()
@@ -264,7 +294,7 @@ def __init__(
264294
self.logger = get_logger(name="TorchscriptEngine")
265295
self.logger.info(f"🔧 [torchscript] Device: {self.device}")
266296
self.opts = opts
267-
self.postprocess_fn = det_postprocess if model_metadata.task == FocoosTask.DETECTION else semseg_postprocess
297+
self.postprocess_fn = get_postprocess_fn(model_metadata.task)
268298

269299
map_location = None if torch.cuda.is_available() else "cpu"
270300

focoos/utils/vision.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def scale_mask(mask: np.ndarray, target_shape: tuple) -> np.ndarray:
109109
"""
110110
# Calculate scale factors for height and width
111111
scale_factors = (target_shape[0] / mask.shape[0], target_shape[1] / mask.shape[1])
112-
113112
# Resize the mask using zoom with nearest-neighbor interpolation (order=0)
114113
scaled_mask = zoom(mask, scale_factors, order=0) > 0.5
115114

@@ -123,6 +122,9 @@ def scale_detections(detections: sv.Detections, in_shape: tuple, out_shape: tupl
123122
x_ratio = out_shape[0] / in_shape[0]
124123
y_ratio = out_shape[1] / in_shape[1]
125124
detections.xyxy = detections.xyxy * np.array([x_ratio, y_ratio, x_ratio, y_ratio])
125+
126+
if detections.mask is not None:
127+
detections.mask = np.array([scale_mask(m, out_shape) for m in detections.mask])
126128
return detections
127129

128130

0 commit comments

Comments
 (0)