22
33from __future__ import annotations
44
5- from collections import OrderedDict
6-
75import numpy as np
86import torch
9- import torch .nn .functional as F # noqa: N812
10- from skimage import morphology
11- from torch import nn
12-
13- from tiatoolbox .utils import misc
14- from tiatoolbox .models .models_abc import ModelABC
15-
7+ from sam2 .automatic_mask_generator import SAM2AutomaticMaskGenerator
168from sam2 .build_sam import build_sam2 , build_sam2_hf
179from sam2 .sam2_image_predictor import SAM2ImagePredictor
18- from sam2 .automatic_mask_generator import SAM2AutomaticMaskGenerator
10+
11+ from tiatoolbox .models .models_abc import ModelABC
1912
2013
21- class SAMPrompts () :
14+ class SAMPrompts :
2215 """Structure of prompts for SAM."""
23- def __init__ (self , point_coords = None , point_labels = None , box_coords = None ):
16+
17+ def __init__ (self , point_coords = None , point_labels = None , box_coords = None ):
2418 self .point_coords = None if point_coords == [] else point_coords
2519 self .box_coords = None if box_coords == [] else box_coords
26- if ( point_coords and point_labels is None ) :
20+ if point_coords and point_labels is None :
2721 self .point_labels = [1 ] * len (point_coords )
28- else :
22+ else :
2923 self .point_labels = point_labels
3024
3125
@@ -44,15 +38,15 @@ def __init__(
4438 self .model = build_sam2_hf (model_hf_path , device = "cpu" )
4539 else :
4640 self .model = build_sam2 (model_cfg_path , checkpoint_path )
47-
41+
4842 self .predictor = SAM2ImagePredictor (self .model )
4943 self .generator = SAM2AutomaticMaskGenerator (self .model )
5044
5145 def forward (self : SAM , image : np .ndarray , prompts : SAMPrompts = None ) -> np .ndarray :
5246 """Torch method, this contains logic for using layers defined in init."""
5347 mask = self .generate_mask (self , image , prompts )
5448 return mask
55-
49+
5650 @staticmethod
5751 def infer_batch (
5852 model : torch .nn .Module ,
@@ -78,7 +72,9 @@ def infer_batch(
7872 model .eval ()
7973 model = model .to (device )
8074
81- if isinstance (batch_data , torch .Tensor ): # Move the tensor to the CPU if it's a PyTorch tensor
75+ if isinstance (
76+ batch_data , torch .Tensor
77+ ): # Move the tensor to the CPU if it's a PyTorch tensor
8278 batch_data = batch_data .to (device ).type (torch .float32 )
8379 batch_data = batch_data .cpu ().numpy ()
8480
@@ -92,7 +88,7 @@ def infer_batch(
9288 def encode_image (self , image : np .ndarray ) -> np .ndarray :
9389 """Encodes the image for feature extraction."""
9490 self .predictor .set_image (image )
95-
91+
9692 @staticmethod
9793 def generate_mask (self , features : np .ndarray , prompts : SAMPrompts ) -> np .ndarray :
9894 """Generates a segmentation mask using SAM 2, optionally guided by a prompt."""
@@ -109,24 +105,33 @@ def generate_mask(self, features: np.ndarray, prompts: SAMPrompts) -> np.ndarray
109105 scores = np .around (scores [sorted_ind ], 2 )
110106 else :
111107 masks = self .generator .generate (features )
112- scores = np .array ([mask [' predicted_iou' ] for mask in masks ])
108+ scores = np .array ([mask [" predicted_iou" ] for mask in masks ])
113109 return masks , scores
114-
110+
115111 @staticmethod
116112 def load_weights (self , checkpoint_path : str ) -> None :
117113 """Loads model weights from specified checkpoint."""
118- self .model .load_state_dict (torch .load (checkpoint_path , map_location = self .device ))
114+ self .model .load_state_dict (
115+ torch .load (checkpoint_path , map_location = self .device )
116+ )
119117
120118 @staticmethod
121119 def preproc (image : np .ndarray ) -> np .ndarray :
122120 """Pre-processes images - Converts them into a format accepted by SAM (HWC) from NCHW."""
123- if isinstance (image , torch .Tensor ): # Move the tensor to the CPU if it's a PyTorch tensor
121+ if isinstance (
122+ image , torch .Tensor
123+ ): # Move the tensor to the CPU if it's a PyTorch tensor
124124 image = image .cpu ().numpy ()
125-
125+
126126 # Handle different shapes
127- if image .ndim == 4 and image .shape == (1 ,512 ,512 ,3 ): # Case 1: (N, H, W, C)
127+ if image .ndim == 4 and image .shape == (1 , 512 , 512 , 3 ): # Case 1: (N, H, W, C)
128128 image = np .squeeze (image , axis = 0 ) # Remove batch dimension
129- elif image .ndim == 4 and image .shape == (1 ,3 ,512 ,512 ): # Case 2: (N, C, H, W)
129+ elif image .ndim == 4 and image .shape == (
130+ 1 ,
131+ 3 ,
132+ 512 ,
133+ 512 ,
134+ ): # Case 2: (N, C, H, W)
130135 image = np .squeeze (image , axis = 0 ) # Remove batch dimension
131136 image = np .transpose (image , (1 , 2 , 0 )) # (C, H, W) -> (H, W, C)
132137
@@ -136,4 +141,4 @@ def preproc(image: np.ndarray) -> np.ndarray:
136141 @staticmethod
137142 def postproc (image : np .ndarray ) -> np .ndarray :
138143 """Define the post-processing of this class of model."""
139- return image
144+ return image
0 commit comments