Skip to content

Commit 33e8604

Browse files
committed
Make Grounding DINO DetectionResult a Pydantic model.
1 parent cec7399 commit 33e8604

File tree

2 files changed

+8
-19
lines changed

2 files changed

+8
-19
lines changed
Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from dataclasses import dataclass
21
from typing import Any, Optional
32

43
import numpy.typing as npt
4+
from pydantic import BaseModel, ConfigDict
55

66

7-
@dataclass
8-
class BoundingBox:
7+
class BoundingBox(BaseModel):
98
"""Bounding box helper class."""
109

1110
xmin: int
@@ -18,24 +17,14 @@ def to_box(self) -> list[int]:
1817
return [self.xmin, self.ymin, self.xmax, self.ymax]
1918

2019

21-
@dataclass
22-
class DetectionResult:
20+
class DetectionResult(BaseModel):
2321
"""Detection result from Grounding DINO or Grounded SAM."""
2422

2523
score: float
2624
label: str
2725
box: BoundingBox
2826
mask: Optional[npt.NDArray[Any]] = None
29-
30-
@classmethod
31-
def from_dict(cls, detection_dict: dict[str, Any]):
32-
return cls(
33-
score=detection_dict["score"],
34-
label=detection_dict["label"],
35-
box=BoundingBox(
36-
xmin=detection_dict["box"]["xmin"],
37-
ymin=detection_dict["box"]["ymin"],
38-
xmax=detection_dict["box"]["xmax"],
39-
ymax=detection_dict["box"]["ymax"],
40-
),
41-
)
27+
model_config = ConfigDict(
28+
# Allow arbitrary types for mask, since it will be a numpy array.
29+
arbitrary_types_allowed=True
30+
)

invokeai/backend/grounded_sam/grounding_dino_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
1717

1818
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
1919
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
20-
results = [DetectionResult.from_dict(result) for result in results]
20+
results = [DetectionResult.model_validate(result) for result in results]
2121
return results
2222

2323
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":

0 commit comments

Comments
 (0)