Skip to content

Commit 0193267

Browse files
committed
Split GroundedSamInvocation into GroundingDinoInvocation and SegmentAnythingModelInvocation.
1 parent 7338682 commit 0193267

File tree

6 files changed

+180
-93
lines changed

6 files changed

+180
-93
lines changed

invokeai/app/invocations/fields.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,23 @@ class ConditioningField(BaseModel):
242242
)
243243

244244

245+
class BoundingBoxField(BaseModel):
246+
"""A bounding box primitive value."""
247+
248+
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
249+
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
250+
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
251+
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
252+
253+
score: Optional[float] = Field(
254+
default=None,
255+
ge=0.0,
256+
le=1.0,
257+
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
258+
"when the bounding box was produced by a detector and has an associated confidence score.",
259+
)
260+
261+
245262
class MetadataField(RootModel[dict[str, Any]]):
246263
"""
247264
Pydantic model for metadata with custom root of type dict[str, Any].
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from PIL import Image
5+
from transformers import pipeline
6+
from transformers.pipelines import ZeroShotObjectDetectionPipeline
7+
8+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
9+
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
10+
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
11+
from invokeai.app.services.shared.invocation_context import InvocationContext
12+
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
13+
from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
14+
15+
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
16+
17+
18+
@invocation(
19+
"grounding_dino",
20+
title="Grounding DINO (Text Prompt Object Detection)",
21+
tags=["prompt", "object detection"],
22+
category="image",
23+
version="1.0.0",
24+
)
25+
class GroundingDinoInvocation(BaseInvocation):
26+
"""Runs a Grounding DINO model (https://arxiv.org/pdf/2303.05499). Performs zero-shot bounding-box object detection
27+
from a text prompt.
28+
29+
Reference:
30+
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
31+
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
32+
"""
33+
34+
prompt: str = InputField(description="The prompt describing the object to segment.")
35+
image: ImageField = InputField(description="The image to segment.")
36+
detection_threshold: float = InputField(
37+
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
38+
ge=0.0,
39+
le=1.0,
40+
default=0.3,
41+
)
42+
43+
@torch.no_grad()
44+
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
45+
# The model expects a 3-channel RGB image.
46+
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
47+
48+
detections = self._detect(
49+
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
50+
)
51+
52+
# Convert detections to BoundingBoxCollectionOutput.
53+
bounding_boxes: list[BoundingBoxField] = []
54+
for detection in detections:
55+
bounding_boxes.append(
56+
BoundingBoxField(
57+
x_min=detection.box.xmin,
58+
x_max=detection.box.xmax,
59+
y_min=detection.box.ymin,
60+
y_max=detection.box.ymax,
61+
score=detection.score,
62+
)
63+
)
64+
return BoundingBoxCollectionOutput(collection=bounding_boxes)
65+
66+
@staticmethod
67+
def _load_grounding_dino(model_path: Path):
68+
grounding_dino_pipeline = pipeline(
69+
model=str(model_path),
70+
task="zero-shot-object-detection",
71+
local_files_only=True,
72+
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
73+
# model, and figure out how to make it work in the pipeline.
74+
# torch_dtype=TorchDevice.choose_torch_dtype(),
75+
)
76+
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
77+
return GroundingDinoPipeline(grounding_dino_pipeline)
78+
79+
def _detect(
80+
self,
81+
context: InvocationContext,
82+
image: Image.Image,
83+
labels: list[str],
84+
threshold: float = 0.3,
85+
) -> list[DetectionResult]:
86+
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
87+
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
88+
# actually makes a difference.
89+
labels = [label if label.endswith(".") else label + "." for label in labels]
90+
91+
with context.models.load_remote_model(
92+
source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino
93+
) as detector:
94+
assert isinstance(detector, GroundingDinoPipeline)
95+
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

invokeai/app/invocations/primitives.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
88
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
99
from invokeai.app.invocations.fields import (
10+
BoundingBoxField,
1011
ColorField,
1112
ConditioningField,
1213
DenoiseMaskField,
@@ -469,3 +470,24 @@ def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
469470

470471

471472
# endregion
473+
474+
# region BoundingBox
475+
476+
477+
@invocation_output("bounding_box_output")
478+
class BoundingBoxOutput(BaseInvocationOutput):
479+
"""Base class for nodes that output a single bounding box"""
480+
481+
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
482+
483+
484+
@invocation_output("bounding_box_collection_output")
485+
class BoundingBoxCollectionOutput(BaseInvocationOutput):
486+
"""Base class for nodes that output a collection of bounding boxes"""
487+
488+
collection: list[BoundingBoxField] = OutputField(
489+
description="The output bounding boxes.",
490+
)
491+
492+
493+
# endregion

invokeai/app/invocations/grounded_sam.py renamed to invokeai/app/invocations/segment_anything_model.py

Lines changed: 31 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,75 +5,56 @@
55
import numpy.typing as npt
66
import torch
77
from PIL import Image
8-
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
8+
from transformers import AutoModelForMaskGeneration, AutoProcessor
99
from transformers.models.sam import SamModel
1010
from transformers.models.sam.processing_sam import SamProcessor
11-
from transformers.pipelines import ZeroShotObjectDetectionPipeline
1211

1312
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
14-
from invokeai.app.invocations.fields import ImageField, InputField
13+
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
1514
from invokeai.app.invocations.primitives import ImageOutput
1615
from invokeai.app.services.shared.invocation_context import InvocationContext
17-
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
18-
from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
1916
from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
2017
from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel
2118

22-
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
2319
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
2420

2521

2622
@invocation(
27-
"grounded_segment_anything",
28-
title="Segment Anything (Text Prompt)",
23+
"segment_anything_model",
24+
title="Segment Anything Model",
2925
tags=["prompt", "segmentation"],
3026
category="segmentation",
3127
version="1.0.0",
3228
)
33-
class GroundedSAMInvocation(BaseInvocation):
34-
"""Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
35-
36-
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes
37-
are passed as a prompt to a Segment Anything model to obtain a segmentation mask.
29+
class SegmentAnythingModelInvocation(BaseInvocation):
30+
"""Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643).
3831
3932
Reference:
4033
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
4134
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
4235
"""
4336

44-
prompt: str = InputField(description="The prompt describing the object to segment.")
4537
image: ImageField = InputField(description="The image to segment.")
38+
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
4639
apply_polygon_refinement: bool = InputField(
47-
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).",
40+
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
4841
default=True,
4942
)
5043
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
5144
description="The filtering to apply to the detected masks before merging them into a final output.",
5245
default="all",
5346
)
54-
detection_threshold: float = InputField(
55-
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
56-
ge=0.0,
57-
le=1.0,
58-
default=0.3,
59-
)
6047

