Skip to content

Commit 4d69480

Browse files
authored
Add crs and transform for array input (#400)
* Add crs and transform for array input * Update show_anns
1 parent 2dc971c commit 4d69480

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

samgeo/common.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,12 +1458,16 @@ def array_to_image(
14581458
array = cv2.imread(array)
14591459
array = cv2.cvtColor(array, cv2.COLOR_BGR2RGB)
14601460

1461-
if output.endswith(".tif") and source is not None:
1462-
with rasterio.open(source) as src:
1463-
crs = src.crs
1464-
transform = src.transform
1465-
if compress is None:
1466-
compress = src.compression
1461+
if output.endswith(".tif"):
1462+
if source is not None:
1463+
with rasterio.open(source) as src:
1464+
crs = src.crs
1465+
transform = src.transform
1466+
if compress is None:
1467+
compress = src.compression
1468+
else:
1469+
crs = kwargs.get("crs", None)
1470+
transform = kwargs.get("transform", None)
14671471

14681472
# Determine the minimum and maximum values in the array
14691473

samgeo/samgeo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -425,11 +425,11 @@ def show_anns(
425425
img[m] = color_mask
426426
ax.imshow(img)
427427

428-
if "dpi" not in kwargs:
429-
kwargs["dpi"] = 100
428+
# if "dpi" not in kwargs:
429+
# kwargs["dpi"] = 100
430430

431-
if "bbox_inches" not in kwargs:
432-
kwargs["bbox_inches"] = "tight"
431+
# if "bbox_inches" not in kwargs:
432+
# kwargs["bbox_inches"] = "tight"
433433

434434
plt.axis(axis)
435435

@@ -442,7 +442,7 @@ def show_anns(
442442
)
443443
else:
444444
array = self.annotations
445-
array_to_image(array, output, self.source)
445+
array_to_image(array, output, self.source, **kwargs)
446446

447447
def set_image(self, image, image_format="RGB"):
448448
"""Set the input image as a numpy array.

samgeo/samgeo2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,11 @@ def show_anns(
456456
img[m] = color_mask
457457
ax.imshow(img)
458458

459-
if "dpi" not in kwargs:
460-
kwargs["dpi"] = 100
459+
# if "dpi" not in kwargs:
460+
# kwargs["dpi"] = 100
461461

462-
if "bbox_inches" not in kwargs:
463-
kwargs["bbox_inches"] = "tight"
462+
# if "bbox_inches" not in kwargs:
463+
# kwargs["bbox_inches"] = "tight"
464464

465465
plt.axis(axis)
466466

@@ -473,7 +473,7 @@ def show_anns(
473473
)
474474
else:
475475
array = self.annotations
476-
common.array_to_image(array, output, self.source)
476+
common.array_to_image(array, output, self.source, **kwargs)
477477

478478
@torch.no_grad()
479479
def set_image(

0 commit comments

Comments
 (0)