Skip to content

Commit 14b5c87

Browse files
feat(nodes): simplify MaskFromIDInvocation
1 parent 8d2b4e2 commit 14b5c87

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

invokeai/app/invocations/image.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
10001000
title="Mask from ID",
10011001
tags=["image", "mask", "id"],
10021002
category="image",
1003-
version="1.0.0",
1003+
version="1.0.1",
10041004
)
10051005
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
10061006
"""Generate a mask for a particular color in an ID Map"""
@@ -1010,40 +1010,24 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
10101010
threshold: int = InputField(default=100, description="Threshold for color detection")
10111011
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
10121012

1013-
def rgba_to_hex(self, rgba_color: tuple[int, int, int, int]):
1014-
r, g, b, a = rgba_color
1015-
hex_code = "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, int(a * 255))
1016-
return hex_code
1017-
1018-
def id_to_mask(self, id_mask: Image.Image, color: tuple[int, int, int, int], threshold: int = 100):
1019-
if id_mask.mode != "RGB":
1020-
id_mask = id_mask.convert("RGB")
1013+
def invoke(self, context: InvocationContext) -> ImageOutput:
1014+
image = context.images.get_pil(self.image.image_name, mode="RGBA")
10211015

1022-
# Can directly just use the tuple but I'll leave this rgba_to_hex here
1023-
# incase anyone prefers using hex codes directly instead of the color picker
1024-
hex_color_str = self.rgba_to_hex(color)
1025-
rgb_color = numpy.array([int(hex_color_str[i : i + 2], 16) for i in (1, 3, 5)])
1016+
np_color = numpy.array(self.color.tuple())
10261017

10271018
# Maybe there's a faster way to calculate this distance but I can't think of any right now.
1028-
color_distance = numpy.linalg.norm(id_mask - rgb_color, axis=-1)
1019+
color_distance = numpy.linalg.norm(image - np_color, axis=-1)
10291020

10301021
# Create a mask based on the threshold and the distance calculated above
1031-
binary_mask = (color_distance < threshold).astype(numpy.uint8) * 255
1022+
binary_mask = (color_distance < self.threshold).astype(numpy.uint8) * 255
10321023

10331024
# Convert the mask back to PIL
10341025
binary_mask_pil = Image.fromarray(binary_mask)
10351026

1036-
return binary_mask_pil
1037-
1038-
def invoke(self, context: InvocationContext) -> ImageOutput:
1039-
image = context.images.get_pil(self.image.image_name)
1040-
1041-
mask = self.id_to_mask(image, self.color.tuple(), self.threshold)
1042-
10431027
if self.invert:
1044-
mask = ImageOps.invert(mask)
1028+
binary_mask_pil = ImageOps.invert(binary_mask_pil)
10451029

1046-
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
1030+
image_dto = context.images.save(image=binary_mask_pil, image_category=ImageCategory.MASK)
10471031

10481032
return ImageOutput.build(image_dto)
10491033

0 commit comments

Comments
 (0)