6148
@torch.no_grad()
6249
def invoke(self, context: InvocationContext) -> ImageOutput:
6350
# The models expect a 3-channel RGB image.
6451
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
6552

66-
detections = self._detect(
67-
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
68-
)
69-
70-
if len(detections) == 0:
53+
if len(self.bounding_boxes) == 0:
7154
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
7255
else:
73-
detections = self._segment(context=context, image=image_pil, detection_results=detections)
74-
75-
detections = self._filter_detections(detections)
76-
masks = [detection.mask for detection in detections]
56+
masks = self._segment(context=context, image=image_pil)
57+
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
7758
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
7859
combined_mask = np.maximum.reduce(masks)
7960

@@ -84,19 +65,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
8465
image_dto = context.images.save(image=mask_pil)
8566
return ImageOutput.build(image_dto)
8667

87-
@staticmethod
88-
def _load_grounding_dino(model_path: Path):
89-
grounding_dino_pipeline = pipeline(
90-
model=str(model_path),
91-
task="zero-shot-object-detection",
92-
local_files_only=True,
93-
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
94-
# model, and figure out how to make it work in the pipeline.
95-
# torch_dtype=TorchDevice.choose_torch_dtype(),
96-
)
97-
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
98-
return GroundingDinoPipeline(grounding_dino_pipeline)
99-
10068
@staticmethod
10169
def _load_sam_model(model_path: Path):
10270
sam_model = AutoModelForMaskGeneration.from_pretrained(
@@ -112,47 +80,28 @@ def _load_sam_model(model_path: Path):
11280
assert isinstance(sam_processor, SamProcessor)
11381
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
11482

115-
def _detect(
116-
self,
117-
context: InvocationContext,
118-
image: Image.Image,
119-
labels: list[str],
120-
threshold: float = 0.3,
121-
) -> list[DetectionResult]:
122-
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
123-
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
124-
# actually makes a difference.
125-
labels = [label if label.endswith(".") else label + "." for label in labels]
126-
127-
with context.models.load_remote_model(
128-
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
129-
) as detector:
130-
assert isinstance(detector, GroundingDinoPipeline)
131-
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
132-
13383
def _segment(
13484
self,
13585
context: InvocationContext,
13686
image: Image.Image,
137-
detection_results: list[DetectionResult],
138-
) -> list[DetectionResult]:
87+
) -> list[npt.NDArray[np.uint8]]:
13988
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
89+
# Convert the bounding boxes to the SAM input format.
90+
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
91+
14092
with (
14193
context.models.load_remote_model(
142-
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
94+
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model
14395
) as sam_pipeline,
14496
):
14597
assert isinstance(sam_pipeline, SegmentAnythingModel)
146-
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
98+
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
14799

