Skip to content

Commit 8a0a371

Browse files
feat(nodes): add GetMaskBoundingBoxInvocation
1 parent 7dbd5f1 commit 8a0a371

File tree

1 file changed

+60
-3
lines changed

1 file changed

+60
-3
lines changed

invokeai/app/invocations/mask.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,22 @@
22
import torch
33
from PIL import Image
44

5-
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
6-
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
7-
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
5+
from invokeai.app.invocations.baseinvocation import (
6+
BaseInvocation,
7+
Classification,
8+
InvocationContext,
9+
invocation,
10+
)
11+
from invokeai.app.invocations.fields import (
12+
BoundingBoxField,
13+
ColorField,
14+
ImageField,
15+
InputField,
16+
TensorField,
17+
WithBoard,
18+
WithMetadata,
19+
)
20+
from invokeai.app.invocations.primitives import BoundingBoxOutput, ImageOutput, MaskOutput
821
from invokeai.backend.image_util.util import pil_to_np
922

1023

@@ -201,3 +214,47 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
201214
image_dto = context.images.save(image=masked_image)
202215

203216
return ImageOutput.build(image_dto)
217+
218+
219+
WHITE = ColorField(r=255, g=255, b=255, a=255)
220+
221+
222+
@invocation(
223+
"get_image_mask_bounding_box",
224+
title="Get Image Mask Bounding Box",
225+
tags=["mask"],
226+
category="mask",
227+
version="1.0.0",
228+
)
229+
class GetMaskBoundingBoxInvocation(BaseInvocation, WithMetadata):
230+
"""Gets the bounding box of the given mask image."""
231+
232+
mask: ImageField = InputField(description="The mask to crop.")
233+
margin: int = InputField(default=0, description="Margin to add to the bounding box.")
234+
mask_color: ColorField = InputField(default=WHITE, description="Color of the mask in the image.")
235+
236+
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
237+
mask = context.images.get_pil(self.mask.image_name, mode="RGBA")
238+
mask_np = np.array(mask)
239+
240+
# Convert mask_color to RGBA tuple
241+
mask_color_rgb = self.mask_color.tuple()
242+
243+
# Find the bounding box of the mask color
244+
y, x = np.where(np.all(mask_np == mask_color_rgb, axis=-1))
245+
246+
if len(x) == 0 or len(y) == 0:
247+
# No pixels found with the given color
248+
return BoundingBoxOutput(bounding_box=BoundingBoxField(x_min=0, y_min=0, x_max=0, y_max=0))
249+
250+
left, upper, right, lower = x.min(), y.min(), x.max(), y.max()
251+
252+
# Add the margin
253+
left = max(0, left - self.margin)
254+
upper = max(0, upper - self.margin)
255+
right = min(mask_np.shape[1], right + self.margin)
256+
lower = min(mask_np.shape[0], lower + self.margin)
257+
258+
bounding_box = BoundingBoxField(x_min=left, y_min=upper, x_max=right, y_max=lower)
259+
260+
return BoundingBoxOutput(bounding_box=bounding_box)

0 commit comments

Comments
 (0)