|
2 | 2 | import torch |
3 | 3 | from PIL import Image |
4 | 4 |
|
5 | | -from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation |
6 | | -from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata |
7 | | -from invokeai.app.invocations.primitives import ImageOutput, MaskOutput |
| 5 | +from invokeai.app.invocations.baseinvocation import ( |
| 6 | + BaseInvocation, |
| 7 | + Classification, |
| 8 | + InvocationContext, |
| 9 | + invocation, |
| 10 | +) |
| 11 | +from invokeai.app.invocations.fields import ( |
| 12 | + BoundingBoxField, |
| 13 | + ColorField, |
| 14 | + ImageField, |
| 15 | + InputField, |
| 16 | + TensorField, |
| 17 | + WithBoard, |
| 18 | + WithMetadata, |
| 19 | +) |
| 20 | +from invokeai.app.invocations.primitives import BoundingBoxOutput, ImageOutput, MaskOutput |
8 | 21 | from invokeai.backend.image_util.util import pil_to_np |
9 | 22 |
|
10 | 23 |
|
@@ -201,3 +214,47 @@ def invoke(self, context: InvocationContext) -> ImageOutput: |
201 | 214 | image_dto = context.images.save(image=masked_image) |
202 | 215 |
|
203 | 216 | return ImageOutput.build(image_dto) |
| 217 | + |
| 218 | + |
| 219 | +WHITE = ColorField(r=255, g=255, b=255, a=255) |
| 220 | + |
| 221 | + |
| 222 | +@invocation( |
| 223 | + "get_image_mask_bounding_box", |
| 224 | + title="Get Image Mask Bounding Box", |
| 225 | + tags=["mask"], |
| 226 | + category="mask", |
| 227 | + version="1.0.0", |
| 228 | +) |
| 229 | +class GetMaskBoundingBoxInvocation(BaseInvocation, WithMetadata): |
| 230 | + """Gets the bounding box of the given mask image.""" |
| 231 | + |
| 232 | + mask: ImageField = InputField(description="The mask to crop.") |
| 233 | + margin: int = InputField(default=0, description="Margin to add to the bounding box.") |
| 234 | + mask_color: ColorField = InputField(default=WHITE, description="Color of the mask in the image.") |
| 235 | + |
| 236 | + def invoke(self, context: InvocationContext) -> BoundingBoxOutput: |
| 237 | + mask = context.images.get_pil(self.mask.image_name, mode="RGBA") |
| 238 | + mask_np = np.array(mask) |
| 239 | + |
| 240 | + # Convert mask_color to RGBA tuple |
| 241 | + mask_color_rgb = self.mask_color.tuple() |
| 242 | + |
| 243 | + # Find the bounding box of the mask color |
| 244 | + y, x = np.where(np.all(mask_np == mask_color_rgb, axis=-1)) |
| 245 | + |
| 246 | + if len(x) == 0 or len(y) == 0: |
| 247 | + # No pixels found with the given color |
| 248 | + return BoundingBoxOutput(bounding_box=BoundingBoxField(x_min=0, y_min=0, x_max=0, y_max=0)) |
| 249 | + |
| 250 | + left, upper, right, lower = x.min(), y.min(), x.max(), y.max() |
| 251 | + |
| 252 | + # Add the margin |
| 253 | + left = max(0, left - self.margin) |
| 254 | + upper = max(0, upper - self.margin) |
| 255 | + right = min(mask_np.shape[1], right + self.margin) |
| 256 | + lower = min(mask_np.shape[0], lower + self.margin) |
| 257 | + |
| 258 | + bounding_box = BoundingBoxField(x_min=left, y_min=upper, x_max=right, y_max=lower) |
| 259 | + |
| 260 | + return BoundingBoxOutput(bounding_box=bounding_box) |
0 commit comments