Skip to content

Commit 6286657

Browse files
authored
Added support for segmenting non-georeferenced imagery (#66)
1 parent 2a6b638 commit 6286657

File tree

2 files changed

+132
-8
lines changed

2 files changed

+132
-8
lines changed

samgeo/common.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,12 +1400,24 @@ def show_box(image, box, ax):
14001400
def overlay_images(
14011401
image1,
14021402
image2,
1403-
opacity=0.5,
1403+
alpha=0.5,
14041404
backend="TkAgg",
14051405
height_ratios=[10, 1],
14061406
show_args1={},
14071407
show_args2={},
14081408
):
1409+
"""Overlays two images using a slider to control the opacity of the top image.
1410+
1411+
Args:
1412+
image1 (str | np.ndarray): The first input image at the bottom represented as a NumPy array or the path to the image.
1413+
image2 (_type_): The second input image on top represented as a NumPy array or the path to the image.
1414+
alpha (float, optional): The alpha value of the top image. Defaults to 0.5.
1415+
backend (str, optional): The backend of the matplotlib plot. Defaults to "TkAgg".
1416+
height_ratios (list, optional): The height ratios of the two subplots. Defaults to [10, 1].
1417+
show_args1 (dict, optional): The keyword arguments to pass to the imshow() function for the first image. Defaults to {}.
1418+
show_args2 (dict, optional): The keyword arguments to pass to the imshow() function for the second image. Defaults to {}.
1419+
1420+
"""
14091421
import sys
14101422
import matplotlib
14111423
import matplotlib.widgets as mpwidgets
@@ -1440,17 +1452,15 @@ def overlay_images(
14401452
# Create the plot
14411453
fig, (ax0, ax1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": height_ratios})
14421454
img0 = ax0.imshow(x, **show_args1)
1443-
img1 = ax0.imshow(y, alpha=opacity, **show_args2)
1455+
img1 = ax0.imshow(y, alpha=alpha, **show_args2)
14441456

14451457
# Define the update function
14461458
def update(value):
14471459
img1.set_alpha(value)
14481460
fig.canvas.draw_idle()
14491461

14501462
# Create the slider
1451-
slider0 = mpwidgets.Slider(
1452-
ax=ax1, label="opacity", valmin=0, valmax=1, valinit=opacity
1453-
)
1463+
slider0 = mpwidgets.Slider(ax=ax1, label="alpha", valmin=0, valmax=1, valinit=alpha)
14541464
slider0.on_changed(update)
14551465

14561466
# Display the plot
@@ -2109,3 +2119,69 @@ def coords_to_geojson(coords, output=None):
21092119
f.write(geojson_str)
21102120
else:
21112121
return geojson_str
2122+
2123+
2124+
def show_canvas(image, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
2125+
"""Show a canvas to collect foreground and background points.
2126+
2127+
Args:
2128+
image (str | np.ndarray): The input image.
2129+
fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
2130+
bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
2131+
radius (int, optional): The radius of the points. Defaults to 5.
2132+
2133+
Returns:
2134+
tuple: A tuple of two lists of foreground and background points.
2135+
"""
2136+
if isinstance(image, str):
2137+
if image.startswith("http"):
2138+
image = download_file(image)
2139+
2140+
image = cv2.imread(image)
2141+
elif isinstance(image, np.ndarray):
2142+
pass
2143+
else:
2144+
raise ValueError("Input image must be a URL or a NumPy array.")
2145+
2146+
# Create an empty list to store the mouse click coordinates
2147+
left_clicks = []
2148+
right_clicks = []
2149+
2150+
# Create a mouse callback function
2151+
def get_mouse_coordinates(event, x, y, flags, param):
2152+
if event == cv2.EVENT_LBUTTONDOWN:
2153+
# Append the coordinates to the mouse_clicks list
2154+
left_clicks.append((x, y))
2155+
2156+
# Draw a green circle at the mouse click coordinates
2157+
cv2.circle(image, (x, y), radius, fg_color, -1)
2158+
2159+
# Show the updated image with the circle
2160+
cv2.imshow("Image", image)
2161+
2162+
elif event == cv2.EVENT_RBUTTONDOWN:
2163+
# Append the coordinates to the mouse_clicks list
2164+
right_clicks.append((x, y))
2165+
2166+
# Draw a red circle at the mouse click coordinates
2167+
cv2.circle(image, (x, y), radius, bg_color, -1)
2168+
2169+
# Show the updated image with the circle
2170+
cv2.imshow("Image", image)
2171+
2172+
# Create a window to display the image
2173+
cv2.namedWindow("Image")
2174+
2175+
# Set the mouse callback function for the window
2176+
cv2.setMouseCallback("Image", get_mouse_coordinates)
2177+
2178+
# Display the image in the window
2179+
cv2.imshow("Image", image)
2180+
2181+
# Wait for a key press to exit
2182+
cv2.waitKey(0)
2183+
2184+
# Destroy the window
2185+
cv2.destroyAllWindows()
2186+
2187+
return left_clicks, right_clicks

samgeo/samgeo.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ def show_masks(
309309
if self.batch:
310310
self.objects = cv2.imread(self.masks)
311311
else:
312-
313312
if self.objects is None:
314313
self.save_masks(foreground=foreground, **kwargs)
315314

@@ -319,7 +318,13 @@ def show_masks(
319318
plt.show()
320319

321320
def show_anns(
322-
self, figsize=(12, 10), axis="off", alpha=0.35, output=None, blend=True, **kwargs
321+
self,
322+
figsize=(12, 10),
323+
axis="off",
324+
alpha=0.35,
325+
output=None,
326+
blend=True,
327+
**kwargs,
323328
):
324329
"""Show the annotations (objects with random color) on the input image.
325330
@@ -376,7 +381,9 @@ def show_anns(
376381

377382
if output is not None:
378383
if blend:
379-
array = blend_images(self.annotations, self.image, alpha=alpha, show=False)
384+
array = blend_images(
385+
self.annotations, self.image, alpha=alpha, show=False
386+
)
380387
else:
381388
array = self.annotations
382389
array_to_image(array, output, self.source)
@@ -494,6 +501,12 @@ def predict(
494501
if isinstance(point_coords, dict):
495502
point_coords = geojson_to_coords(point_coords)
496503

504+
if hasattr(self, "point_coords"):
505+
point_coords = self.point_coords
506+
507+
if hasattr(self, "point_labels"):
508+
point_labels = self.point_labels
509+
497510
if point_crs is not None:
498511
point_coords = coords_to_xy(self.image, point_coords, point_crs)
499512

@@ -533,10 +546,45 @@ def predict(
533546
return masks, scores, logits
534547

535548
def show_map(self, basemap="SATELLITE", repeat_mode=True, out_dir=None, **kwargs):
549+
"""Show the interactive map.
550+
551+
Args:
552+
basemap (str, optional): The basemap. It can be one of the following: SATELLITE, ROADMAP, TERRAIN, HYBRID.
553+
repeat_mode (bool, optional): Whether to use the repeat mode for draw control. Defaults to True.
554+
out_dir (str, optional): The path to the output directory. Defaults to None.
555+
556+
Returns:
557+
leafmap.Map: The map object.
558+
"""
536559
return sam_map_gui(
537560
self, basemap=basemap, repeat_mode=repeat_mode, out_dir=out_dir, **kwargs
538561
)
539562

563+
def show_canvas(self, fg_color=(0, 255, 0), bg_color=(0, 0, 255), radius=5):
564+
"""Show a canvas to collect foreground and background points.
565+
566+
Args:
567+
image (str | np.ndarray): The input image.
568+
fg_color (tuple, optional): The color for the foreground points. Defaults to (0, 255, 0).
569+
bg_color (tuple, optional): The color for the background points. Defaults to (0, 0, 255).
570+
radius (int, optional): The radius of the points. Defaults to 5.
571+
572+
Returns:
573+
tuple: A tuple of two lists of foreground and background points.
574+
"""
575+
576+
if self.image is None:
577+
raise ValueError("Please run set_image() first.")
578+
579+
image = self.image
580+
fg_points, bg_points = show_canvas(image, fg_color, bg_color, radius)
581+
self.fg_points = fg_points
582+
self.bg_points = bg_points
583+
point_coords = fg_points + bg_points
584+
point_labels = [1] * len(fg_points) + [0] * len(bg_points)
585+
self.point_coords = point_coords
586+
self.point_labels = point_labels
587+
540588
def image_to_image(self, image, **kwargs):
541589
return image_to_image(image, self, **kwargs)
542590

0 commit comments

Comments
 (0)