@@ -180,38 +180,29 @@ def predict_masks_nested(
180
180
torch .cat ([self .iou_token .weight , self .mask_tokens .weight ], dim = 0 ))
181
181
tokens = torch .cat ([output_tokens , sparse_prompt_embeddings ], dim = 2 )
182
182
183
- # TODO: remove this and make sure offsets are propagated
184
- offsets = tokens .offsets ()
185
-
186
183
src = dense_prompt_embeddings + image_embeddings .unsqueeze (1 )
187
184
pos_src = torch .zeros_like (src ) + image_pe
188
- b , c , h , w = src .values (). shape
185
+ h , w = src .shape [ - 2 :]
189
186
190
187
# Run the transformer
191
- # TODO: Run the full NTs through instead of just the buffers
192
- hs , src = self .transformer (src .values (), pos_src .values (), tokens .values ())
193
- iou_token_out = hs [:, 0 , :]
194
- mask_tokens_out = hs [:, 1 : (1 + self .num_mask_tokens ), :]
188
+ hs , src = self .transformer (src , pos_src , tokens )
189
+ iou_token_out = hs [..., 0 , :]
190
+ mask_tokens_out = hs [..., 1 : (1 + self .num_mask_tokens ), :]
195
191
196
192
# Upscale mask embeddings and predict masks using the mask tokens
197
- src = src .transpose (1 , 2 ). view ( b , c , h , w )
193
+ src = src .transpose (- 2 , - 1 ). unflatten ( - 1 , ( h , w ) )
198
194
upscaled_embedding = self .output_upscaling (src )
199
195
hyper_in_list : List [torch .Tensor ] = []
200
196
for i in range (self .num_mask_tokens ):
201
- hyper_in_list .append (self .output_hypernetworks_mlps [i ](mask_tokens_out [: , i , :]))
202
- hyper_in = torch .stack (hyper_in_list , dim = 1 )
203
- b , c , h , w = upscaled_embedding .shape
204
- masks = (hyper_in @ upscaled_embedding .view ( b , c , h * w )).view ( b , - 1 , h , w )
197
+ hyper_in_list .append (self .output_hypernetworks_mlps [i ](mask_tokens_out [... , i , :]))
198
+ hyper_in = torch .stack (hyper_in_list , dim = - 2 )
199
+ h , w = upscaled_embedding .shape [ - 2 :]
200
+ masks = (hyper_in @ upscaled_embedding .flatten ( - 2 )).unflatten ( - 1 , ( h , w ) )
205
201
206
202
# Generate mask quality predictions
207
203
iou_pred = self .iou_prediction_head (iou_token_out )
208
204
209
- # TODO: No need to create NT by hand once we propagate it properly through Transformer
210
- from torch .nested ._internal .nested_tensor import NestedTensor
211
- num_tensors = offsets .shape [0 ] - 1
212
- masks_nt = NestedTensor (masks , offsets )
213
- iou_pred_nt = NestedTensor (iou_pred , offsets )
214
- return masks_nt , iou_pred_nt
205
+ return masks , iou_pred
215
206
216
207
217
208
# Lightly adapted from
0 commit comments