Skip to content

Commit b583276

Browse files
committed
Return a MaskOutput from SegmentAnythingModelInvocation. And add a MaskTensorToImageInvocation.
1 parent fca1197 commit b583276

File tree

2 files changed

+59
-30
lines changed

2 files changed

+59
-30
lines changed

invokeai/app/invocations/mask.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
import torch
3+
from PIL import Image
34

45
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
5-
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
6-
from invokeai.app.invocations.primitives import MaskOutput
6+
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
7+
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
78

89

910
@invocation(
@@ -118,3 +119,28 @@ def invoke(self, context: InvocationContext) -> MaskOutput:
118119
height=mask.shape[1],
119120
width=mask.shape[2],
120121
)
122+
123+
124+
@invocation(
125+
"tensor_mask_to_image",
126+
title="Tensor Mask to Image",
127+
tags=["mask"],
128+
category="mask",
129+
version="1.0.0",
130+
)
131+
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
132+
"""Convert a mask tensor to an image."""
133+
134+
mask: TensorField = InputField(description="The mask tensor to convert.")
135+
136+
def invoke(self, context: InvocationContext) -> ImageOutput:
137+
mask = context.tensors.load(self.mask.tensor_name)
138+
# Ensure that the mask is binary.
139+
if mask.dtype != torch.bool:
140+
mask = mask > 0.5
141+
mask_np = mask.float().cpu().detach().numpy() * 255
142+
mask_np = mask_np.astype(np.uint8)
143+
144+
mask_pil = Image.fromarray(mask_np, mode="L")
145+
image_dto = context.images.save(image=mask_pil)
146+
return ImageOutput.build(image_dto)

invokeai/app/invocations/segment_anything_model.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
from typing import Literal
33

44
import numpy as np
5-
import numpy.typing as npt
65
import torch
76
from PIL import Image
87
from transformers import AutoModelForMaskGeneration, AutoProcessor
98
from transformers.models.sam import SamModel
109
from transformers.models.sam.processing_sam import SamProcessor
1110

1211
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
13-
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
14-
from invokeai.app.invocations.primitives import ImageOutput
12+
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
13+
from invokeai.app.invocations.primitives import MaskOutput
1514
from invokeai.app.services.shared.invocation_context import InvocationContext
1615
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
1716
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
@@ -46,24 +45,22 @@ class SegmentAnythingModelInvocation(BaseInvocation):
4645
)
4746

4847
@torch.no_grad()
49-
def invoke(self, context: InvocationContext) -> ImageOutput:
48+
def invoke(self, context: InvocationContext) -> MaskOutput:
5049
# The models expect a 3-channel RGB image.
5150
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
5251

5352
if len(self.bounding_boxes) == 0:
54-
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
53+
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
5554
else:
5655
masks = self._segment(context=context, image=image_pil)
5756
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
58-
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
59-
combined_mask = np.maximum.reduce(masks)
6057

61-
# Map [0, 1] to [0, 255].
62-
mask_np = combined_mask * 255
63-
mask_pil = Image.fromarray(mask_np)
58+
# masks contains bool values, so we merge them via max-reduce.
59+
combined_mask, _ = torch.stack(masks).max(dim=0)
6460

65-
image_dto = context.images.save(image=mask_pil)
66-
return ImageOutput.build(image_dto)
61+
mask_tensor_name = context.tensors.save(combined_mask)
62+
height, width = combined_mask.shape
63+
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
6764

6865
@staticmethod
6966
def _load_sam_model(model_path: Path):
@@ -84,7 +81,7 @@ def _segment(
8481
self,
8582
context: InvocationContext,
8683
image: Image.Image,
87-
) -> list[npt.NDArray[np.uint8]]:
84+
) -> list[torch.Tensor]:
8885
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
8986
# Convert the bounding boxes to the SAM input format.
9087
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
@@ -97,22 +94,23 @@ def _segment(
9794
assert isinstance(sam_pipeline, SegmentAnythingModel)
9895
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
9996

100-
masks = self._to_numpy_masks(masks)
97+
masks = self._process_masks(masks)
10198
if self.apply_polygon_refinement:
10299
masks = self._apply_polygon_refinement(masks)
103100

104101
return masks
105102

106-
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
107-
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
108-
eps = 0.0001
103+
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
104+
"""Convert the tensor output from the Segment Anything model from a tensor of shape
105+
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
106+
"""
107+
assert masks.dtype == torch.bool
109108
# [num_masks, channels, height, width] -> [num_masks, height, width]
110-
masks = masks.permute(0, 2, 3, 1).float().mean(dim=-1)
111-
masks = masks > eps
112-
np_masks = masks.cpu().numpy().astype(np.uint8)
113-
return list(np_masks)
109+
masks, _ = masks.max(dim=1)
110+
# Split the first dimension into a list of masks.
111+
return list(masks.cpu().unbind(dim=0))
114112

115-
def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]:
113+
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
116114
"""Apply polygon refinement to the masks.
117115
118116
Convert each mask to a polygon, then back to a mask. This has the following effect:
@@ -121,26 +119,31 @@ def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[
121119
- Removes small mask pieces.
122120
- Removes holes from the mask.
123121
"""
124-
for idx, mask in enumerate(masks):
122+
# Convert tensor masks to np masks.
123+
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
124+
125+
# Apply polygon refinement.
126+
for idx, mask in enumerate(np_masks):
125127
shape = mask.shape
126128
assert len(shape) == 2 # Assert length to satisfy type checker.
127129
polygon = mask_to_polygon(mask)
128130
mask = polygon_to_mask(polygon, shape)
129-
masks[idx] = mask
131+
np_masks[idx] = mask
132+
133+
# Convert np masks back to tensor masks.
134+
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
130135

131136
return masks
132137

133-
def _filter_masks(
134-
self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField]
135-
) -> list[npt.NDArray[np.uint8]]:
138+
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
136139
"""Filter the detected masks based on the specified mask filter."""
137140
assert len(masks) == len(bounding_boxes)
138141

139142
if self.mask_filter == "all":
140143
return masks
141144
elif self.mask_filter == "largest":
142145
# Find the largest mask.
143-
return [max(masks, key=lambda x: x.sum())]
146+
return [max(masks, key=lambda x: float(x.sum()))]
144147
elif self.mask_filter == "highest_box_score":
145148
# Find the index of the bounding box with the highest score.
146149
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most

0 commit comments

Comments
 (0)