99from sam2 .automatic_mask_generator import SAM2AutomaticMaskGenerator
1010from sam2 .build_sam import build_sam2 , build_sam2_hf
1111from sam2 .sam2_image_predictor import SAM2ImagePredictor
12+ from segment_anything import SamAutomaticMaskGenerator , SamPredictor , sam_model_registry
1213
1314from tiatoolbox .models .models_abc import ModelABC
1415
1718
1819
1920class SAM (ModelABC ):
20- """SAM architecture."""
21+ """Segment Anything Model (SAM) Architecture.
22+
23+ Meta AI's zero-shot segmentation model.
24+ SAM is used for interactive general-purpose segmentation.
25+
26+ Currently supports both SAM and SAM2, each of which require
27+ different model checkpoints and configuration files.
28+
29+ SAM accepts an RGB image patch along with a list of point and bounding
30+ box coordinates as prompts.
31+
32+ Args:
33+ model_type (str):
34+ Model type. Currently supported: vit_b, vit_l, vit_h.
35+ Required for SAM.
36+ checkpoint_path (str):
37+ Path to the model checkpoint.
38+ Required for both SAM and SAM2.
39+ model_cfg_path (str):
40+ Path to the model configuration file.
41+ Required for SAM2.
42+ model_hf_path (str):
43+ Huggingface path for the pretrained SAM2 model.
44+ If provided, it will override the checkpoint_path and model_cfg_path.
45+ Default is "facebook/sam2-hiera-tiny".
46+ device (str):
47+ Device to run inference on.
48+ use_sam2 (bool):
49+ Whether to use SAM2 or not. Default is True.
50+
51+ Examples:
52+ >>> # instantiate SAM with checkpoint path and model type
53+ >>> sam = SAM(
54+ ... model_type="vit_b",
55+ ... checkpoint_path="path/to/sam_checkpoint.pth"
56+ ... use_sam2=False
57+ ... )
58+ >>> # instantiate SAM2 with checkpoint and config path
59+ >>> sam2 = SAM(
60+ ... checkpoint_path="path/to/sam2_checkpoint.pth",
61+ ... model_cfg_path="path/to/sam2_config.yaml"
62+ ... )
63+ >>> # instantiate SAM2 with Huggingface path
64+ >>> sam2 = SAM(
65+ ... model_hf_path="facebook/sam2-hiera-tiny"
66+ ... )
67+ """
2168
2269 def __init__ (
2370 self : SAM ,
24- model_hf_path : str | None = "facebook/sam2-hiera-tiny" ,
71+ model_type : str | None = None ,
2572 checkpoint_path : str | None = None ,
2673 model_cfg_path : str | None = None ,
74+ model_hf_path : str = "facebook/sam2-hiera-tiny" ,
75+ * ,
76+ device : str = "cpu" ,
77+ use_sam2 : bool = True ,
2778 ) -> None :
2879 """Initialize :class:`SAM`."""
2980 super ().__init__ ()
81+ self .use_sam2 = use_sam2
3082 self .net_name = "SAM"
3183
32- if checkpoint_path is None or model_cfg_path is None :
33- self .model = build_sam2_hf (model_hf_path , device = "cpu" )
84+ if self .use_sam2 :
85+ # Load SAM2
86+ if checkpoint_path is None or model_cfg_path is None :
87+ self .model = build_sam2_hf (model_hf_path , device = device )
88+ else :
89+ self .model = build_sam2 (model_cfg_path , checkpoint_path )
90+ self .predictor = SAM2ImagePredictor (self .model )
91+ self .generator = SAM2AutomaticMaskGenerator (self .model )
3492 else :
35- self .model = build_sam2 (model_cfg_path , checkpoint_path )
36-
37- self .predictor = SAM2ImagePredictor (self .model )
38- self .generator = SAM2AutomaticMaskGenerator (self .model )
93+ # Load original SAM
94+ if checkpoint_path is None :
95+ msg = "You must provide a checkpoint path for SAM."
96+ raise ValueError (msg )
97+ self .model = sam_model_registry [model_type ](checkpoint = checkpoint_path ).to (
98+ device
99+ )
100+ self .predictor = SamPredictor (self .model )
101+ self .generator = SamAutomaticMaskGenerator (self .model )
39102
40103 def forward (
41104 self : SAM ,
42105 imgs : list ,
43106 point_coords : list [list [IntPair ]] | None = None ,
44107 box_coords : list [list [IntBounds ]] | None = None ,
45108 ) -> np .ndarray :
46- """Torch method, this contains logic for using layers defined in init."""
109+ """Torch method. Defines forward pass on each image in the batch.
110+
111+ Note: This architecture only uses a single layer, so only one forward pass
112+ is needed.
113+
114+ Args:
115+ imgs (list):
116+ List of images to process, of the shape NHWC.
117+ point_coords (list):
118+ List of point coordinates for each image.
119+ box_coords (list):
120+ List of bounding box coordinates for each image.
121+
122+ Returns:
123+ list:
124+ List of masks and scores for each image.
125+
126+ """
47127 batch_masks , batch_scores = [], []
48128
49129 for i , image in enumerate (imgs ):
@@ -96,8 +176,10 @@ def infer_batch(
96176 batch_data (list):
97177 A batch of data generated by
98178 `torch.utils.data.DataLoader`.
99- prompts (SAMPrompts):
100- Prompts for SAM model.
179+ point_coords (list):
180+ Point coordinates for each image in the batch.
181+ box_coords (list):
182+ Bounding box coordinates for each image in the batch.
101183 device (str):
102184 Device to run inference on.
103185
@@ -115,7 +197,6 @@ def _encode_image(self: SAM, image: np.ndarray) -> np.ndarray:
115197 """Encodes the image for feature extraction."""
116198 self .predictor .set_image (image )
117199
118- @staticmethod
119200 def load_weights (self : SAM , checkpoint_path : str ) -> None :
120201 """Loads model weights from specified checkpoint."""
121202 self .model .load_state_dict (
0 commit comments