Skip to content

Commit 71eedb2

Browse files
committed
Added docstrings
1 parent c1cb059 commit 71eedb2

File tree

3 files changed

+280
-101
lines changed

3 files changed

+280
-101
lines changed

tests/models/test_arch_sam.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from typing import Callable
55

66
from tiatoolbox.models import SAM
7-
from tiatoolbox.models.architecture.sam import SAMPrompts
7+
from tiatoolbox.utils import env_detection as toolbox_env
88
from tiatoolbox.utils import imread
99

10-
ON_GPU = False # TODO: Use Environment variable to set this to True
10+
ON_GPU = toolbox_env.has_gpu()
1111

1212
# Test pretrained Model =============================
1313

@@ -26,17 +26,14 @@ def test_functional_sam(
2626
# test inference
2727
# create prompts
2828

29-
prompts1 = SAMPrompts(point_coords=[[64, 64]])
30-
prompts2 = SAMPrompts(point_coords=[[64, 64]], point_labels=[1])
31-
prompts3 = SAMPrompts(box_coords=[[64, 64, 128, 128]])
32-
prompts4 = SAMPrompts(
33-
point_coords=[[64, 64]], point_labels=[1], box_coords=[[64, 64, 128, 128]]
34-
)
29+
points1 = [[[64, 64]]]
30+
points2 = [[[64, 64], [128, 128]]]
31+
boxes1 = [[[64, 64, 128, 128]]]
3532

3633
model = SAM()
3734

3835
_ = model.infer_batch(model, img, on_gpu=ON_GPU) # no prompts
39-
_ = model.infer_batch(model, img, prompts=prompts1, on_gpu=ON_GPU)
40-
_ = model.infer_batch(model, img, prompts=prompts2, on_gpu=ON_GPU)
41-
_ = model.infer_batch(model, img, prompts=prompts3, on_gpu=ON_GPU)
42-
_ = model.infer_batch(model, img, prompts=prompts4, on_gpu=ON_GPU)
36+
_ = model.infer_batch(model, img, points1, on_gpu=ON_GPU)
37+
_ = model.infer_batch(model, img, points2, on_gpu=ON_GPU)
38+
_ = model.infer_batch(model, img, box_coords=boxes1, on_gpu=ON_GPU)
39+
_ = model.infer_batch(model, img, points2, boxes1, on_gpu=ON_GPU)

tiatoolbox/models/architecture/sam.py

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
1010
from sam2.build_sam import build_sam2, build_sam2_hf
1111
from sam2.sam2_image_predictor import SAM2ImagePredictor
12+
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
1213

1314
from tiatoolbox.models.models_abc import ModelABC
1415

@@ -17,33 +18,112 @@
1718

1819

1920
class SAM(ModelABC):
20-
"""SAM architecture."""
21+
"""Segment Anything Model (SAM) Architecture.
22+
23+
Meta AI's zero-shot segmentation model.
24+
SAM is used for interactive general-purpose segmentation.
25+
26+
Currently supports both SAM and SAM2, each of which require
27+
different model checkpoints and configuration files.
28+
29+
SAM accepts an RGB image patch along with a list of point and bounding
30+
box coordinates as prompts.
31+
32+
Args:
33+
model_type (str):
34+
Model type. Currently supported: vit_b, vit_l, vit_h.
35+
Required for SAM.
36+
checkpoint_path (str):
37+
Path to the model checkpoint.
38+
Required for both SAM and SAM2.
39+
model_cfg_path (str):
40+
Path to the model configuration file.
41+
Required for SAM2.
42+
model_hf_path (str):
43+
Huggingface path for the pretrained SAM2 model.
44+
If provided, it will override the checkpoint_path and model_cfg_path.
45+
Default is "facebook/sam2-hiera-tiny".
46+
device (str):
47+
Device to run inference on.
48+
use_sam2 (bool):
49+
Whether to use SAM2 or not. Default is True.
50+
51+
Examples:
52+
>>> # instantiate SAM with checkpoint path and model type
53+
>>> sam = SAM(
54+
... model_type="vit_b",
55+
... checkpoint_path="path/to/sam_checkpoint.pth"
56+
... use_sam2=False
57+
... )
58+
>>> # instantiate SAM2 with checkpoint and config path
59+
>>> sam2 = SAM(
60+
... checkpoint_path="path/to/sam2_checkpoint.pth",
61+
... model_cfg_path="path/to/sam2_config.yaml"
62+
... )
63+
>>> # instantiate SAM2 with Huggingface path
64+
>>> sam2 = SAM(
65+
... model_hf_path="facebook/sam2-hiera-tiny"
66+
... )
67+
"""
2168

2269
def __init__(
2370
self: SAM,
24-
model_hf_path: str | None = "facebook/sam2-hiera-tiny",
71+
model_type: str | None = None,
2572
checkpoint_path: str | None = None,
2673
model_cfg_path: str | None = None,
74+
model_hf_path: str = "facebook/sam2-hiera-tiny",
75+
*,
76+
device: str = "cpu",
77+
use_sam2: bool = True,
2778
) -> None:
2879
"""Initialize :class:`SAM`."""
2980
super().__init__()
81+
self.use_sam2 = use_sam2
3082
self.net_name = "SAM"
3183

32-
if checkpoint_path is None or model_cfg_path is None:
33-
self.model = build_sam2_hf(model_hf_path, device="cpu")
84+
if self.use_sam2:
85+
# Load SAM2
86+
if checkpoint_path is None or model_cfg_path is None:
87+
self.model = build_sam2_hf(model_hf_path, device=device)
88+
else:
89+
self.model = build_sam2(model_cfg_path, checkpoint_path)
90+
self.predictor = SAM2ImagePredictor(self.model)
91+
self.generator = SAM2AutomaticMaskGenerator(self.model)
3492
else:
35-
self.model = build_sam2(model_cfg_path, checkpoint_path)
36-
37-
self.predictor = SAM2ImagePredictor(self.model)
38-
self.generator = SAM2AutomaticMaskGenerator(self.model)
93+
# Load original SAM
94+
if checkpoint_path is None:
95+
msg = "You must provide a checkpoint path for SAM."
96+
raise ValueError(msg)
97+
self.model = sam_model_registry[model_type](checkpoint=checkpoint_path).to(
98+
device
99+
)
100+
self.predictor = SamPredictor(self.model)
101+
self.generator = SamAutomaticMaskGenerator(self.model)
39102

40103
def forward(
41104
self: SAM,
42105
imgs: list,
43106
point_coords: list[list[IntPair]] | None = None,
44107
box_coords: list[list[IntBounds]] | None = None,
45108
) -> np.ndarray:
46-
"""Torch method, this contains logic for using layers defined in init."""
109+
"""Torch method. Defines forward pass on each image in the batch.
110+
111+
Note: This architecture only uses a single layer, so only one forward pass
112+
is needed.
113+
114+
Args:
115+
imgs (list):
116+
List of images to process, of the shape NHWC.
117+
point_coords (list):
118+
List of point coordinates for each image.
119+
box_coords (list):
120+
List of bounding box coordinates for each image.
121+
122+
Returns:
123+
list:
124+
List of masks and scores for each image.
125+
126+
"""
47127
batch_masks, batch_scores = [], []
48128

49129
for i, image in enumerate(imgs):
@@ -96,8 +176,10 @@ def infer_batch(
96176
batch_data (list):
97177
A batch of data generated by
98178
`torch.utils.data.DataLoader`.
99-
prompts (SAMPrompts):
100-
Prompts for SAM model.
179+
point_coords (list):
180+
Point coordinates for each image in the batch.
181+
box_coords (list):
182+
Bounding box coordinates for each image in the batch.
101183
device (str):
102184
Device to run inference on.
103185
@@ -115,7 +197,6 @@ def _encode_image(self: SAM, image: np.ndarray) -> np.ndarray:
115197
"""Encodes the image for feature extraction."""
116198
self.predictor.set_image(image)
117199

118-
@staticmethod
119200
def load_weights(self: SAM, checkpoint_path: str) -> None:
120201
"""Loads model weights from specified checkpoint."""
121202
self.model.load_state_dict(

0 commit comments

Comments
 (0)