5
5
import numpy .typing as npt
6
6
import torch
7
7
from PIL import Image
8
- from transformers import AutoModelForMaskGeneration , AutoProcessor , pipeline
8
+ from transformers import AutoModelForMaskGeneration , AutoProcessor
9
9
from transformers .models .sam import SamModel
10
10
from transformers .models .sam .processing_sam import SamProcessor
11
- from transformers .pipelines import ZeroShotObjectDetectionPipeline
12
11
13
12
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
14
- from invokeai .app .invocations .fields import ImageField , InputField
13
+ from invokeai .app .invocations .fields import BoundingBoxField , ImageField , InputField
15
14
from invokeai .app .invocations .primitives import ImageOutput
16
15
from invokeai .app .services .shared .invocation_context import InvocationContext
17
- from invokeai .backend .image_util .grounded_sam .detection_result import DetectionResult
18
- from invokeai .backend .image_util .grounded_sam .grounding_dino_pipeline import GroundingDinoPipeline
19
16
from invokeai .backend .image_util .grounded_sam .mask_refinement import mask_to_polygon , polygon_to_mask
20
17
from invokeai .backend .image_util .grounded_sam .segment_anything_model import SegmentAnythingModel
21
18
22
- GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
23
19
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
24
20
25
21
26
22
@invocation (
27
- "grounded_segment_anything " ,
28
- title = "Segment Anything (Text Prompt) " ,
23
+ "segment_anything_model " ,
24
+ title = "Segment Anything Model " ,
29
25
tags = ["prompt" , "segmentation" ],
30
26
category = "segmentation" ,
31
27
version = "1.0.0" ,
32
28
)
33
- class GroundedSAMInvocation (BaseInvocation ):
34
- """Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
35
-
36
- More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes
37
- are passed as a prompt to a Segment Anything model to obtain a segmentation mask.
29
+ class SegmentAnythingModelInvocation (BaseInvocation ):
30
+ """Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643).
38
31
39
32
Reference:
40
33
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
41
34
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
42
35
"""
43
36
44
- prompt : str = InputField (description = "The prompt describing the object to segment." )
45
37
image : ImageField = InputField (description = "The image to segment." )
38
+ bounding_boxes : list [BoundingBoxField ] = InputField (description = "The bounding boxes to prompt the SAM model with." )
46
39
apply_polygon_refinement : bool = InputField (
47
- description = "Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging)." ,
40
+ description = "Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging)." ,
48
41
default = True ,
49
42
)
50
43
mask_filter : Literal ["all" , "largest" , "highest_box_score" ] = InputField (
51
44
description = "The filtering to apply to the detected masks before merging them into a final output." ,
52
45
default = "all" ,
53
46
)
54
- detection_threshold : float = InputField (
55
- description = "The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used." ,
56
- ge = 0.0 ,
57
- le = 1.0 ,
58
- default = 0.3 ,
59
- )
60
47
61
48
@torch .no_grad ()
62
49
def invoke (self , context : InvocationContext ) -> ImageOutput :
63
50
# The models expect a 3-channel RGB image.
64
51
image_pil = context .images .get_pil (self .image .image_name , mode = "RGB" )
65
52
66
- detections = self ._detect (
67
- context = context , image = image_pil , labels = [self .prompt ], threshold = self .detection_threshold
68
- )
69
-
70
- if len (detections ) == 0 :
53
+ if len (self .bounding_boxes ) == 0 :
71
54
combined_mask = np .zeros (image_pil .size [::- 1 ], dtype = np .uint8 )
72
55
else :
73
- detections = self ._segment (context = context , image = image_pil , detection_results = detections )
74
-
75
- detections = self ._filter_detections (detections )
76
- masks = [detection .mask for detection in detections ]
56
+ masks = self ._segment (context = context , image = image_pil )
57
+ masks = self ._filter_masks (masks = masks , bounding_boxes = self .bounding_boxes )
77
58
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
78
59
combined_mask = np .maximum .reduce (masks )
79
60
@@ -84,19 +65,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
84
65
image_dto = context .images .save (image = mask_pil )
85
66
return ImageOutput .build (image_dto )
86
67
87
- @staticmethod
88
- def _load_grounding_dino (model_path : Path ):
89
- grounding_dino_pipeline = pipeline (
90
- model = str (model_path ),
91
- task = "zero-shot-object-detection" ,
92
- local_files_only = True ,
93
- # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
94
- # model, and figure out how to make it work in the pipeline.
95
- # torch_dtype=TorchDevice.choose_torch_dtype(),
96
- )
97
- assert isinstance (grounding_dino_pipeline , ZeroShotObjectDetectionPipeline )
98
- return GroundingDinoPipeline (grounding_dino_pipeline )
99
-
100
68
@staticmethod
101
69
def _load_sam_model (model_path : Path ):
102
70
sam_model = AutoModelForMaskGeneration .from_pretrained (
@@ -112,47 +80,28 @@ def _load_sam_model(model_path: Path):
112
80
assert isinstance (sam_processor , SamProcessor )
113
81
return SegmentAnythingModel (sam_model = sam_model , sam_processor = sam_processor )
114
82
115
- def _detect (
116
- self ,
117
- context : InvocationContext ,
118
- image : Image .Image ,
119
- labels : list [str ],
120
- threshold : float = 0.3 ,
121
- ) -> list [DetectionResult ]:
122
- """Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
123
- # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
124
- # actually makes a difference.
125
- labels = [label if label .endswith ("." ) else label + "." for label in labels ]
126
-
127
- with context .models .load_remote_model (
128
- source = GROUNDING_DINO_MODEL_ID , loader = GroundedSAMInvocation ._load_grounding_dino
129
- ) as detector :
130
- assert isinstance (detector , GroundingDinoPipeline )
131
- return detector .detect (image = image , candidate_labels = labels , threshold = threshold )
132
-
133
83
def _segment (
134
84
self ,
135
85
context : InvocationContext ,
136
86
image : Image .Image ,
137
- detection_results : list [DetectionResult ],
138
- ) -> list [DetectionResult ]:
87
+ ) -> list [npt .NDArray [np .uint8 ]]:
139
88
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
89
+ # Convert the bounding boxes to the SAM input format.
90
+ sam_bounding_boxes = [[bb .x_min , bb .y_min , bb .x_max , bb .y_max ] for bb in self .bounding_boxes ]
91
+
140
92
with (
141
93
context .models .load_remote_model (
142
- source = SEGMENT_ANYTHING_MODEL_ID , loader = GroundedSAMInvocation ._load_sam_model
94
+ source = SEGMENT_ANYTHING_MODEL_ID , loader = SegmentAnythingModelInvocation ._load_sam_model
143
95
) as sam_pipeline ,
144
96
):
145
97
assert isinstance (sam_pipeline , SegmentAnythingModel )
146
- masks = sam_pipeline .segment (image = image , detection_results = detection_results )
98
+ masks = sam_pipeline .segment (image = image , bounding_boxes = sam_bounding_boxes )
147
99
148
100
masks = self ._to_numpy_masks (masks )
149
101
if self .apply_polygon_refinement :
150
102
masks = self ._apply_polygon_refinement (masks )
151
103
152
- for detection_result , mask in zip (detection_results , masks , strict = True ):
153
- detection_result .mask = mask
154
-
155
- return detection_results
104
+ return masks
156
105
157
106
def _to_numpy_masks (self , masks : torch .Tensor ) -> list [npt .NDArray [np .uint8 ]]:
158
107
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
@@ -181,15 +130,23 @@ def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[
181
130
182
131
return masks
183
132
184
- def _filter_detections (self , detections : list [DetectionResult ]) -> list [DetectionResult ]:
133
+ def _filter_masks (
134
+ self , masks : list [npt .NDArray [np .uint8 ]], bounding_boxes : list [BoundingBoxField ]
135
+ ) -> list [npt .NDArray [np .uint8 ]]:
185
136
"""Filter the detected masks based on the specified mask filter."""
137
+ assert len (masks ) == len (bounding_boxes )
138
+
186
139
if self .mask_filter == "all" :
187
- return detections
140
+ return masks
188
141
elif self .mask_filter == "largest" :
189
142
# Find the largest mask.
190
- return [max (detections , key = lambda x : x . mask .sum ())]
143
+ return [max (masks , key = lambda x : x .sum ())]
191
144
elif self .mask_filter == "highest_box_score" :
192
- # Find the detection with the highest box score.
193
- return [max (detections , key = lambda x : x .score )]
145
+ # Find the index of the bounding box with the highest score.
146
+ # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
147
+ # cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
148
+ # reasonable fallback since the expected score range is [0.0, 1.0].
149
+ max_score_idx = max (range (len (bounding_boxes )), key = lambda i : bounding_boxes [i ].score or - 1.0 )
150
+ return [masks [max_score_idx ]]
194
151
else :
195
152
raise ValueError (f"Invalid mask filter: { self .mask_filter } " )
0 commit comments