@@ -23,8 +23,6 @@ def __init__(
2323 checkpoint = "sam_vit_h_4b8939.pth" ,
2424 automatic = True ,
2525 device = None ,
26- erosion_kernel = None ,
27- mask_multiplier = 255 ,
2826 sam_kwargs = None ,
2927 ):
3028 """Initialize the class.
@@ -39,10 +37,6 @@ def __init__(
3937 The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
4038 device (str, optional): The device to use. It can be one of the following: cpu, cuda.
4139 Defaults to None, which will use cuda if available.
42- erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
43- Set to None to disable it. Defaults to (3, 3).
44- mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
45- You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
4640 sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
4741 The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.
4842
@@ -96,38 +90,48 @@ def __init__(
9690 # Segment selected objects using input prompts
9791 self .predictor = SamPredictor (self .sam , ** sam_kwargs )
9892
99- # Apply the erosion filter to the mask to extract borders
100- self .erosion_kernel = erosion_kernel
101- if self .erosion_kernel is not None :
102- self .erosion_kernel = np .ones (erosion_kernel , np .uint8 )
93+ # # Apply the erosion filter to the mask to extract borders
94+ # self.erosion_kernel = erosion_kernel
95+ # if self.erosion_kernel is not None:
96+ # self.erosion_kernel = np.ones(erosion_kernel, np.uint8)
10397
104- # Rescale the binary mask to a larger range, for example, from [0, 1] to [0, 255].
105- self .mask_multiplier = mask_multiplier
98+ # # Rescale the binary mask to a larger range, for example, from [0, 1] to [0, 255].
99+ # self.mask_multiplier = mask_multiplier
106100
107- def __call__ (self , image ):
101+ def __call__ (
102+ self ,
103+ image ,
104+ foreground = True ,
105+ erosion_kernel = (3 , 3 ),
106+ mask_multiplier = 255 ,
107+ ** kwargs ,
108+ ):
108109 # Segment each image tile
109110 h , w , _ = image .shape
110111
111112 masks = self .mask_generator .generate (image )
112113
113- resulting_mask = np .ones ((h , w ), dtype = np .uint8 )
114+ if foreground : # Extract foreground objects only
115+ resulting_mask = np .zeros ((h , w ), dtype = np .uint8 )
116+ else :
117+ resulting_mask = np .ones ((h , w ), dtype = np .uint8 )
114118 resulting_borders = np .zeros ((h , w ), dtype = np .uint8 )
115119
116120 for m in masks :
117121 mask = (m ["segmentation" ] > 0 ).astype (np .uint8 )
118122 resulting_mask += mask
119123
120124 # Apply erosion to the mask
121- if self . erosion_kernel is not None :
122- mask_erode = cv2 .erode (mask , self . erosion_kernel , iterations = 1 )
125+ if erosion_kernel is not None :
126+ mask_erode = cv2 .erode (mask , erosion_kernel , iterations = 1 )
123127 mask_erode = (mask_erode > 0 ).astype (np .uint8 )
124128 edge_mask = mask - mask_erode
125129 resulting_borders += edge_mask
126130
127131 resulting_mask = (resulting_mask > 0 ).astype (np .uint8 )
128132 resulting_borders = (resulting_borders > 0 ).astype (np .uint8 )
129133 resulting_mask_with_borders = resulting_mask - resulting_borders
130- return resulting_mask_with_borders * self . mask_multiplier
134+ return resulting_mask_with_borders * mask_multiplier
131135
132136 def generate (
133137 self ,
@@ -165,7 +169,15 @@ def generate(
165169 raise ValueError (f"Input path { source } does not exist." )
166170
167171 if batch : # Subdivide the image into tiles and segment each tile
168- return tiff_to_tiff (source , output , self , ** kwargs )
172+ return tiff_to_tiff (
173+ source ,
174+ output ,
175+ self ,
176+ foreground = foreground ,
177+ erosion_kernel = erosion_kernel ,
178+ mask_multiplier = mask_multiplier ,
179+ ** kwargs ,
180+ )
169181
170182 image = cv2 .imread (source )
171183 elif isinstance (source , np .ndarray ):
0 commit comments