Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/python/model_api/visualizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# SPDX-License-Identifier: Apache-2.0

from .layout import Flatten, HStack, Layout
from .primitive import BoundingBox, Overlay, Polygon
from .primitive import BoundingBox, Label, Overlay, Polygon
from .scene import Scene
from .visualizer import Visualizer

__all__ = ["BoundingBox", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
__all__ = ["BoundingBox", "Label", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
20 changes: 14 additions & 6 deletions src/python/model_api/visualizer/layout/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Type, Union
from typing import TYPE_CHECKING, Type, Union, cast

from model_api.visualizer.primitive import Label

from .layout import Layout

Expand All @@ -29,13 +31,19 @@ def __init__(self, *args: Union[Type[Primitive], Layout]) -> None:
def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, scene: Scene) -> PIL.Image | None:
if scene.has_primitives(primitive):
primitives = scene.get_primitives(primitive)
for _primitive in primitives:
image = _primitive.compute(image)
if primitive == Label: # Labels need to be rendered next to each other
# cast is needed as mypy does not know that the primitives are of type Label.
primitives_ = cast("list[Label]", primitives)
image = Label.overlay_labels(image, primitives_)
else:
# Other primitives are rendered on top of each other
for _primitive in primitives:
image = _primitive.compute(image)
return image
return None

def __call__(self, scene: Scene) -> PIL.Image:
_image: PIL.Image = scene.base.copy()
image_: PIL.Image = scene.base.copy()
for child in self.children:
_image = child(scene) if isinstance(child, Layout) else self._compute_on_primitive(child, _image, scene)
return _image
image_ = child(scene) if isinstance(child, Layout) else self._compute_on_primitive(child, image_, scene)
return image_
12 changes: 6 additions & 6 deletions src/python/model_api/visualizer/layout/hstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, sc
if scene.has_primitives(primitive):
images = []
for _primitive in scene.get_primitives(primitive):
_image = _primitive.compute(image.copy())
images.append(_image)
image_ = _primitive.compute(image.copy())
images.append(image_)
return self._stitch(*images)
return None

Expand Down Expand Up @@ -70,9 +70,9 @@ def __call__(self, scene: Scene) -> PIL.Image:
images: list[PIL.Image] = []
for child in self.children:
if isinstance(child, Layout):
_image = child(scene)
image_ = child(scene)
else:
_image = self._compute_on_primitive(child, scene.base.copy(), scene)
if _image is not None:
images.append(_image)
image_ = self._compute_on_primitive(child, scene.base.copy(), scene)
if image_ is not None:
images.append(image_)
return self._stitch(*images)
3 changes: 2 additions & 1 deletion src/python/model_api/visualizer/primitive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# SPDX-License-Identifier: Apache-2.0

from .bounding_box import BoundingBox
from .label import Label
from .overlay import Overlay
from .polygon import Polygon
from .primitive import Primitive

__all__ = ["Primitive", "BoundingBox", "Overlay", "Polygon"]
__all__ = ["Primitive", "BoundingBox", "Label", "Overlay", "Polygon"]
106 changes: 106 additions & 0 deletions src/python/model_api/visualizer/primitive/label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Label primitive."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from io import BytesIO
from typing import Union

from PIL import Image, ImageDraw, ImageFont

from .primitive import Primitive


class Label(Primitive):
"""Label primitive.

Labels require a different processing than other primitives as the class also handles the instance when the layout
requests all the labels to be drawn on a single image.

Args:
label (str): Text of the label.
fg_color (str | tuple[int, int, int]): Foreground color of the label.
bg_color (str | tuple[int, int, int]): Background color of the label.
font_path (str | None | BytesIO): Path to the font file.
size (int): Size of the font.

Examples:
>>> label = Label(label="Label 1")
>>> label.compute(image).save("label.jpg")

>>> label = Label(text="Label 1", fg_color="red", bg_color="blue", font_path="arial.ttf", size=20)
>>> label.compute(image).save("label.jpg")

or multiple labels on a single image:
>>> label1 = Label(text="Label 1")
>>> label2 = Label(text="Label 2")
>>> label3 = Label(text="Label 3")
>>> Label.overlay_labels(image, [label1, label2, label3]).save("labels.jpg")
"""

def __init__(
self,
label: str,
fg_color: Union[str, tuple[int, int, int]] = "black",
bg_color: Union[str, tuple[int, int, int]] = "yellow",
font_path: Union[str, BytesIO, None] = None,
size: int = 16,
) -> None:
self.label = label
self.fg_color = fg_color
self.bg_color = bg_color
self.font = ImageFont.load_default(size=size) if font_path is None else ImageFont.truetype(font_path, size)

def compute(self, image: Image, buffer_y: int = 5) -> Image:
"""Generate label on top of the image.

Args:
image (PIL.Image): Image to paste the label on.
buffer_y (int): Buffer to add to the y-axis of the label.
"""
label_image = self.generate_label_image(buffer_y)
image.paste(label_image, (0, 0))
return image

def generate_label_image(self, buffer_y: int = 5) -> Image:
"""Generate label image.

Args:
buffer_y (int): Buffer to add to the y-axis of the label. This is needed as the text is clipped from the
bottom otherwise.

Returns:
PIL.Image: Image that consists only of the label.
"""
dummy_image = Image.new("RGB", (1, 1))
draw = ImageDraw.Draw(dummy_image)
textbox = draw.textbbox((0, 0), self.label, font=self.font)
label_image = Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), self.bg_color)
draw = ImageDraw.Draw(label_image)
draw.text((0, 0), self.label, font=self.font, fill=self.fg_color)
return label_image

