@@ -84,6 +84,34 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
84
84
image_dto = context .images .save (image = mask_pil )
85
85
return ImageOutput .build (image_dto )
86
86
87
+ @staticmethod
88
+ def _load_grounding_dino (model_path : Path ):
89
+ grounding_dino_pipeline = pipeline (
90
+ model = str (model_path ),
91
+ task = "zero-shot-object-detection" ,
92
+ local_files_only = True ,
93
+ # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
94
+ # model, and figure out how to make it work in the pipeline.
95
+ # torch_dtype=TorchDevice.choose_torch_dtype(),
96
+ )
97
+ assert isinstance (grounding_dino_pipeline , ZeroShotObjectDetectionPipeline )
98
+ return GroundingDinoPipeline (grounding_dino_pipeline )
99
+
100
+ @staticmethod
101
+ def _load_sam_model (model_path : Path ):
102
+ sam_model = AutoModelForMaskGeneration .from_pretrained (
103
+ model_path ,
104
+ local_files_only = True ,
105
+ # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
106
+ # model, and figure out how to make it work in the pipeline.
107
+ # torch_dtype=TorchDevice.choose_torch_dtype(),
108
+ )
109
+ assert isinstance (sam_model , SamModel )
110
+
111
+ sam_processor = AutoProcessor .from_pretrained (model_path , local_files_only = True )
112
+ assert isinstance (sam_processor , SamProcessor )
113
+ return SegmentAnythingModel (sam_model = sam_model , sam_processor = sam_processor )
114
+
87
115
def _detect (
88
116
self ,
89
117
context : InvocationContext ,
@@ -96,19 +124,9 @@ def _detect(
96
124
# actually makes a difference.
97
125
labels = [label if label .endswith ("." ) else label + "." for label in labels ]
98
126
99
- def load_grounding_dino (model_path : Path ):
100
- grounding_dino_pipeline = pipeline (
101
- model = str (model_path ),
102
- task = "zero-shot-object-detection" ,
103
- local_files_only = True ,
104
- # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
105
- # model, and figure out how to make it work in the pipeline.
106
- # torch_dtype=TorchDevice.choose_torch_dtype(),
107
- )
108
- assert isinstance (grounding_dino_pipeline , ZeroShotObjectDetectionPipeline )
109
- return GroundingDinoPipeline (grounding_dino_pipeline )
110
-
111
- with context .models .load_remote_model (source = GROUNDING_DINO_MODEL_ID , loader = load_grounding_dino ) as detector :
127
+ with context .models .load_remote_model (
128
+ source = GROUNDING_DINO_MODEL_ID , loader = GroundedSAMInvocation ._load_grounding_dino
129
+ ) as detector :
112
130
assert isinstance (detector , GroundingDinoPipeline )
113
131
return detector .detect (image = image , candidate_labels = labels , threshold = threshold )
114
132
@@ -119,26 +137,12 @@ def _segment(
119
137
detection_results : list [DetectionResult ],
120
138
) -> list [DetectionResult ]:
121
139
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
122
-
123
- def load_sam_model (model_path : Path ):
124
- sam_model = AutoModelForMaskGeneration .from_pretrained (
125
- model_path ,
126
- local_files_only = True ,
127
- # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
128
- # model, and figure out how to make it work in the pipeline.
129
- # torch_dtype=TorchDevice.choose_torch_dtype(),
130
- )
131
- assert isinstance (sam_model , SamModel )
132
-
133
- sam_processor = AutoProcessor .from_pretrained (model_path , local_files_only = True )
134
- assert isinstance (sam_processor , SamProcessor )
135
- return SegmentAnythingModel (sam_model = sam_model , sam_processor = sam_processor )
136
-
137
140
with (
138
- context .models .load_remote_model (source = SEGMENT_ANYTHING_MODEL_ID , loader = load_sam_model ) as sam_pipeline ,
141
+ context .models .load_remote_model (
142
+ source = SEGMENT_ANYTHING_MODEL_ID , loader = GroundedSAMInvocation ._load_sam_model
143
+ ) as sam_pipeline ,
139
144
):
140
145
assert isinstance (sam_pipeline , SegmentAnythingModel )
141
-
142
146
masks = sam_pipeline .segment (image = image , detection_results = detection_results )
143
147
144
148
masks = self ._to_numpy_masks (masks )
0 commit comments