Skip to content

Commit 7338682

Browse files
committed
Make GroundingDinoPipeline and SegmentAnythingModel subclasses of RawModel for type checking purposes.
1 parent 9f448fe commit 7338682

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from transformers.pipelines import ZeroShotObjectDetectionPipeline
66

77
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
8+
from invokeai.backend.raw_model import RawModel
89

910

10-
class GroundingDinoPipeline:
11+
class GroundingDinoPipeline(RawModel):
1112
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
1213
management system.
1314
"""
@@ -20,14 +21,13 @@ def detect(self, image: Image.Image, candidate_labels: list[str], threshold: flo
2021
results = [DetectionResult.model_validate(result) for result in results]
2122
return results
2223

23-
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
24+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
2425
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
2526
# CUDA.
2627
if device is not None and device.type not in {"cpu", "cuda"}:
2728
device = None
2829
self._pipeline.model.to(device=device, dtype=dtype)
2930
self._pipeline.device = self._pipeline.model.device
30-
return self
3131

3232
def calc_size(self) -> int:
3333
# HACK(ryand): Fix the circular import issue.

invokeai/backend/image_util/grounded_sam/segment_anything_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@
66
from transformers.models.sam.processing_sam import SamProcessor
77

88
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
9+
from invokeai.backend.raw_model import RawModel
910

1011

11-
class SegmentAnythingModel:
12+
class SegmentAnythingModel(RawModel):
1213
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
1314

1415
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
1516
self._sam_model = sam_model
1617
self._sam_processor = sam_processor
1718

18-
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel":
19+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
1920
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
2021
if device is not None and device.type not in {"cpu", "cuda"}:
2122
device = None
2223
self._sam_model.to(device=device, dtype=dtype)
23-
return self
2424

2525
def calc_size(self) -> int:
2626
# HACK(ryand): Fix the circular import issue.

0 commit comments

Comments
 (0)