@classmethod
def overlay_labels(cls, image: Image, labels: list["Label"], buffer_y: int = 5, buffer_x: int = 5) -> Image:
"""Overlay multiple label images on top of the image.
Paste the labels in a row but wrap the labels if they exceed the image width.

Args:
image (PIL.Image): Image to paste the labels on.
labels (list[Label]): Labels to be pasted on the image.
buffer_y (int): Buffer to add to the y-axis of the labels.
buffer_x (int): Space between the labels.

Returns:
PIL.Image: Image with the labels pasted on it.
"""
offset_x = 0
offset_y = 0
for label in labels:
label_image = label.generate_label_image(buffer_y)
image.paste(label_image, (offset_x, offset_y))
offset_x += label_image.width + buffer_x
if offset_x + label_image.width > image.width:
offset_x = 0
offset_y += label_image.height
return image
60 changes: 54 additions & 6 deletions src/python/model_api/visualizer/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import numpy as np
from PIL import Image

from model_api.visualizer.primitive import Overlay, Primitive
from model_api.visualizer.primitive import BoundingBox, Label, Overlay, Polygon, Primitive

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -27,16 +27,23 @@ class Scene:
def __init__(
self,
base: Image,
bounding_box: BoundingBox | list[BoundingBox] | None = None,
label: Label | list[Label] | None = None,
overlay: Overlay | list[Overlay] | np.ndarray | None = None,
polygon: Polygon | list[Polygon] | None = None,
layout: Layout | None = None,
) -> None:
self.base = base
self.overlay = self._to_overlay(overlay)
self.bounding_box = self._to_bounding_box(bounding_box)
self.label = self._to_label(label)
self.polygon = self._to_polygon(polygon)
self.layout = layout

def show(self) -> Image: ...

def save(self, path: Path) -> None: ...
def save(self, path: Path) -> None:
self.render().save(path)

def render(self) -> Image:
if self.layout is None:
Expand All @@ -46,16 +53,42 @@ def render(self) -> Image:
def has_primitives(self, primitive: type[Primitive]) -> bool:
if primitive == Overlay:
return bool(self.overlay)
if primitive == BoundingBox:
return bool(self.bounding_box)
if primitive == Label:
return bool(self.label)
if primitive == Polygon:
return bool(self.polygon)
return False

def get_primitives(self, primitive: type[Primitive]) -> list[Primitive]:
"""Get primitives of the given type.

Args:
primitive (type[Primitive]): The type of primitive to get.

Example:
>>> scene = Scene(base=Image.new("RGB", (100, 100)), overlay=[Overlay(Image.new("RGB", (100, 100)))])
>>> scene.get_primitives(Overlay)
[Overlay(image=Image.new("RGB", (100, 100)))]

Returns:
list[Primitive]: The primitives of the given type or an empty list if no primitives are found.
"""
primitives: list[Primitive] | None = None
# cast is needed as mypy does not know that the primitives are a subclass of Primitive.
if primitive == Overlay:
primitives = self.overlay # type: ignore[assignment] # TODO(ashwinvaidya17): Address this in the next PR
if primitives is None:
primitives = cast("list[Primitive]", self.overlay)
elif primitive == BoundingBox:
primitives = cast("list[Primitive]", self.bounding_box)
elif primitive == Label:
primitives = cast("list[Primitive]", self.label)
elif primitive == Polygon:
primitives = cast("list[Primitive]", self.polygon)
else:
msg = f"Primitive {primitive} not found"
raise ValueError(msg)
return primitives
return primitives or []

@property
def default_layout(self) -> Layout:
Expand All @@ -70,3 +103,18 @@ def _to_overlay(self, overlay: Overlay | list[Overlay] | np.ndarray | None) -> l
if isinstance(overlay, Overlay):
return [overlay]
return overlay

def _to_bounding_box(self, bounding_box: BoundingBox | list[BoundingBox] | None) -> list[BoundingBox] | None:
if isinstance(bounding_box, BoundingBox):
return [bounding_box]
return bounding_box

def _to_label(self, label: Label | list[Label] | None) -> list[Label] | None:
if isinstance(label, Label):
return [label]
return label

def _to_polygon(self, polygon: Polygon | list[Polygon] | None) -> list[Polygon] | None:
if isinstance(polygon, Polygon):
return [polygon]
return polygon
8 changes: 7 additions & 1 deletion tests/python/unit/visualizer/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import PIL
from PIL import ImageDraw

from model_api.visualizer import BoundingBox, Overlay, Polygon
from model_api.visualizer import BoundingBox, Label, Overlay, Polygon


def test_overlay(mock_image: PIL.Image):
Expand Down Expand Up @@ -51,3 +51,9 @@ def test_polygon(mock_image: PIL.Image):
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
polygon = Polygon(mask=mask, color="red")
assert polygon.compute(mock_image) == expected_image


def test_label(mock_image: PIL.Image):
label = Label(label="Label")
# When using a single label, compute and overlay_labels should return the same image
assert label.compute(mock_image) == Label.overlay_labels(mock_image, [label])
Loading