Skip to content

Commit 54d8174

Browse files
committed
Added download file functions
1 parent 8b956bd commit 54d8174

File tree

3 files changed

+139
-2
lines changed

3 files changed

+139
-2
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ onnxruntime; python_version < '3.11'
66
onnx
77
geopandas
88
rasterio
9-
tqdm
9+
tqdm
10+
gdown

samgeo/common.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,136 @@ def temp_file_path(extension):
6666
return file_path
6767

6868

69+
def github_raw_url(url):
70+
"""Get the raw URL for a GitHub file.
71+
72+
Args:
73+
url (str): The GitHub URL.
74+
Returns:
75+
str: The raw URL.
76+
"""
77+
if isinstance(url, str) and url.startswith("https://github.com/") and "blob" in url:
78+
url = url.replace("github.com", "raw.githubusercontent.com").replace(
79+
"blob/", ""
80+
)
81+
return url
82+
83+
84+
def download_file(
85+
url=None,
86+
output=None,
87+
quiet=False,
88+
proxy=None,
89+
speed=None,
90+
use_cookies=True,
91+
verify=True,
92+
id=None,
93+
fuzzy=False,
94+
resume=False,
95+
unzip=True,
96+
overwrite=False,
97+
subfolder=False,
98+
):
99+
"""Download a file from URL, including Google Drive shared URL.
100+
101+
Args:
102+
url (str, optional): Google Drive URL is also supported. Defaults to None.
103+
output (str, optional): Output filename. Default is basename of URL.
104+
quiet (bool, optional): Suppress terminal output. Default is False.
105+
proxy (str, optional): Proxy. Defaults to None.
106+
speed (float, optional): Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.
107+
use_cookies (bool, optional): Flag to use cookies. Defaults to True.
108+
verify (bool | str, optional): Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string,
109+
in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.
110+
id (str, optional): Google Drive's file ID. Defaults to None.
111+
fuzzy (bool, optional): Fuzzy extraction of Google Drive's file Id. Defaults to False.
112+
resume (bool, optional): Resume the download from existing tmp file if possible. Defaults to False.
113+
unzip (bool, optional): Unzip the file. Defaults to True.
114+
overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
115+
subfolder (bool, optional): Create a subfolder with the same name as the file. Defaults to False.
116+
117+
Returns:
118+
str: The output file path.
119+
"""
120+
import zipfile
121+
try:
122+
import gdown
123+
except ImportError:
124+
print(
125+
"The gdown package is required for this function. Use `pip install gdown` to install it."
126+
)
127+
return
128+
129+
if output is None:
130+
if isinstance(url, str) and url.startswith("http"):
131+
output = os.path.basename(url)
132+
133+
out_dir = os.path.abspath(os.path.dirname(output))
134+
if not os.path.exists(out_dir):
135+
os.makedirs(out_dir)
136+
137+
if isinstance(url, str):
138+
if os.path.exists(os.path.abspath(output)) and (not overwrite):
139+
print(
140+
f"{output} already exists. Skip downloading. Set overwrite=True to overwrite."
141+
)
142+
return os.path.abspath(output)
143+
else:
144+
url = github_raw_url(url)
145+
146+
if "https://drive.google.com/file/d/" in url:
147+
fuzzy = True
148+
149+
output = gdown.download(
150+
url, output, quiet, proxy, speed, use_cookies, verify, id, fuzzy, resume
151+
)
152+
153+
if unzip and output.endswith(".zip"):
154+
with zipfile.ZipFile(output, "r") as zip_ref:
155+
if not quiet:
156+
print("Extracting files...")
157+
if subfolder:
158+
basename = os.path.splitext(os.path.basename(output))[0]
159+
160+
output = os.path.join(out_dir, basename)
161+
if not os.path.exists(output):
162+
os.makedirs(output)
163+
zip_ref.extractall(output)
164+
else:
165+
zip_ref.extractall(os.path.dirname(output))
166+
167+
return os.path.abspath(output)
168+
169+
170+
def download_checkpoint(url=None, output=None, overwrite=False, **kwargs):
171+
"""Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
172+
173+
Args:
174+
url (str, optional): The checkpoint URL. Defaults to None.
175+
output (str, optional): The output file path. Defaults to None.
176+
overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
177+
178+
Returns:
179+
str: The output file path.
180+
"""
181+
checkpoints = {
182+
'sam_vit_h_4b8939.pth': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
183+
'sam_vit_l_0b3195.pth': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
184+
'sam_vit_b_01ec64.pth': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
185+
}
186+
187+
if isinstance(url, str) and url in checkpoints:
188+
url = checkpoints[url]
189+
190+
if url is None:
191+
url = checkpoints['sam_vit_h_4b8939.pth']
192+
193+
if output is None:
194+
output = os.path.basename(url)
195+
196+
return download_file(url, output,overwrite=overwrite, **kwargs)
197+
198+
69199
def image_to_cog(source, dst_path=None, profile="deflate", **kwargs):
70200
"""Converts an image to a COG file.
71201

samgeo/samgeo.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
The source code is adapted from https://github.com/aliaksandr960/segment-anything-eo. Credit to the author Aliaksandr Hancharenka.
33
"""
44

5-
5+
import os
66
import numpy as np
77
import cv2
88
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
@@ -41,6 +41,12 @@ def __init__(
4141
mask_multiplier=255,
4242
sam_kwargs=None,
4343
):
44+
45+
if not os.path.exists(checkpoint):
46+
print(f'Checkpoint {checkpoint} does not exist.')
47+
download_checkpoint(output=checkpoint)
48+
49+
4450
self.checkpoint = checkpoint
4551
self.model_type = model_type
4652
self.device = device

0 commit comments

Comments
 (0)