Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions kauldron/summaries/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get_state(
return self.State(images=images)


# TODO(klausg): The use of rearrange is weird here. maybe move to contrib?
@dataclasses.dataclass(kw_only=True, frozen=True)
class ShowBoxes(metrics.Metric):
"""Show a set of boxes with optional image reshaping.
Expand All @@ -108,9 +107,6 @@ class ShowBoxes(metrics.Metric):
num_colors: Number of different colors to use for the boxes. Default 16.
in_vrange: Optional value range of the input images. Used to clip and then
rescale the images to [0, 1].
rearrange: Optional einops string to reshape the images AFTER the boxes have
been drawn.
rearrange_kwargs: Optional keyword arguments for the einops reshape.
"""

images: kontext.Key = kontext.REQUIRED
Expand All @@ -121,22 +117,19 @@ class ShowBoxes(metrics.Metric):
num_colors: int = 16
in_vrange: Optional[tuple[float, float]] = None

rearrange: Optional[str] = None
rearrange_kwargs: Mapping[str, Any] | None = None

@struct.dataclass
class State(metrics.AutoState["ShowBoxes"]):
"""Collects the first num_images images and boxes."""

images: Float["n h w #3"] = metrics.truncate_field(
images: Float["*b h w #3"] = metrics.truncate_field(
num_field="parent.num_images"
)
boxes: Float["n k 4"] = metrics.truncate_field(
boxes: Float["*b k 4"] = metrics.truncate_field(
num_field="parent.num_images"
)

@typechecked
def compute(self) -> Float["n h w #3"]:
def compute(self) -> Float["*b h w #3"]:
data = super().compute()
images, boxes = data.images, data.boxes

Expand All @@ -152,11 +145,7 @@ def compute(self) -> Float["n h w #3"]:
colors = _get_uniform_colors(self.parent.num_colors)
images = tf.image.draw_bounding_boxes(images, boxes, colors)

# Note: rearrange is applied AFTER the boxes are drawn.
images = np.reshape(images, images_shape)
images = _maybe_rearrange(
images, self.parent.rearrange, self.parent.rearrange_kwargs
)

# always clip to avoid display problems in TB and Datatables
return np.clip(images, 0.0, 1.0)
Expand Down Expand Up @@ -338,4 +327,4 @@ def _get_uniform_colors(n_colors: int) -> Array:
(np.expand_dims(hues, axis=1), np.ones((n_colors, 2))), axis=1
)
rgb_colors = matplotlib.colors.hsv_to_rgb(hsv_colors)
return rgb_colors # rgb_colors.shape = (n_colors, 3)
return rgb_colors # rgb_colors.shape = (n_colors, 3)