|
56 | 56 | logger = get_logger() |
57 | 57 |
|
58 | 58 |
|
| 59 | +def get_postprocess_fn(task: FocoosTask): |
| 60 | + if task == FocoosTask.INSTANCE_SEGMENTATION: |
| 61 | + return instance_postprocess |
| 62 | + elif task == FocoosTask.SEMSEG: |
| 63 | + return semseg_postprocess |
| 64 | + else: |
| 65 | + return det_postprocess |
| 66 | + |
| 67 | + |
59 | 68 | def det_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: |
60 | 69 | """ |
61 | 70 | Postprocesses the output of an object detection model and filters detections |
@@ -108,6 +117,26 @@ def semseg_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_t |
108 | 117 | ) |
109 | 118 |
|
110 | 119 |
|
| 120 | +def instance_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: |
| 121 | + """ |
| 122 | + Postprocesses the output of an instance segmentation model and filters detections |
| 123 | + based on a confidence threshold. |
| 124 | + """ |
| 125 | + cls_ids, mask, confs = out[0][0], out[1][0], out[2][0] |
| 126 | + high_conf_indices = np.where(confs > conf_threshold)[0] |
| 127 | + |
| 128 | + masks = mask[high_conf_indices].astype(bool) |
| 129 | + cls_ids = cls_ids[high_conf_indices].astype(int) |
| 130 | + confs = confs[high_conf_indices].astype(float) |
| 131 | + return sv.Detections( |
| 132 | + mask=masks, |
| 133 | + # xyxy is required from supervision |
| 134 | + xyxy=np.zeros(shape=(len(high_conf_indices), 4), dtype=np.uint8), |
| 135 | + class_id=cls_ids, |
| 136 | + confidence=confs, |
| 137 | + ) |
| 138 | + |
| 139 | + |
111 | 140 | class BaseRuntime: |
112 | 141 | def __init__(self, model_path: str, opts: Any, model_metadata: ModelMetadata): |
113 | 142 | pass |
@@ -136,7 +165,8 @@ def __init__(self, model_path: str, opts: OnnxRuntimeOpts, model_metadata: Model |
136 | 165 | self.name = Path(model_path).stem |
137 | 166 | self.opts = opts |
138 | 167 | self.model_metadata = model_metadata |
139 | | - self.postprocess_fn = det_postprocess if model_metadata.task == FocoosTask.DETECTION else semseg_postprocess |
| 168 | + |
| 169 | + self.postprocess_fn = get_postprocess_fn(model_metadata.task) |
140 | 170 |
|
141 | 171 | # Setup session options |
142 | 172 | options = ort.SessionOptions() |
@@ -264,7 +294,7 @@ def __init__( |
264 | 294 | self.logger = get_logger(name="TorchscriptEngine") |
265 | 295 | self.logger.info(f"🔧 [torchscript] Device: {self.device}") |
266 | 296 | self.opts = opts |
267 | | - self.postprocess_fn = det_postprocess if model_metadata.task == FocoosTask.DETECTION else semseg_postprocess |
| 297 | + self.postprocess_fn = get_postprocess_fn(model_metadata.task) |
268 | 298 |
|
269 | 299 | map_location = None if torch.cuda.is_available() else "cpu" |
270 | 300 |
|
|
0 commit comments