148100
masks = self._to_numpy_masks(masks)
149101
if self.apply_polygon_refinement:
150102
masks = self._apply_polygon_refinement(masks)
151103

152-
for detection_result, mask in zip(detection_results, masks, strict=True):
153-
detection_result.mask = mask
154-
155-
return detection_results
104+
return masks
156105

157106
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
158107
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
@@ -181,15 +130,23 @@ def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[
181130

182131
return masks
183132

184-
def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]:
133+
def _filter_masks(
134+
self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField]
135+
) -> list[npt.NDArray[np.uint8]]:
185136
"""Filter the detected masks based on the specified mask filter."""
137+
assert len(masks) == len(bounding_boxes)
138+
186139
if self.mask_filter == "all":
187-
return detections
140+
return masks
188141
elif self.mask_filter == "largest":
189142
# Find the largest mask.
190-
return [max(detections, key=lambda x: x.mask.sum())]
143+
return [max(masks, key=lambda x: x.sum())]
191144
elif self.mask_filter == "highest_box_score":
192-
# Find the detection with the highest box score.
193-
return [max(detections, key=lambda x: x.score)]
145+
# Find the index of the bounding box with the highest score.
146+
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
147+
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
148+
# reasonable fallback since the expected score range is [0.0, 1.0].
149+
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
150+
return [masks[max_score_idx]]
194151
else:
195152
raise ValueError(f"Invalid mask filter: {self.mask_filter}")

invokeai/backend/image_util/grounded_sam/detection_result.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from typing import Any, Optional
2-
3-
import numpy.typing as npt
41
from pydantic import BaseModel, ConfigDict
52

63

@@ -12,18 +9,13 @@ class BoundingBox(BaseModel):
129
xmax: int
1310
ymax: int
1411

15-
def to_box(self) -> list[int]:
16-
"""Convert to the array notation expected by SAM."""
17-
return [self.xmin, self.ymin, self.xmax, self.ymax]
18-
1912

2013
class DetectionResult(BaseModel):
21-
"""Detection result from Grounding DINO or Grounded SAM."""
14+
"""Detection result from Grounding DINO."""
2215

2316
score: float
2417
label: str
2518
box: BoundingBox
26-
mask: Optional[npt.NDArray[Any]] = None
2719
model_config = ConfigDict(
2820
# Allow arbitrary types for mask, since it will be a numpy array.
2921
arbitrary_types_allowed=True

0 commit comments

Comments
 (0)