Skip to content

Commit e206890

Browse files
committed
Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders.
1 parent 0a7048f commit e206890

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

invokeai/app/invocations/grounded_sam.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,34 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
8484
image_dto = context.images.save(image=mask_pil)
8585
return ImageOutput.build(image_dto)
8686

87+
@staticmethod
88+
def _load_grounding_dino(model_path: Path):
89+
grounding_dino_pipeline = pipeline(
90+
model=str(model_path),
91+
task="zero-shot-object-detection",
92+
local_files_only=True,
93+
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
94+
# model, and figure out how to make it work in the pipeline.
95+
# torch_dtype=TorchDevice.choose_torch_dtype(),
96+
)
97+
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
98+
return GroundingDinoPipeline(grounding_dino_pipeline)
99+
100+
@staticmethod
101+
def _load_sam_model(model_path: Path):
102+
sam_model = AutoModelForMaskGeneration.from_pretrained(
103+
model_path,
104+
local_files_only=True,
105+
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
106+
# model, and figure out how to make it work in the pipeline.
107+
# torch_dtype=TorchDevice.choose_torch_dtype(),
108+
)
109+
assert isinstance(sam_model, SamModel)
110+
111+
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
112+
assert isinstance(sam_processor, SamProcessor)
113+
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
114+
87115
def _detect(
88116
self,
89117
context: InvocationContext,
@@ -96,19 +124,9 @@ def _detect(
96124
# actually makes a difference.
97125
labels = [label if label.endswith(".") else label + "." for label in labels]
98126

99-
def load_grounding_dino(model_path: Path):
100-
grounding_dino_pipeline = pipeline(
101-
model=str(model_path),
102-
task="zero-shot-object-detection",
103-
local_files_only=True,
104-
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
105-
# model, and figure out how to make it work in the pipeline.
106-
# torch_dtype=TorchDevice.choose_torch_dtype(),
107-
)
108-
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
109-
return GroundingDinoPipeline(grounding_dino_pipeline)
110-
111-
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
127+
with context.models.load_remote_model(
128+
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
129+
) as detector:
112130
assert isinstance(detector, GroundingDinoPipeline)
113131
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
114132

@@ -119,26 +137,12 @@ def _segment(
119137
detection_results: list[DetectionResult],
120138
) -> list[DetectionResult]:
121139
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
122-
123-
def load_sam_model(model_path: Path):
124-
sam_model = AutoModelForMaskGeneration.from_pretrained(
125-
model_path,
126-
local_files_only=True,
127-
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
128-
# model, and figure out how to make it work in the pipeline.
129-
# torch_dtype=TorchDevice.choose_torch_dtype(),
130-
)
131-
assert isinstance(sam_model, SamModel)
132-
133-
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
134-
assert isinstance(sam_processor, SamProcessor)
135-
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
136-
137140
with (
138-
context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline,
141+
context.models.load_remote_model(
142+
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
143+
) as sam_pipeline,
139144
):
140145
assert isinstance(sam_pipeline, SegmentAnythingModel)
141-
142146
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
143147

144148
masks = self._to_numpy_masks(masks)

0 commit comments

Comments
 (0)