Skip to content

Commit bca0216

Browse files
authored
Added support for custom projection bbox (#39)
* Add bbox_to_xy function * Add support for bbox custom projection
1 parent 6013784 commit bca0216

File tree

2 files changed

+108
-5
lines changed

2 files changed

+108
-5
lines changed

samgeo/common.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,9 @@ def draw_tile(
525525
for k, (fut, corner_xy) in enumerate(zip(futures, corners), 1):
526526
bigim = paste_tile(bigim, base_size, fut.result(), corner_xy, bbox)
527527
if not quiet:
528-
print("Downloaded image %d/%d" % (k, totalnum))
528+
print(
529+
f"Downloaded image {str(k).zfill(len(str(totalnum)))}/{totalnum}"
530+
)
529531

530532
if not quiet:
531533
print("Saving GeoTIFF. Please wait...")
@@ -592,6 +594,7 @@ def get_crs(src_fp):
592594

593595
def get_features(src_fp, bidx=1):
594596
from rasterio import features
597+
595598
with rasterio.open(src_fp) as src:
596599
features = features.dataset_features(
597600
src,
@@ -736,6 +739,93 @@ def coords_to_xy(
736739
return result
737740

738741

742+
def bbox_to_xy(
743+
src_fp: str, coords: list, coord_crs: str = "epsg:4326", **kwargs
744+
) -> list:
745+
"""Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
746+
747+
Args:
748+
src_fp (str): The source raster file path.
749+
coords (list): A list of coordinates in the format of [[minx, miny, maxx, maxy], [minx, miny, maxx, maxy], ...]
750+
coord_crs (str, optional): The coordinate CRS of the input coordinates. Defaults to "epsg:4326".
751+
752+
Returns:
753+
list: A list of pixel coordinates in the format of [[minx, miny, maxx, maxy], ...]
754+
"""
755+
756+
if isinstance(coords, str):
757+
gdf = gpd.read_file(coords)
758+
coords = gdf.geometry.bounds.values.tolist()
759+
if gdf.crs is not None:
760+
coord_crs = f"epsg:{gdf.crs.to_epsg()}"
761+
elif isinstance(coords, np.ndarray):
762+
coords = coords.tolist()
763+
if isinstance(coords, dict):
764+
import json
765+
766+
geojson = json.dumps(coords)
767+
gdf = gpd.read_file(geojson, driver="GeoJSON")
768+
coords = gdf.geometry.bounds.values.tolist()
769+
770+
elif not isinstance(coords, list):
771+
raise ValueError("coords must be a list of coordinates.")
772+
773+
if not isinstance(coords[0], list):
774+
coords = [coords]
775+
776+
new_coords = []
777+
778+
with rasterio.open(src_fp) as src:
779+
width = src.width
780+
height = src.height
781+
782+
for coord in coords:
783+
minx, miny, maxx, maxy = coord
784+
785+
if coord_crs != src.crs:
786+
minx, miny = transform_coords(minx, miny, coord_crs, src.crs, **kwargs)
787+
maxx, maxy = transform_coords(maxx, maxy, coord_crs, src.crs, **kwargs)
788+
789+
rows1, cols1 = rasterio.transform.rowcol(
790+
src.transform, minx, miny, **kwargs
791+
)
792+
rows2, cols2 = rasterio.transform.rowcol(
793+
src.transform, maxx, maxy, **kwargs
794+
)
795+
796+
new_coords.append([cols1, rows1, cols2, rows2])
797+
798+
else:
799+
new_coords.append([minx, miny, maxx, maxy])
800+
801+
result = []
802+
803+
for coord in new_coords:
804+
minx, miny, maxx, maxy = coord
805+
806+
if (
807+
minx >= 0
808+
and miny >= 0
809+
and maxx >= 0
810+
and maxy >= 0
811+
and minx < width
812+
and miny < height
813+
and maxx < width
814+
and maxy < height
815+
):
816+
result.append(coord)
817+
818+
if len(result) == 0:
819+
print("No valid pixel coordinates found.")
820+
return None
821+
elif len(result) == 1:
822+
return result[0]
823+
elif len(result) < len(coords):
824+
print("Some coordinates are out of the image boundary.")
825+
826+
return result
827+
828+
739829
def geojson_to_xy(
740830
src_fp: str, geojson: str, coord_crs: str = "epsg:4326", **kwargs
741831
) -> list:
@@ -1586,10 +1676,21 @@ def segment_button_click(change):
15861676
try:
15871677
if m.user_rois is not None:
15881678
filename = f"masks_{random_string()}.tif"
1589-
sam.predict(point_coords=m.user_rois, point_crs='EPSG:4326', output=filename)
1679+
sam.predict(
1680+
point_coords=m.user_rois,
1681+
point_crs="EPSG:4326",
1682+
output=filename,
1683+
)
15901684
if m.find_layer("Masks") is not None:
15911685
m.remove_layer(m.find_layer("Masks"))
1592-
m.add_raster(filename, nodata=0, cmap='Blues', opacity=opacity_slider.value, layer_name="Masks", zoom_to_layer=False)
1686+
m.add_raster(
1687+
filename,
1688+
nodata=0,
1689+
cmap="Blues",
1690+
opacity=opacity_slider.value,
1691+
layer_name="Masks",
1692+
zoom_to_layer=False,
1693+
)
15931694
output.clear_output()
15941695
segment_button.value = False
15951696
sam.prediction_fp = filename
@@ -1633,4 +1734,4 @@ def random_string(string_length=6):
16331734

16341735
# random.seed(1001)
16351736
letters = string.ascii_lowercase
1636-
return "".join(random.choice(letters) for i in range(string_length))
1737+
return "".join(random.choice(letters) for i in range(string_length))

samgeo/samgeo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,9 @@ def predict(
503503
)
504504
point_labels = np.array(point_labels)
505505

506+
if isinstance(box, list) and point_crs is not None:
507+
box = np.array(bbox_to_xy(self.image, box, point_crs))
508+
506509
predictor = self.predictor
507510
masks, scores, logits = predictor.predict(
508511
point_coords, point_labels, box, mask_input, multimask_output, return_logits
@@ -518,7 +521,6 @@ def predict(
518521
return masks, scores, logits
519522

520523
def show_map(self, **kwargs):
521-
522524
return sam_map_gui(self, **kwargs)
523525

524526
def image_to_image(self, image, **kwargs):

0 commit comments

Comments
 (0)