Skip to content

Commit f7c45f9

Browse files
authored
Added support for input prompts (#30)
* Add support for input prompts * Add show_points and show_box functions * Updated notebooks * Add functions for coordinate transformation * Improve predict function * Added save_prediction method * Add sam map gui * Add interactive segmentation gui * Add input prompt notebook * Add demo to notebook
1 parent ade624d commit f7c45f9

File tree

9 files changed

+943
-76
lines changed

9 files changed

+943
-76
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,24 @@ The **segment-geospatial** package draws its inspiration from [segment-anything-
2424

2525
- [Segmenting satellite imagery](https://samgeo.gishub.org/examples/satellite)
2626
- [Automatically generating object masks](https://samgeo.gishub.org/examples/automatic_mask_generator)
27+
- [Segmenting satellite imagery with input prompts](https://samgeo.gishub.org/examples/input_prompts)
2728

2829
## Demos
2930

3031
- Automatic mask generator
3132

3233
![](https://i.imgur.com/I1IhDgz.gif)
3334

35+
- Interactive segmentation with input prompts
36+
37+
![](https://i.imgur.com/GV7Rzxt.gif)
38+
39+
## Tutorials
40+
41+
Video tutorials are available on my [YouTube Channel](https://youtube.com/@giswqs).
42+
43+
[![Alt text](https://img.youtube.com/vi/YHA_-QMB8_U/0.jpg)](https://www.youtube.com/playlist?list=PLAxJ4-o7ZoPcrg5RnZjkB_KY6tv96WO2h)
44+
3445
## Acknowledgements
3546

3647
This package was made possible by the following open source projects. Credit goes to the developers of these projects.

docs/examples/input_prompts.ipynb

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
{
2+
"cells": [
3+
{
4+
"attachments": {},
5+
"cell_type": "markdown",
6+
"metadata": {},
7+
"source": [
8+
"# Generating object masks from input prompts with SAM\n",
9+
"\n",
10+
"[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/segment-geospatial/blob/main/docs/examples/input_prompts.ipynb)\n",
11+
"[![image](https://img.shields.io/badge/Open-Planetary%20Computer-black?style=flat&logo=microsoft)](https://pccompute.westeurope.cloudapp.azure.com/compute/hub/user-redirect/git-pull?repo=https://github.com/opengeos/segment-geospatial&urlpath=lab/tree/segment-geospatial/docs/examples/input_prompts.ipynb&branch=main)\n",
12+
"[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/input_prompts.ipynb)\n",
13+
"\n",
14+
"This notebook shows how to generate object masks from input prompts with the Segment Anything Model (SAM). \n",
15+
"\n",
16+
"Make sure you use GPU runtime for this notebook. For Google Colab, go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator. \n",
17+
"\n",
18+
"The notebook is adapted from [segment-anything/notebooks/predictor_example.ipynb](https://github.com/opengeos/segment-anything/blob/pypi/notebooks/predictor_example.ipynb), but I have made it much easier to save the segmentation results and visualize them."
19+
]
20+
},
21+
{
22+
"attachments": {},
23+
"cell_type": "markdown",
24+
"metadata": {},
25+
"source": [
26+
"## Install dependencies\n",
27+
"\n",
28+
"Uncomment and run the following cell to install the required dependencies."
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": null,
34+
"metadata": {},
35+
"outputs": [],
36+
"source": [
37+
"# %pip install segment-geospatial leafmap localtileserver"
38+
]
39+
},
40+
{
41+
"attachments": {},
42+
"cell_type": "markdown",
43+
"metadata": {},
44+
"source": [
45+
"If you encounter any issues, please try to uncomment the following cell and run it to update the package from GitHub. New features are usually added to GitHub first before they are released to PyPI."
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": null,
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"# %pip install git+https://github.com/opengeos/segment-geospatial.git"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"import os\n",
64+
"import leafmap\n",
65+
"from samgeo import SamGeo, tms_to_geotiff"
66+
]
67+
},
68+
{
69+
"attachments": {},
70+
"cell_type": "markdown",
71+
"metadata": {},
72+
"source": [
73+
"## Create an interactive map"
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"metadata": {},
80+
"outputs": [],
81+
"source": [
82+
"m = leafmap.Map(center=[37.6412, -122.1353], zoom=15, height=\"800px\")\n",
83+
"m.add_basemap(\"SATELLITE\")\n",
84+
"m"
85+
]
86+
},
87+
{
88+
"attachments": {},
89+
"cell_type": "markdown",
90+
"metadata": {},
91+
"source": [
92+
"## Download a sample image\n",
93+
"\n",
94+
"Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": null,
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"if m.user_roi is not None:\n",
104+
" bbox = m.user_roi_bounds()\n",
105+
"else:\n",
106+
" bbox = [-122.1497, 37.6311, -122.1203, 37.6458]"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": null,
112+
"metadata": {},
113+
"outputs": [],
114+
"source": [
115+
"image = \"satellite.tif\"\n",
116+
"tms_to_geotiff(output=image, bbox=bbox, zoom=16, source=\"Satellite\", overwrite=True)"
117+
]
118+
},
119+
{
120+
"attachments": {},
121+
"cell_type": "markdown",
122+
"metadata": {},
123+
"source": [
124+
"You can also use your own image. Uncomment and run the following cell to use your own image."
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": null,
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"# image = '/path/to/your/own/image.tif'"
134+
]
135+
},
136+
{
137+
"attachments": {},
138+
"cell_type": "markdown",
139+
"metadata": {},
140+
"source": [
141+
"Display the downloaded image on the map."
142+
]
143+
},
144+
{
145+
"cell_type": "code",
146+
"execution_count": null,
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"m.layers[-1].visible = False\n",
151+
"m.add_raster(image, layer_name=\"Image\")\n",
152+
"m"
153+
]
154+
},
155+
{
156+
"attachments": {},
157+
"cell_type": "markdown",
158+
"metadata": {},
159+
"source": [
160+
"## Initialize SAM class\n",
161+
"\n",
162+
"Specify the file path to the model checkpoint. If it is not specified, the model will to downloaded to the working directory."
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": null,
168+
"metadata": {},
169+
"outputs": [],
170+
"source": [
171+
"out_dir = os.path.join(os.path.expanduser(\"~\"), \"Downloads\")\n",
172+
"checkpoint = os.path.join(out_dir, \"sam_vit_h_4b8939.pth\")"
173+
]
174+
},
175+
{
176+
"attachments": {},
177+
"cell_type": "markdown",
178+
"metadata": {},
179+
"source": [
180+
"Set `automatic=False` to disable the `SamAutomaticMaskGenerator` and enable the `SamPredictor`."
181+
]
182+
},
183+
{
184+
"cell_type": "code",
185+
"execution_count": null,
186+
"metadata": {},
187+
"outputs": [],
188+
"source": [
189+
"sam = SamGeo(\n",
190+
" model_type=\"vit_h\",\n",
191+
" checkpoint=checkpoint,\n",
192+
" automatic=False,\n",
193+
" sam_kwargs=None,\n",
194+
")"
195+
]
196+
},
197+
{
198+
"attachments": {},
199+
"cell_type": "markdown",
200+
"metadata": {},
201+
"source": [
202+
"Specify the image to segment. "
203+
]
204+
},
205+
{
206+
"cell_type": "code",
207+
"execution_count": null,
208+
"metadata": {},
209+
"outputs": [],
210+
"source": [
211+
"sam.set_image(image)"
212+
]
213+
},
214+
{
215+
"attachments": {},
216+
"cell_type": "markdown",
217+
"metadata": {},
218+
"source": [
219+
"## Image segmentation with input points\n",
220+
"\n",
221+
"A single point can be used to segment an object. The point can be specified as a tuple of (x, y), such as (col, row) or (lon, lat). The points can also be specified as a file path to a vector dataset. For non (col, row) input points, specify the `point_crs` parameter, which will automatically transform the points to the image column and row coordinates.\n",
222+
"\n",
223+
"Try a single point input:"
224+
]
225+
},
226+
{
227+
"cell_type": "code",
228+
"execution_count": null,
229+
"metadata": {},
230+
"outputs": [],
231+
"source": [
232+
"point_coords = [[-122.1419, 37.6383]]\n",
233+
"sam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output='mask1.tif')\n",
234+
"m.add_raster('mask1.tif', layer_name=\"Mask1\", nodata=0, cmap=\"Blues\", opacity=1)\n",
235+
"m"
236+
]
237+
},
238+
{
239+
"attachments": {},
240+
"cell_type": "markdown",
241+
"metadata": {},
242+
"source": [
243+
"Try multiple points input:"
244+
]
245+
},
246+
{
247+
"cell_type": "code",
248+
"execution_count": null,
249+
"metadata": {},
250+
"outputs": [],
251+
"source": [
252+
"point_coords = [[-122.1464, 37.6431], [-122.1449, 37.6415], [-122.1451, 37.6395]]\n",
253+
"sam.predict(point_coords, point_labels=1, point_crs=\"EPSG:4326\", output='mask2.tif')\n",
254+
"m.add_raster('mask2.tif', layer_name=\"Mask2\", nodata=0, cmap=\"Greens\", opacity=1)\n",
255+
"m"
256+
]
257+
},
258+
{
259+
"attachments": {},
260+
"cell_type": "markdown",
261+
"metadata": {},
262+
"source": [
263+
"## Interactive segmentation\n",
264+
"\n",
265+
"Display the interactive map and use the marker tool to draw points on the map. Then click on the `Segment` button to segment the objects. The results will be added to the map automatically. Click on the `Reset` button to clear the points and the results."
266+
]
267+
},
268+
{
269+
"cell_type": "code",
270+
"execution_count": null,
271+
"metadata": {},
272+
"outputs": [],
273+
"source": [
274+
"m = sam.show_map()\n",
275+
"m"
276+
]
277+
},
278+
{
279+
"attachments": {},
280+
"cell_type": "markdown",
281+
"metadata": {},
282+
"source": [
283+
"![](https://i.imgur.com/GV7Rzxt.gif)"
284+
]
285+
},
286+
{
287+
"attachments": {},
288+
"cell_type": "markdown",
289+
"metadata": {},
290+
"source": []
291+
}
292+
],
293+
"metadata": {
294+
"kernelspec": {
295+
"display_name": "sam",
296+
"language": "python",
297+
"name": "python3"
298+
},
299+
"language_info": {
300+
"codemirror_mode": {
301+
"name": "ipython",
302+
"version": 3
303+
},
304+
"file_extension": ".py",
305+
"mimetype": "text/x-python",
306+
"name": "python",
307+
"nbconvert_exporter": "python",
308+
"pygments_lexer": "ipython3",
309+
"version": "3.9.16"
310+
},
311+
"orig_nbformat": 4
312+
},
313+
"nbformat": 4,
314+
"nbformat_minor": 2
315+
}

docs/examples/satellite-predictor.ipynb

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@
5555
"source": [
5656
"import os\n",
5757
"import leafmap\n",
58-
"import torch\n",
59-
"from samgeo import SamGeo, SamGeoPredictor, tms_to_geotiff, get_basemaps\n",
58+
"from samgeo import SamGeoPredictor, tms_to_geotiff, get_basemaps\n",
6059
"from segment_anything import sam_model_registry"
6160
]
6261
},
@@ -152,7 +151,9 @@
152151
"metadata": {},
153152
"outputs": [],
154153
"source": [
155-
"tms_to_geotiff(output=image, bbox=bbox, zoom=zoom + 1, source=\"Satellite\")"
154+
"tms_to_geotiff(\n",
155+
" output=image, bbox=bbox, zoom=zoom + 1, source=\"Satellite\", overwrite=True\n",
156+
")"
156157
]
157158
},
158159
{
@@ -286,13 +287,6 @@
286287
"m.add_vector(vector, layer_name=\"Vector\", style=style)\n",
287288
"m"
288289
]
289-
},
290-
{
291-
"cell_type": "code",
292-
"execution_count": null,
293-
"metadata": {},
294-
"outputs": [],
295-
"source": []
296290
}
297291
],
298292
"metadata": {

0 commit comments

Comments
 (0)