Skip to content

Commit 2361b1f

Browse files
author
jakmro
committed
Refactor EfficientSAM initialization and clean up imports
1 parent cb17083 commit 2361b1f

File tree

6 files changed

+16
-19
lines changed

6 files changed

+16
-19
lines changed

examples/models/efficient_sam/efficient_sam_core/__init__.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

examples/models/efficient_sam/efficient_sam_core/efficient_sam.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the same directory.
66

7-
import math
8-
from typing import Any, List, Tuple, Type
7+
from typing import List, Tuple
98

109
import torch
1110
import torch.nn.functional as F
1211

13-
from torch import nn, Tensor
12+
from torch import nn
1413

1514
from .efficient_sam_decoder import MaskDecoder, PromptEncoder
1615
from .efficient_sam_encoder import ImageEncoderViT
17-
from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
16+
from .two_way_transformer import TwoWayTransformer
1817

1918

2019
class EfficientSam(nn.Module):
@@ -27,8 +26,8 @@ def __init__(
2726
prompt_encoder: PromptEncoder,
2827
decoder_max_num_input_points: int,
2928
mask_decoder: MaskDecoder,
30-
pixel_mean: List[float] = [0.485, 0.456, 0.406],
31-
pixel_std: List[float] = [0.229, 0.224, 0.225],
29+
pixel_mean: List[float] = None,
30+
pixel_std: List[float] = None,
3231
) -> None:
3332
"""
3433
SAM predicts object masks from an image and input prompts.
@@ -47,6 +46,10 @@ def __init__(
4746
self.prompt_encoder = prompt_encoder
4847
self.decoder_max_num_input_points = decoder_max_num_input_points
4948
self.mask_decoder = mask_decoder
49+
if pixel_mean is None:
50+
pixel_mean = [0.485, 0.456, 0.406]
51+
if pixel_std is None:
52+
pixel_std = [0.229, 0.224, 0.225]
5053
self.register_buffer(
5154
"pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False
5255
)

examples/models/efficient_sam/efficient_sam_core/efficient_sam_decoder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010
import torch
1111
import torch.nn as nn
12-
import torch.nn.functional as F
1312

1413
from .mlp import MLPBlock
1514

examples/models/efficient_sam/efficient_sam_core/efficient_sam_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the same directory.
66

77
import math
8-
from typing import List, Optional, Tuple, Type
8+
from typing import List, Type
99

1010
import torch
1111
import torch.nn as nn
@@ -216,7 +216,7 @@ def __init__(
216216
num_positions = num_patches + 1
217217
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim))
218218
self.blocks = nn.ModuleList()
219-
for i in range(depth):
219+
for _ in range(depth):
220220
vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True)
221221
self.blocks.append(vit_block)
222222
self.neck = nn.Sequential(

examples/models/efficient_sam/efficient_sam_core/two_way_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import math
22
from typing import Tuple, Type
3+
34
import torch
45
from torch import nn, Tensor
6+
57
from .mlp import MLPBlock
68

79

@@ -84,7 +86,7 @@ def forward(
8486
keys = image_embedding
8587

8688
# Apply transformer blocks and final layernorm
87-
for idx, layer in enumerate(self.layers):
89+
for _, layer in enumerate(self.layers):
8890
queries, keys = layer(
8991
queries=queries,
9092
keys=keys,

examples/models/efficient_sam/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import torch
1010

11-
from .efficient_sam_core.build_efficient_sam import build_efficient_sam_vitt
12-
1311
from ..model_base import EagerModelBase
1412

13+
from .efficient_sam_core.build_efficient_sam import build_efficient_sam_vitt
14+
1515

1616
class EfficientSAM(EagerModelBase):
1717
def __init__(self):

0 commit comments

Comments
 (0)