1313
1414from detectron2 .structures import ImageList , Instances
1515from segment_anything import ( # , Instances
16- SamAutomaticMaskGenerator ,
17- SamPredictor ,
16+ # SamAutomaticMaskGenerator,
17+ # SamPredictor,
1818 sam_model_registry ,
1919)
2020from torch .nn import functional as F
@@ -50,6 +50,7 @@ def __repr__(self):
5050
5151
5252def mask_to_bbx (mask ):
53+ mask = mask .cpu ()
5354 mask = np .array (mask )
5455 mask = np .squeeze (mask )
5556 h , w = mask .shape [- 2 :]
@@ -81,17 +82,26 @@ class SAM(BaseWrapper):
8182 def __init__ (self , device : str , ** kwargs ):
8283 super ().__init__ (device )
8384
85+ _path_prefix = (
86+ f"{ root_path } "
87+ if kwargs ["model_path_prefix" ] == "default"
88+ else kwargs ["model_path_prefix" ]
89+ )
90+ self .model_info = {
91+ "cfg" : f"{ _path_prefix } /{ kwargs ['cfg' ]} " ,
92+ "weights" : f"{ _path_prefix } /{ kwargs ['weights' ]} " ,
93+ }
94+
8495 self .model = (
85- sam_model_registry ["vit_h" ](checkpoint = kwargs ["weights" ]).to (device ).eval ()
96+ sam_model_registry ["vit_h" ](checkpoint = self .model_info ["weights" ])
97+ .to (device )
98+ .eval ()
8699 )
87- self .model .load_state_dict (torch .load (kwargs ["weights" ]))
88100
89- self .backbone = self .model .image_encoder
101+ self .image_encoder = self .model .image_encoder
90102 self .prompt_encoder = self .model .prompt_encoder
91103 self .head = self .model .mask_decoder
92104
93- # SamPredictor(self.model)
94- # print(SamPredictor)
95105 self .supported_split_points = Split_Points
96106
97107 assert "splits" in kwargs , "Split layer ids must be provided"
@@ -106,18 +116,31 @@ def __init__(self, device: str, **kwargs):
106116 zip (self .split_layer_list , [None ] * len (self .split_layer_list ))
107117 )
108118
109- self .annotation_file = "/o/projects/proj-river/ctc_sequences/vcm_testdata/samtest/annotations/mpeg-oiv6-segmentation-coco_fortest.json"
110-
111119 @property
112120 def SPLIT_IMGENC (self ):
113121 return str (self .supported_split_points .ImageEncoder )
114122
115- def input_to_features (self , x , device : str ) -> Dict :
123+ @staticmethod
124+ def prompt_inputs (file_name ):
125+ # [TODO] should be improved...
126+ prompt_link = file_name .replace ("/images/" , "/prompts/" ).replace (".jpg" , ".txt" )
127+
128+ with open (prompt_link , "r" ) as f :
129+ line = f .readline ()
130+ # first_two = list(map(int, line.strip().split()[:2]))
131+ parts = line .strip ().split ()
132+ prompts = list (map (int , parts [:2 ]))
133+ object_classes = [int (line .strip ().split ()[- 1 ])]
134+
135+ return prompts , object_classes
136+
137+ def input_to_features (self , x : list , device : str ) -> Dict :
116138 """Computes deep features at the intermediate layer(s) all the way from the input"""
117139 self .model = self .model .to (device ).eval ()
140+ assert isinstance (x , list ) and len (x ) == 1
118141
119142 if self .split_id == self .SPLIT_IMGENC :
120- return self ._input_to_image_encoder (x )
143+ return self ._input_to_image_encoder (x , device )
121144 else :
122145 self .logger .error (f"Not supported split point { self .split_id } " )
123146
@@ -129,48 +152,37 @@ def features_to_output(self, x: Dict, device: str):
129152 self .model = self .model .to (device ).eval ()
130153
131154 if self .split_id == self .SPLIT_IMGENC :
155+ assert "file_name" in x
156+
157+ prompts , object_classes = self .prompt_inputs (x ["file_name" ])
158+
132159 return self ._image_encoder_to_output (
133160 x ["data" ],
134161 x ["org_input_size" ],
135162 x ["input_size" ],
136- x ["prompts" ],
137- x ["object_classes" ],
163+ prompts ,
164+ object_classes ,
165+ device ,
138166 )
139167 else :
140168 self .logger .error (f"Not supported split points { self .split_id } " )
141169
142170 raise NotImplementedError
143171
144172 @torch .no_grad ()
145- def _input_to_image_encoder (self , x ):
173+ def _input_to_image_encoder (self , x , device ):
146174 """Computes and return encoded image all the way from the input"""
147- # TODO pre_processing
148- # print("AAAAA _input_to_image_encoder", x ,'\n')
149- # imgs = ImageList(x)
150- imgs = x [0 ]["image" ]
151- feature = {}
152- feature ["backbone" ] = self .backbone (imgs )
153-
154- prompt_link = (
155- x [0 ]["file_name" ].replace ("/images/" , "/prompts/" ).replace (".jpg" , ".txt" )
156- )
157- # print("AAAAA prompt_link", prompt_link)
158-
159- with open (prompt_link , "r" ) as f :
160- line = f .readline ()
161- # first_two = list(map(int, line.strip().split()[:2]))
162- parts = line .strip ().split ()
163- prompts = list (map (int , parts [:2 ]))
164- object_classes = [int (line .strip ().split ()[- 1 ])]
175+ assert len (x ) == 1
165176
166- image_sizes = [x [0 ]["height" ], x [0 ]["width" ]]
167- # print("AAAAA image_sizes", image_sizes, int(image_sizes[0]) * int(image_sizes[1])),
177+ img = x [0 ]["image" ].to (device )
178+ input_size = list (img .size ()[2 :])
179+ feature = {}
180+ input_img = self .model .preprocess (img )
181+ feature ["backbone" ] = self .image_encoder (input_img )
168182
169183 return {
170184 "data" : feature ,
171- "input_size" : image_sizes ,
172- "prompts" : prompts ,
173- "object_classes" : object_classes ,
185+ "input_size" : input_size ,
174186 }
175187
176188 @torch .no_grad ()
@@ -181,45 +193,6 @@ def get_input_size(self, x):
181193 image_sizes = [x [0 ]["height" ], x [0 ]["width" ]]
182194 return image_sizes # [1024, 1024]
183195
184- @torch .no_grad ()
185- def get_prompts (self , x ):
186- """Computes prompts"""
187- prompt_link = (
188- x [0 ]["file_name" ].replace ("/images/" , "/prompts/" ).replace (".jpg" , ".txt" )
189- )
190- # print("AAAAA prompt_link", prompt_link)
191-
192- with open (prompt_link , "r" ) as f :
193- line = f .readline ()
194- # first_two = list(map(int, line.strip().split()[:2]))
195- parts = line .strip ().split ()
196- prompts = list (map (int , parts [:2 ]))
197- object_classes = [int (line .strip ().split ()[- 1 ])]
198-
199- image_sizes = [x [0 ]["height" ], x [0 ]["width" ]]
200- # print("AAAAA image_sizes", image_sizes, int(image_sizes[0]) * int(image_sizes[1])),
201-
202- return prompts
203-
204- @torch .no_grad ()
205- def get_object_classes (self , x ):
206- """Computes input image size to the network"""
207- prompt_link = (
208- x [0 ]["file_name" ].replace ("/images/" , "/prompts/" ).replace (".jpg" , ".txt" )
209- )
210- # print("AAAAA prompt_link", prompt_link)
211-
212- with open (prompt_link , "r" ) as f :
213- line = f .readline ()
214- # first_two = list(map(int, line.strip().split()[:2]))
215- parts = line .strip ().split ()
216- prompts = list (map (int , parts [:2 ]))
217- object_classes = [int (line .strip ().split ()[- 1 ])]
218-
219- image_sizes = [x [0 ]["height" ], x [0 ]["width" ]]
220- # print("AAAAA image_sizes", image_sizes, int(image_sizes[0]) * int(image_sizes[1])),
221- return object_classes
222-
223196 @torch .no_grad ()
224197 def _image_encoder_to_output (
225198 self ,
@@ -228,6 +201,7 @@ def _image_encoder_to_output(
228201 input_img_size : List ,
229202 prompts : List ,
230203 object_classes : List ,
204+ device ,
231205 ):
232206 """
233207 performs downstream task using the encoded image feature
@@ -237,7 +211,7 @@ def _image_encoder_to_output(
237211
238212 input_points = [prompts ] # [[469, 295]] #prompts["points"]
239213 input_points = np .array (input_points )
240- input_points_ = torch .tensor (input_points )
214+ input_points_ = torch .tensor (input_points , device = device )
241215 input_points_ = input_points_ .unsqueeze (- 1 )
242216 input_points_ = input_points_ .permute (2 , 0 , 1 )
243217
@@ -246,7 +220,7 @@ def _image_encoder_to_output(
246220 input_labels_ = input_labels_ .unsqueeze (- 1 )
247221 input_labels_ = input_labels_ .permute (1 , 0 )
248222
249- points = (torch . tensor ( input_points_ ) , torch .tensor (input_labels_ ))
223+ points = (input_points_ , torch .tensor (input_labels_ , device = device ))
250224 prompt_feature = self .prompt_encoder (points = points , boxes = None , masks = None )
251225 image_pe = self .prompt_encoder .get_dense_pe ()
252226
@@ -261,7 +235,7 @@ def _image_encoder_to_output(
261235 # post process mask
262236 masks = F .interpolate (
263237 low_res_masks ,
264- (1024 , 1024 ),
238+ (self . image_encoder . img_size , self . image_encoder . img_size ),
265239 mode = "bilinear" ,
266240 align_corners = False ,
267241 )
@@ -270,7 +244,7 @@ def _image_encoder_to_output(
270244 ] # [..., : 793, : 1024]
271245 masks = F .interpolate (
272246 masks ,
273- (input_img_size [ 0 ], input_img_size [ 1 ]),
247+ (org_img_size [ "height" ], org_img_size [ "width" ]),
274248 mode = "bilinear" ,
275249 align_corners = False ,
276250 )
@@ -314,14 +288,26 @@ def _image_encoder_to_output(
314288 def forward (self , x ):
315289 """Complete the downstream task with end-to-end manner all the way from the input"""
316290 # test
317- enc = self ._input_to_image_encoder (self , x )
318- dec = self ._image_encoder_to_output (enc )
291+ enc_res = self ._input_to_image_encoder ([x ], self .device )
319292
320- return dec
293+ # suppose that the order of keys and values is matched
294+ enc_res ["data" ] = {
295+ k : v .to (device = self .device )
296+ for k , v in zip (self .split_layer_list , enc_res ["data" ].values ())
297+ }
298+
299+ prompts , object_classes = self .prompt_inputs (x ["file_name" ])
300+
301+ dec_res = self ._image_encoder_to_output (
302+ enc_res ["data" ],
303+ {"height" : x ["height" ], "width" : x ["width" ]},
304+ enc_res ["input_size" ],
305+ prompts ,
306+ object_classes ,
307+ device = self .device ,
308+ )
321309
322- # @property
323- # def cfg(self):
324- # return self._cfg
310+ return dec_res
325311
326312
327313@register_vision_model ("sam_vit_h_4b8939" )
0 commit comments