Skip to content

Commit e3ab2c7

Browse files
authored
Made show_masks method compatible with batch mode (#54)
1 parent de82bb3 commit e3ab2c7

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

samgeo/samgeo.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def generate(
176176
raise ValueError(f"Input path {source} does not exist.")
177177

178178
if batch: # Subdivide the image into tiles and segment each tile
179+
self.batch = True
180+
self.source = source
181+
self.masks = output
179182
return tiff_to_tiff(
180183
source,
181184
output,
@@ -197,6 +200,7 @@ def generate(
197200
mask_generator = self.mask_generator # The automatic mask generator
198201
masks = mask_generator.generate(image) # Segment the input image
199202
self.masks = masks # Store the masks as a list of dictionaries
203+
self.batch = False
200204

201205
if output is not None:
202206
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
@@ -302,8 +306,12 @@ def show_masks(
302306

303307
import matplotlib.pyplot as plt
304308

305-
if self.objects is None:
306-
self.save_masks(foreground=foreground, **kwargs)
309+
if self.batch:
310+
self.objects = cv2.imread(self.masks)
311+
else:
312+
313+
if self.objects is None:
314+
self.save_masks(foreground=foreground, **kwargs)
307315

308316
plt.figure(figsize=figsize)
309317
plt.imshow(self.objects, cmap=cmap)

0 commit comments

Comments
 (0)