Skip to content

Commit 9104c3b

Browse files
committed
Add docstrings
1 parent 2b87cea commit 9104c3b

File tree

3 files changed

+97
-20
lines changed

3 files changed

+97
-20
lines changed

docs/examples/satellite.ipynb

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
"[![image](https://img.shields.io/badge/Open-Planetary%20Computer-black?style=flat&logo=microsoft)](https://pccompute.westeurope.cloudapp.azure.com/compute/hub/user-redirect/git-pull?repo=https://github.com/opengeos/segment-geospatial&urlpath=lab/tree/segment-geospatial/docs/examples/satellite.ipynb&branch=main)\n",
1414
"[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/satellite.ipynb)\n",
1515
"\n",
16-
"This notebook shows how to use segment satellite imagery using the Segment Anything Model (SAM) with a few lines of code. "
16+
"This notebook shows how to use segment satellite imagery using the Segment Anything Model (SAM) with a few lines of code. \n",
17+
"\n",
18+
"Make sure you use GPU runtime for this notebook. For Google Colab, go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator. "
1719
]
1820
},
1921
{
@@ -118,18 +120,7 @@
118120
"metadata": {},
119121
"outputs": [],
120122
"source": [
121-
"out_dir = os.path.join(os.path.expanduser('~'), 'Downloads')\n",
122-
"if not os.path.exists(out_dir):\n",
123-
" os.makedirs(out_dir)"
124-
]
125-
},
126-
{
127-
"cell_type": "code",
128-
"execution_count": null,
129-
"metadata": {},
130-
"outputs": [],
131-
"source": [
132-
"image = os.path.join(out_dir, 'satellite.tif')\n",
123+
"image = 'satellite.tif'\n",
133124
"# image = '/path/to/your/own/image.tif'"
134125
]
135126
},
@@ -174,7 +165,16 @@
174165
"metadata": {},
175166
"outputs": [],
176167
"source": [
177-
"checkpoint = os.path.join(out_dir, 'sam_vit_h_4b8939.pth')\n",
168+
"out_dir = os.path.join(os.path.expanduser('~'), 'Downloads')\n",
169+
"checkpoint = os.path.join(out_dir, 'sam_vit_h_4b8939.pth')"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178178
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
179179
"sam = SamGeo(checkpoint=checkpoint,\n",
180180
" model_type='vit_h',\n",
@@ -207,7 +207,9 @@
207207
"cell_type": "markdown",
208208
"metadata": {},
209209
"source": [
210-
"## Polygonize the raster data"
210+
"## Polygonize the raster data\n",
211+
"\n",
212+
"Save the segmentation results as a GeoPackage file."
211213
]
212214
},
213215
{
@@ -216,10 +218,28 @@
216218
"metadata": {},
217219
"outputs": [],
218220
"source": [
219-
"vector = os.path.join(out_dir, 'segment.gpkg')\n",
221+
"vector = 'segment.gpkg'\n",
220222
"sam.tiff_to_gpkg(mask, vector, simplify_tolerance=None)"
221223
]
222224
},
225+
{
226+
"attachments": {},
227+
"cell_type": "markdown",
228+
"metadata": {},
229+
"source": [
230+
"You can also save the segmentation results as any vector data format supported by GeoPandas."
231+
]
232+
},
233+
{
234+
"cell_type": "code",
235+
"execution_count": null,
236+
"metadata": {},
237+
"outputs": [],
238+
"source": [
239+
"shapefile = 'segment.shp'\n",
240+
"sam.tiff_to_vector(mask, vector)"
241+
]
242+
},
223243
{
224244
"attachments": {},
225245
"cell_type": "markdown",

mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ plugins:
2424
- mkdocs-jupyter:
2525
include_source: True
2626
execute: False
27-
# ignore_h1_titles: True
27+
ignore_h1_titles: True
2828
# execute_ignore: "*.ipynb"
2929

3030
markdown_extensions:

samgeo/samgeo.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232

3333

3434
class SamGeo:
35+
"""The main class for segmenting geospatial data with the Segment Anything Model (SAM). See
36+
https://github.com/facebookresearch/segment-anything
37+
"""
38+
3539
def __init__(
3640
self,
3741
checkpoint="sam_vit_h_4b8939.pth",
@@ -41,7 +45,17 @@ def __init__(
4145
mask_multiplier=255,
4246
sam_kwargs=None,
4347
):
44-
48+
"""Initialize the class.
49+
50+
Args:
51+
checkpoint (str, optional): The path to the checkpoint. It can be one of the following:
52+
sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth. Defaults to "sam_vit_h_4b8939.pth".
53+
model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_l. Defaults to 'vit_h'.
54+
device (str, optional): The device to use. It can be one of the following: cpu, cuda. Defaults to 'cpu'.
55+
erosion_kernel (tuple, optional): The erosion kernel. Defaults to (3, 3).
56+
mask_multiplier (int, optional): The mask multiplier. Defaults to 255.
57+
sam_kwargs (_type_, optional): The arguments for the SAM model. Defaults to None.
58+
"""
4559
if not os.path.exists(checkpoint):
4660
print(f'Checkpoint {checkpoint} does not exist.')
4761
download_checkpoint(output=checkpoint)
@@ -89,6 +103,13 @@ def __call__(self, image):
89103
return resulting_mask_with_borders * self.mask_multiplier
90104

91105
def generate(self, in_path, out_path, **kwargs):
106+
"""Segment the input image and save the result to the output path.
107+
108+
Args:
109+
in_path (str): The path to the input image.
110+
out_path (str): The path to the output image.
111+
"""
112+
92113
return tiff_to_tiff(in_path, out_path, self, **kwargs)
93114

94115
def image_to_image(self, image, **kwargs):
@@ -98,7 +119,43 @@ def download_tms_as_tiff(self, source, pt1, pt2, zoom, dist):
98119
image = draw_tile(source, pt1[0], pt1[1], pt2[0], pt2[1], zoom, dist)
99120
return image
100121

101-
def tiff_to_gpkg(self, tiff_path, gpkg_path, simplify_tolerance=None):
122+
def tiff_to_gpkg(self, tiff_path, gpkg_path, simplify_tolerance=None, **kwargs):
123+
"""Convert a tiff file to a gpkg file.
124+
125+
Args:
126+
tiff_path (str): The path to the tiff file.
127+
gpkg_path (str): The path to the gpkg file.
128+
simplify_tolerance (_type_, optional): The simplify tolerance. Defaults to None.
129+
"""
130+
131+
with rasterio.open(tiff_path) as src:
132+
band = src.read()
133+
134+
mask = band != 0
135+
shapes = features.shapes(band, mask=mask, transform=src.transform)
136+
137+
fc = [
138+
{"geometry": shapely.geometry.shape(shape), "properties": {"value": value}}
139+
for shape, value in shapes
140+
]
141+
if simplify_tolerance is not None:
142+
for i in fc:
143+
i["geometry"] = i["geometry"].simplify(tolerance=simplify_tolerance)
144+
145+
gdf = gpd.GeoDataFrame.from_features(fc)
146+
gdf.set_crs(epsg=src.crs.to_epsg(), inplace=True)
147+
gdf.to_file(gpkg_path, driver='GPKG', **kwargs)
148+
149+
150+
def tiff_to_vector(self, tiff_path, output, simplify_tolerance=None, **kwargs):
151+
"""Convert a tiff file to a gpkg file.
152+
153+
Args:
154+
tiff_path (str): The path to the tiff file.
155+
output (str): The path to the vector file.
156+
simplify_tolerance (_type_, optional): The simplify tolerance. Defaults to None.
157+
"""
158+
102159
with rasterio.open(tiff_path) as src:
103160
band = src.read()
104161

@@ -115,4 +172,4 @@ def tiff_to_gpkg(self, tiff_path, gpkg_path, simplify_tolerance=None):
115172

116173
gdf = gpd.GeoDataFrame.from_features(fc)
117174
gdf.set_crs(epsg=src.crs.to_epsg(), inplace=True)
118-
gdf.to_file(gpkg_path, driver='GPKG')
175+
gdf.to_file(output, **kwargs)

0 commit comments

Comments
 (0)