Skip to content

Commit ddda5e6

Browse files
refactor label scores
Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent d75be73 commit ddda5e6

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/python/model_api/visualizer/primitive/label.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class Label(Primitive):
1919
2020
Args:
2121
label (str): Text of the label.
22+
score (float | None): Score of the label. This is optional.
2223
fg_color (str | tuple[int, int, int]): Foreground color of the label.
2324
bg_color (str | tuple[int, int, int]): Background color of the label.
2425
font_path (str | None | BytesIO): Path to the font file.
@@ -40,13 +41,14 @@ class Label(Primitive):
4041

4142
def __init__(
4243
self,
43-
label: Union[str, float],
44+
label: str,
45+
score: Union[float, None] = None,
4446
fg_color: Union[str, tuple[int, int, int]] = "black",
4547
bg_color: Union[str, tuple[int, int, int]] = "yellow",
4648
font_path: Union[str, BytesIO, None] = None,
4749
size: int = 16,
4850
) -> None:
49-
self.label = str(label)
51+
self.label = f"{label} ({score:.2f})" if score is not None else label
5052
self.fg_color = fg_color
5153
self.bg_color = bg_color
5254
self.font = ImageFont.load_default(size=size) if font_path is None else ImageFont.truetype(font_path, size)

src/python/model_api/visualizer/scene/anomaly.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,8 @@ def _get_bounding_boxes(self, result: AnomalyResult) -> list[BoundingBox]:
4242

4343
def _get_labels(self, result: AnomalyResult) -> list[Label]:
4444
labels = []
45-
if result.pred_label is not None:
46-
labels.append(Label(result.pred_label))
47-
if result.pred_score is not None:
48-
labels.append(Label(result.pred_score))
45+
if result.pred_label is not None and result.pred_score is not None:
46+
labels.append(Label(label=result.pred_label, score=result.pred_score))
4947
return labels
5048

5149
def _get_polygons(self, result: AnomalyResult) -> list[Polygon]:

0 commit comments

Comments
 (0)