2
2
from typing import Literal
3
3
4
4
import numpy as np
5
- import numpy .typing as npt
6
5
import torch
7
6
from PIL import Image
8
7
from transformers import AutoModelForMaskGeneration , AutoProcessor
9
8
from transformers .models .sam import SamModel
10
9
from transformers .models .sam .processing_sam import SamProcessor
11
10
12
11
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
13
- from invokeai .app .invocations .fields import BoundingBoxField , ImageField , InputField
14
- from invokeai .app .invocations .primitives import ImageOutput
12
+ from invokeai .app .invocations .fields import BoundingBoxField , ImageField , InputField , TensorField
13
+ from invokeai .app .invocations .primitives import MaskOutput
15
14
from invokeai .app .services .shared .invocation_context import InvocationContext
16
15
from invokeai .backend .image_util .segment_anything .mask_refinement import mask_to_polygon , polygon_to_mask
17
16
from invokeai .backend .image_util .segment_anything .segment_anything_model import SegmentAnythingModel
@@ -46,24 +45,22 @@ class SegmentAnythingModelInvocation(BaseInvocation):
46
45
)
47
46
48
47
@torch .no_grad ()
49
- def invoke (self , context : InvocationContext ) -> ImageOutput :
48
+ def invoke (self , context : InvocationContext ) -> MaskOutput :
50
49
# The models expect a 3-channel RGB image.
51
50
image_pil = context .images .get_pil (self .image .image_name , mode = "RGB" )
52
51
53
52
if len (self .bounding_boxes ) == 0 :
54
- combined_mask = np .zeros (image_pil .size [::- 1 ], dtype = np . uint8 )
53
+ combined_mask = torch .zeros (image_pil .size [::- 1 ], dtype = torch . bool )
55
54
else :
56
55
masks = self ._segment (context = context , image = image_pil )
57
56
masks = self ._filter_masks (masks = masks , bounding_boxes = self .bounding_boxes )
58
- # masks contains binary values of 0 or 1, so we merge them via max-reduce.
59
- combined_mask = np .maximum .reduce (masks )
60
57
61
- # Map [0, 1] to [0, 255].
62
- mask_np = combined_mask * 255
63
- mask_pil = Image .fromarray (mask_np )
58
+ # masks contains bool values, so we merge them via max-reduce.
59
+ combined_mask , _ = torch .stack (masks ).max (dim = 0 )
64
60
65
- image_dto = context .images .save (image = mask_pil )
66
- return ImageOutput .build (image_dto )
61
+ mask_tensor_name = context .tensors .save (combined_mask )
62
+ height , width = combined_mask .shape
63
+ return MaskOutput (mask = TensorField (tensor_name = mask_tensor_name ), width = width , height = height )
67
64
68
65
@staticmethod
69
66
def _load_sam_model (model_path : Path ):
@@ -84,7 +81,7 @@ def _segment(
84
81
self ,
85
82
context : InvocationContext ,
86
83
image : Image .Image ,
87
- ) -> list [npt . NDArray [ np . uint8 ] ]:
84
+ ) -> list [torch . Tensor ]:
88
85
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
89
86
# Convert the bounding boxes to the SAM input format.
90
87
sam_bounding_boxes = [[bb .x_min , bb .y_min , bb .x_max , bb .y_max ] for bb in self .bounding_boxes ]
@@ -97,22 +94,23 @@ def _segment(
97
94
assert isinstance (sam_pipeline , SegmentAnythingModel )
98
95
masks = sam_pipeline .segment (image = image , bounding_boxes = sam_bounding_boxes )
99
96
100
- masks = self ._to_numpy_masks (masks )
97
+ masks = self ._process_masks (masks )
101
98
if self .apply_polygon_refinement :
102
99
masks = self ._apply_polygon_refinement (masks )
103
100
104
101
return masks
105
102
106
- def _to_numpy_masks (self , masks : torch .Tensor ) -> list [npt .NDArray [np .uint8 ]]:
107
- """Convert the tensor output from the Segment Anything model to a list of numpy masks."""
108
- eps = 0.0001
103
+ def _process_masks (self , masks : torch .Tensor ) -> list [torch .Tensor ]:
104
+ """Convert the tensor output from the Segment Anything model from a tensor of shape
105
+ [num_masks, channels, height, width] to a list of tensors of shape [height, width].
106
+ """
107
+ assert masks .dtype == torch .bool
109
108
# [num_masks, channels, height, width] -> [num_masks, height, width]
110
- masks = masks .permute (0 , 2 , 3 , 1 ).float ().mean (dim = - 1 )
111
- masks = masks > eps
112
- np_masks = masks .cpu ().numpy ().astype (np .uint8 )
113
- return list (np_masks )
109
+ masks , _ = masks .max (dim = 1 )
110
+ # Split the first dimension into a list of masks.
111
+ return list (masks .cpu ().unbind (dim = 0 ))
114
112
115
- def _apply_polygon_refinement (self , masks : list [npt . NDArray [ np . uint8 ]] ) -> list [npt . NDArray [ np . uint8 ] ]:
113
+ def _apply_polygon_refinement (self , masks : list [torch . Tensor ] ) -> list [torch . Tensor ]:
116
114
"""Apply polygon refinement to the masks.
117
115
118
116
Convert each mask to a polygon, then back to a mask. This has the following effect:
@@ -121,26 +119,31 @@ def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[
121
119
- Removes small mask pieces.
122
120
- Removes holes from the mask.
123
121
"""
124
- for idx , mask in enumerate (masks ):
122
+ # Convert tensor masks to np masks.
123
+ np_masks = [mask .cpu ().numpy ().astype (np .uint8 ) for mask in masks ]
124
+
125
+ # Apply polygon refinement.
126
+ for idx , mask in enumerate (np_masks ):
125
127
shape = mask .shape
126
128
assert len (shape ) == 2 # Assert length to satisfy type checker.
127
129
polygon = mask_to_polygon (mask )
128
130
mask = polygon_to_mask (polygon , shape )
129
- masks [idx ] = mask
131
+ np_masks [idx ] = mask
132
+
133
+ # Convert np masks back to tensor masks.
134
+ masks = [torch .tensor (mask , dtype = torch .bool ) for mask in np_masks ]
130
135
131
136
return masks
132
137
133
- def _filter_masks (
134
- self , masks : list [npt .NDArray [np .uint8 ]], bounding_boxes : list [BoundingBoxField ]
135
- ) -> list [npt .NDArray [np .uint8 ]]:
138
+ def _filter_masks (self , masks : list [torch .Tensor ], bounding_boxes : list [BoundingBoxField ]) -> list [torch .Tensor ]:
136
139
"""Filter the detected masks based on the specified mask filter."""
137
140
assert len (masks ) == len (bounding_boxes )
138
141
139
142
if self .mask_filter == "all" :
140
143
return masks
141
144
elif self .mask_filter == "largest" :
142
145
# Find the largest mask.
143
- return [max (masks , key = lambda x : x .sum ())]
146
+ return [max (masks , key = lambda x : float ( x .sum () ))]
144
147
elif self .mask_filter == "highest_box_score" :
145
148
# Find the index of the bounding box with the highest score.
146
149
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
0 commit comments