Skip to content

Commit a60de07

Browse files
authored
Use jagged layout NT through the MaskDecoder (#45)
1 parent 92efc1d commit a60de07

File tree

3 files changed

+18
-30
lines changed

3 files changed

+18
-30
lines changed

segment_anything_fast/modeling/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
3636
self.eps = eps
3737

3838
def forward(self, x: torch.Tensor) -> torch.Tensor:
39-
u = x.mean(1, keepdim=True)
40-
s = (x - u).pow(2).mean(1, keepdim=True)
39+
u = x.mean(-3, keepdim=True)
40+
s = (x - u).pow(2).mean(-3, keepdim=True)
4141
x = (x - u) / torch.sqrt(s + self.eps)
4242
x = self.weight[:, None, None] * x + self.bias[:, None, None]
4343
return x

segment_anything_fast/modeling/mask_decoder.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,38 +180,29 @@ def predict_masks_nested(
180180
torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0))
181181
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
182182

183-
# TODO: remove this and make sure offsets are propagated
184-
offsets = tokens.offsets()
185-
186183
src = dense_prompt_embeddings + image_embeddings.unsqueeze(1)
187184
pos_src = torch.zeros_like(src) + image_pe
188-
b, c, h, w = src.values().shape
185+
h, w = src.shape[-2:]
189186

190187
# Run the transformer
191-
# TODO: Run the full NTs through instead of just the buffers
192-
hs, src = self.transformer(src.values(), pos_src.values(), tokens.values())
193-
iou_token_out = hs[:, 0, :]
194-
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
188+
hs, src = self.transformer(src, pos_src, tokens)
189+
iou_token_out = hs[..., 0, :]
190+
mask_tokens_out = hs[..., 1 : (1 + self.num_mask_tokens), :]
195191

196192
# Upscale mask embeddings and predict masks using the mask tokens
197-
src = src.transpose(1, 2).view(b, c, h, w)
193+
src = src.transpose(-2, -1).unflatten(-1, (h, w))
198194
upscaled_embedding = self.output_upscaling(src)
199195
hyper_in_list: List[torch.Tensor] = []
200196
for i in range(self.num_mask_tokens):
201-
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
202-
hyper_in = torch.stack(hyper_in_list, dim=1)
203-
b, c, h, w = upscaled_embedding.shape
204-
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
197+
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[..., i, :]))
198+
hyper_in = torch.stack(hyper_in_list, dim=-2)
199+
h, w = upscaled_embedding.shape[-2:]
200+
masks = (hyper_in @ upscaled_embedding.flatten(-2)).unflatten(-1, (h, w))
205201

206202
# Generate mask quality predictions
207203
iou_pred = self.iou_prediction_head(iou_token_out)
208204

209-
# TODO: No need to create NT by hand once we propagate it properly through Transformer
210-
from torch.nested._internal.nested_tensor import NestedTensor
211-
num_tensors = offsets.shape[0] - 1
212-
masks_nt = NestedTensor(masks, offsets)
213-
iou_pred_nt = NestedTensor(iou_pred, offsets)
214-
return masks_nt, iou_pred_nt
205+
return masks, iou_pred
215206

216207

217208
# Lightly adapted from

segment_anything_fast/modeling/transformer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,8 @@ def forward(
7979
torch.Tensor: the processed image_embedding
8080
"""
8181
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
82-
bs, c, h, w = image_embedding.shape
83-
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84-
image_pe = image_pe.flatten(2).permute(0, 2, 1)
82+
image_embedding = image_embedding.flatten(-2).transpose(-1, -2)
83+
image_pe = image_pe.flatten(-2).transpose(-1, -2)
8584

8685
# Prepare queries
8786
queries = point_embedding
@@ -206,14 +205,12 @@ def __init__(
206205
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207206

208207
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209-
b, n, c = x.shape
210-
x = x.reshape(b, n, num_heads, c // num_heads)
211-
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
208+
x = x.unflatten(-1, (num_heads, -1))
209+
return x.transpose(-3, -2) # B... x N_heads x N_tokens x C_per_head
212210

213211
def _recombine_heads(self, x: Tensor) -> Tensor:
214-
b, n_heads, n_tokens, c_per_head = x.shape
215-
x = x.transpose(1, 2)
216-
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
212+
x = x.transpose(-3, -2)
213+
return x.flatten(-2)
217214

218215
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219216
# Input projections

0 commit comments

Comments
 (0)