3232
3333
3434class SamGeo :
35+ """The main class for segmenting geospatial data with the Segment Anything Model (SAM). See
36+ https://github.com/facebookresearch/segment-anything
37+ """
38+
3539 def __init__ (
3640 self ,
3741 checkpoint = "sam_vit_h_4b8939.pth" ,
@@ -41,7 +45,17 @@ def __init__(
4145 mask_multiplier = 255 ,
4246 sam_kwargs = None ,
4347 ):
44-
48+ """Initialize the class.
49+
50+ Args:
51+ checkpoint (str, optional): The path to the checkpoint. It can be one of the following:
52+ sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth. Defaults to "sam_vit_h_4b8939.pth".
53+ model_type (str, optional): The model type. It can be one of the following: vit_h, vit_l, vit_l. Defaults to 'vit_h'.
54+ device (str, optional): The device to use. It can be one of the following: cpu, cuda. Defaults to 'cpu'.
55+ erosion_kernel (tuple, optional): The erosion kernel. Defaults to (3, 3).
56+ mask_multiplier (int, optional): The mask multiplier. Defaults to 255.
57+ sam_kwargs (_type_, optional): The arguments for the SAM model. Defaults to None.
58+ """
4559 if not os .path .exists (checkpoint ):
4660 print (f'Checkpoint { checkpoint } does not exist.' )
4761 download_checkpoint (output = checkpoint )
@@ -89,6 +103,13 @@ def __call__(self, image):
89103 return resulting_mask_with_borders * self .mask_multiplier
90104
91105 def generate (self , in_path , out_path , ** kwargs ):
106+ """Segment the input image and save the result to the output path.
107+
108+ Args:
109+ in_path (str): The path to the input image.
110+ out_path (str): The path to the output image.
111+ """
112+
92113 return tiff_to_tiff (in_path , out_path , self , ** kwargs )
93114
94115 def image_to_image (self , image , ** kwargs ):
@@ -98,7 +119,43 @@ def download_tms_as_tiff(self, source, pt1, pt2, zoom, dist):
98119 image = draw_tile (source , pt1 [0 ], pt1 [1 ], pt2 [0 ], pt2 [1 ], zoom , dist )
99120 return image
100121
101- def tiff_to_gpkg (self , tiff_path , gpkg_path , simplify_tolerance = None ):
122+ def tiff_to_gpkg (self , tiff_path , gpkg_path , simplify_tolerance = None , ** kwargs ):
123+ """Convert a tiff file to a gpkg file.
124+
125+ Args:
126+ tiff_path (str): The path to the tiff file.
127+ gpkg_path (str): The path to the gpkg file.
128+ simplify_tolerance (_type_, optional): The simplify tolerance. Defaults to None.
129+ """
130+
131+ with rasterio .open (tiff_path ) as src :
132+ band = src .read ()
133+
134+ mask = band != 0
135+ shapes = features .shapes (band , mask = mask , transform = src .transform )
136+
137+ fc = [
138+ {"geometry" : shapely .geometry .shape (shape ), "properties" : {"value" : value }}
139+ for shape , value in shapes
140+ ]
141+ if simplify_tolerance is not None :
142+ for i in fc :
143+ i ["geometry" ] = i ["geometry" ].simplify (tolerance = simplify_tolerance )
144+
145+ gdf = gpd .GeoDataFrame .from_features (fc )
146+ gdf .set_crs (epsg = src .crs .to_epsg (), inplace = True )
147+ gdf .to_file (gpkg_path , driver = 'GPKG' , ** kwargs )
148+
149+
150+ def tiff_to_vector (self , tiff_path , output , simplify_tolerance = None , ** kwargs ):
151+ """Convert a tiff file to a gpkg file.
152+
153+ Args:
154+ tiff_path (str): The path to the tiff file.
155+ output (str): The path to the vector file.
156+ simplify_tolerance (_type_, optional): The simplify tolerance. Defaults to None.
157+ """
158+
102159 with rasterio .open (tiff_path ) as src :
103160 band = src .read ()
104161
@@ -115,4 +172,4 @@ def tiff_to_gpkg(self, tiff_path, gpkg_path, simplify_tolerance=None):
115172
116173 gdf = gpd .GeoDataFrame .from_features (fc )
117174 gdf .set_crs (epsg = src .crs .to_epsg (), inplace = True )
118- gdf .to_file (gpkg_path , driver = 'GPKG' )
175+ gdf .to_file (output , ** kwargs )
0 commit comments