Skip to content

Commit 331bc73

Browse files
authored
Fixed the batch processing bug (#29)
1 parent c0ce782 commit 331bc73

File tree

4 files changed

+50
-27
lines changed

4 files changed

+50
-27
lines changed

docs/examples/satellite-predictor.ipynb

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"metadata": {
6-
},
5+
"metadata": {},
76
"source": [
87
"# Segment Anything Model for Geospatial Data \n",
98
"\n",
@@ -18,8 +17,7 @@
1817
},
1918
{
2019
"cell_type": "markdown",
21-
"metadata": {
22-
},
20+
"metadata": {},
2321
"source": [
2422
"## Install dependencies\n",
2523
"\n",
@@ -318,7 +316,7 @@
318316
"name": "python",
319317
"nbconvert_exporter": "python",
320318
"pygments_lexer": "ipython3",
321-
"version": "3.11.3"
319+
"version": "3.9.16"
322320
}
323321
},
324322
"nbformat": 4,

docs/examples/satellite.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,6 @@
210210
"sam = SamGeo(\n",
211211
" model_type='vit_h',\n",
212212
" checkpoint=checkpoint,\n",
213-
" erosion_kernel=(3, 3),\n",
214-
" mask_multiplier=255,\n",
215213
" sam_kwargs=None,\n",
216214
")"
217215
]
@@ -232,7 +230,9 @@
232230
"outputs": [],
233231
"source": [
234232
"mask = 'segment.tiff'\n",
235-
"sam.generate(image, mask, batch=True)"
233+
"sam.generate(\n",
234+
" image, mask, batch=True, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255\n",
235+
")"
236236
]
237237
},
238238
{

samgeo/common.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,10 @@ def tiff_to_tiff(
701701
sample_size=(512, 512),
702702
sample_resize=None,
703703
bound=128,
704+
foreground=True,
705+
erosion_kernel=(3, 3),
706+
mask_multiplier=255,
707+
**kwargs,
704708
):
705709
with rasterio.open(src_fp) as src:
706710
profile = src.profile
@@ -720,6 +724,9 @@ def tiff_to_tiff(
720724
profile["count"] = 1
721725
profile["dtype"] = "uint8"
722726

727+
if erosion_kernel is not None:
728+
erosion_kernel = np.ones(erosion_kernel, np.uint8)
729+
723730
with rasterio.open(dst_fp, "w", **profile) as dst:
724731
for b in tqdm(sample_grid):
725732
# Read each tile from the source
@@ -733,7 +740,13 @@ def tiff_to_tiff(
733740
)
734741

735742
# Run the model, call the __call__ method of SamGeo class
736-
uin8_out = func(uint8_rgb_in)
743+
uin8_out = func(
744+
uint8_rgb_in,
745+
foreground=foreground,
746+
erosion_kernel=erosion_kernel,
747+
mask_multiplier=mask_multiplier,
748+
**kwargs,
749+
)
737750

738751
if resize_hw is not None:
739752
uin8_out = cv2.resize(

samgeo/samgeo.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ def __init__(
2323
checkpoint="sam_vit_h_4b8939.pth",
2424
automatic=True,
2525
device=None,
26-
erosion_kernel=None,
27-
mask_multiplier=255,
2826
sam_kwargs=None,
2927
):
3028
"""Initialize the class.
@@ -39,10 +37,6 @@ def __init__(
3937
The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
4038
device (str, optional): The device to use. It can be one of the following: cpu, cuda.
4139
Defaults to None, which will use cuda if available.
42-
erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
43-
Set to None to disable it. Defaults to (3, 3).
44-
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
45-
You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
4640
sam_kwargs (dict, optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
4741
The available arguments with default values are listed below. See https://bit.ly/410RV0v for more details.
4842
@@ -96,38 +90,48 @@ def __init__(
9690
# Segment selected objects using input prompts
9791
self.predictor = SamPredictor(self.sam, **sam_kwargs)
9892

99-
# Apply the erosion filter to the mask to extract borders
100-
self.erosion_kernel = erosion_kernel
101-
if self.erosion_kernel is not None:
102-
self.erosion_kernel = np.ones(erosion_kernel, np.uint8)
93+
# # Apply the erosion filter to the mask to extract borders
94+
# self.erosion_kernel = erosion_kernel
95+
# if self.erosion_kernel is not None:
96+
# self.erosion_kernel = np.ones(erosion_kernel, np.uint8)
10397

104-
# Rescale the binary mask to a larger range, for example, from [0, 1] to [0, 255].
105-
self.mask_multiplier = mask_multiplier
98+
# # Rescale the binary mask to a larger range, for example, from [0, 1] to [0, 255].
99+
# self.mask_multiplier = mask_multiplier
106100

107-
def __call__(self, image):
101+
def __call__(
102+
self,
103+
image,
104+
foreground=True,
105+
erosion_kernel=(3, 3),
106+
mask_multiplier=255,
107+
**kwargs,
108+
):
108109
# Segment each image tile
109110
h, w, _ = image.shape
110111

111112
masks = self.mask_generator.generate(image)
112113

113-
resulting_mask = np.ones((h, w), dtype=np.uint8)
114+
if foreground: # Extract foreground objects only
115+
resulting_mask = np.zeros((h, w), dtype=np.uint8)
116+
else:
117+
resulting_mask = np.ones((h, w), dtype=np.uint8)
114118
resulting_borders = np.zeros((h, w), dtype=np.uint8)
115119

116120
for m in masks:
117121
mask = (m["segmentation"] > 0).astype(np.uint8)
118122
resulting_mask += mask
119123

120124
# Apply erosion to the mask
121-
if self.erosion_kernel is not None:
122-
mask_erode = cv2.erode(mask, self.erosion_kernel, iterations=1)
125+
if erosion_kernel is not None:
126+
mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
123127
mask_erode = (mask_erode > 0).astype(np.uint8)
124128
edge_mask = mask - mask_erode
125129
resulting_borders += edge_mask
126130

127131
resulting_mask = (resulting_mask > 0).astype(np.uint8)
128132
resulting_borders = (resulting_borders > 0).astype(np.uint8)
129133
resulting_mask_with_borders = resulting_mask - resulting_borders
130-
return resulting_mask_with_borders * self.mask_multiplier
134+
return resulting_mask_with_borders * mask_multiplier
131135

132136
def generate(
133137
self,
@@ -165,7 +169,15 @@ def generate(
165169
raise ValueError(f"Input path {source} does not exist.")
166170

167171
if batch: # Subdivide the image into tiles and segment each tile
168-
return tiff_to_tiff(source, output, self, **kwargs)
172+
return tiff_to_tiff(
173+
source,
174+
output,
175+
self,
176+
foreground=foreground,
177+
erosion_kernel=erosion_kernel,
178+
mask_multiplier=mask_multiplier,
179+
**kwargs,
180+
)
169181

170182
image = cv2.imread(source)
171183
elif isinstance(source, np.ndarray):

0 commit comments

Comments
 (0)