@@ -66,6 +66,136 @@ def temp_file_path(extension):
6666 return file_path
6767
6868
69+ def github_raw_url (url ):
70+ """Get the raw URL for a GitHub file.
71+
72+ Args:
73+ url (str): The GitHub URL.
74+ Returns:
75+ str: The raw URL.
76+ """
77+ if isinstance (url , str ) and url .startswith ("https://github.com/" ) and "blob" in url :
78+ url = url .replace ("github.com" , "raw.githubusercontent.com" ).replace (
79+ "blob/" , ""
80+ )
81+ return url
82+
83+
84+ def download_file (
85+ url = None ,
86+ output = None ,
87+ quiet = False ,
88+ proxy = None ,
89+ speed = None ,
90+ use_cookies = True ,
91+ verify = True ,
92+ id = None ,
93+ fuzzy = False ,
94+ resume = False ,
95+ unzip = True ,
96+ overwrite = False ,
97+ subfolder = False ,
98+ ):
99+ """Download a file from URL, including Google Drive shared URL.
100+
101+ Args:
102+ url (str, optional): Google Drive URL is also supported. Defaults to None.
103+ output (str, optional): Output filename. Default is basename of URL.
104+ quiet (bool, optional): Suppress terminal output. Default is False.
105+ proxy (str, optional): Proxy. Defaults to None.
106+ speed (float, optional): Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.
107+ use_cookies (bool, optional): Flag to use cookies. Defaults to True.
108+ verify (bool | str, optional): Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string,
109+ in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.
110+ id (str, optional): Google Drive's file ID. Defaults to None.
111+ fuzzy (bool, optional): Fuzzy extraction of Google Drive's file Id. Defaults to False.
112+ resume (bool, optional): Resume the download from existing tmp file if possible. Defaults to False.
113+ unzip (bool, optional): Unzip the file. Defaults to True.
114+ overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
115+ subfolder (bool, optional): Create a subfolder with the same name as the file. Defaults to False.
116+
117+ Returns:
118+ str: The output file path.
119+ """
120+ import zipfile
121+ try :
122+ import gdown
123+ except ImportError :
124+ print (
125+ "The gdown package is required for this function. Use `pip install gdown` to install it."
126+ )
127+ return
128+
129+ if output is None :
130+ if isinstance (url , str ) and url .startswith ("http" ):
131+ output = os .path .basename (url )
132+
133+ out_dir = os .path .abspath (os .path .dirname (output ))
134+ if not os .path .exists (out_dir ):
135+ os .makedirs (out_dir )
136+
137+ if isinstance (url , str ):
138+ if os .path .exists (os .path .abspath (output )) and (not overwrite ):
139+ print (
140+ f"{ output } already exists. Skip downloading. Set overwrite=True to overwrite."
141+ )
142+ return os .path .abspath (output )
143+ else :
144+ url = github_raw_url (url )
145+
146+ if "https://drive.google.com/file/d/" in url :
147+ fuzzy = True
148+
149+ output = gdown .download (
150+ url , output , quiet , proxy , speed , use_cookies , verify , id , fuzzy , resume
151+ )
152+
153+ if unzip and output .endswith (".zip" ):
154+ with zipfile .ZipFile (output , "r" ) as zip_ref :
155+ if not quiet :
156+ print ("Extracting files..." )
157+ if subfolder :
158+ basename = os .path .splitext (os .path .basename (output ))[0 ]
159+
160+ output = os .path .join (out_dir , basename )
161+ if not os .path .exists (output ):
162+ os .makedirs (output )
163+ zip_ref .extractall (output )
164+ else :
165+ zip_ref .extractall (os .path .dirname (output ))
166+
167+ return os .path .abspath (output )
168+
169+
170+ def download_checkpoint (url = None , output = None , overwrite = False , ** kwargs ):
171+ """Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
172+
173+ Args:
174+ url (str, optional): The checkpoint URL. Defaults to None.
175+ output (str, optional): The output file path. Defaults to None.
176+ overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
177+
178+ Returns:
179+ str: The output file path.
180+ """
181+ checkpoints = {
182+ 'sam_vit_h_4b8939.pth' : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" ,
183+ 'sam_vit_l_0b3195.pth' : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" ,
184+ 'sam_vit_b_01ec64.pth' : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" ,
185+ }
186+
187+ if isinstance (url , str ) and url in checkpoints :
188+ url = checkpoints [url ]
189+
190+ if url is None :
191+ url = checkpoints ['sam_vit_h_4b8939.pth' ]
192+
193+ if output is None :
194+ output = os .path .basename (url )
195+
196+ return download_file (url , output ,overwrite = overwrite , ** kwargs )
197+
198+
69199def image_to_cog (source , dst_path = None , profile = "deflate" , ** kwargs ):
70200 """Converts an image to a COG file.
71201
0 commit comments