77import json
88import os
99import typing
10+ from traceback import format_exc
1011from urllib .parse import urlparse
1112
1213import niquests
2223from starlette .requests import HTTPConnection , Request
2324from starlette .types import ASGIApp , Receive , Scope , Send
2425
26+ from .._exceptions import ModelFetchError
2527from .._misc import get_username_secret_from_headers
2628from ..nextcloud import AsyncNextcloudApp , NextcloudApp
2729from ..talk_bot import TalkBotMessage
28- from .defs import LogLvl
2930from .misc import persistent_storage
3031
3132
@@ -70,9 +71,24 @@ def set_handlers(
7071
7172 .. note:: When this parameter is ``False``, the provision of ``models_to_fetch`` is not allowed.
7273
73- :param models_to_fetch: Dictionary describing which models should be downloaded during `init`.
74+ :param models_to_fetch: Dictionary describing which models should be downloaded during `init` of the form:
75+ .. code-block:: python
76+ {
77+ "model_url_1": {
78+ "save_path": "path_or_filename_to_save_the_model_to",
79+ },
80+ "huggingface_model_name_1": {
81+ "max_workers": 4,
82+ "cache_dir": "path_to_cache_dir",
83+ "revision": "revision_to_fetch",
84+ ...
85+ },
86+ ...
87+ }
88+
7489
7590 .. note:: ``huggingface_hub`` package should be present for automatic models fetching.
91+ All model options are optional and can be left empty.
7692
7793 :param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not.
7894
@@ -121,74 +137,98 @@ def __map_app_static_folders(fast_api_app: FastAPI):
121137
122138
123139def fetch_models_task (nc : NextcloudApp , models : dict [str , dict ], progress_init_start_value : int ) -> None :
124- """Use for cases when you want to define custom `/init` but still need to easy download models."""
140+ """Use for cases when you want to define custom `/init` but still need to easy download models.
141+
142+ :param nc: NextcloudApp instance.
143+ :param models_to_fetch: Dictionary describing which models should be downloaded of the form:
144+ .. code-block:: python
145+ {
146+ "model_url_1": {
147+ "save_path": "path_or_filename_to_save_the_model_to",
148+ },
149+ "huggingface_model_name_1": {
150+ "max_workers": 4,
151+ "cache_dir": "path_to_cache_dir",
152+ "revision": "revision_to_fetch",
153+ ...
154+ },
155+ ...
156+ }
157+
158+ .. note:: ``huggingface_hub`` package should be present for automatic models fetching.
159+ All model options are optional and can be left empty.
160+
161+ :param progress_init_start_value: Integer value defining from which percent the progress should start.
162+
163+ :raises ModelFetchError: in case of a model download error.
164+ :raises NextcloudException: in case of a network error reaching the Nextcloud server.
165+ """
125166 if models :
126167 current_progress = progress_init_start_value
127168 percent_for_each = min (int ((100 - progress_init_start_value ) / len (models )), 99 )
128169 for model in models :
129- if model .startswith (("http://" , "https://" )):
130- models [model ]["path" ] = __fetch_model_as_file (
131- current_progress , percent_for_each , nc , model , models [model ]
132- )
133- else :
134- models [model ]["path" ] = __fetch_model_as_snapshot (
135- current_progress , percent_for_each , nc , model , models [model ]
136- )
137- current_progress += percent_for_each
170+ try :
171+ if model .startswith (("http://" , "https://" )):
172+ models [model ]["path" ] = __fetch_model_as_file (
173+ current_progress , percent_for_each , nc , model , models [model ]
174+ )
175+ else :
176+ models [model ]["path" ] = __fetch_model_as_snapshot (
177+ current_progress , percent_for_each , nc , model , models [model ]
178+ )
179+ current_progress += percent_for_each
180+ except BaseException as e : # noqa pylint: disable=broad-exception-caught
181+ nc .set_init_status (current_progress , f"Downloading of '{ model } ' failed: { e } : { format_exc ()} " )
182+ raise ModelFetchError (f"Downloading of '{ model } ' failed." ) from e
138183 nc .set_init_status (100 )
139184
140185
141186def __fetch_model_as_file (
142187 current_progress : int , progress_for_task : int , nc : NextcloudApp , model_path : str , download_options : dict
143- ) -> str | None :
188+ ) -> str :
144189 result_path = download_options .pop ("save_path" , urlparse (model_path ).path .split ("/" )[- 1 ])
145- try :
146-
147- with niquests .get ("GET" , model_path , stream = True ) as response :
148- if not response .is_success :
149- nc .log (LogLvl .ERROR , f"Downloading of '{ model_path } ' returned { response .status_code } status." )
150- return None
151- downloaded_size = 0
152- linked_etag = ""
153- for each_history in response .history :
154- linked_etag = each_history .headers .get ("X-Linked-ETag" , "" )
155- if linked_etag :
156- break
157- if not linked_etag :
158- linked_etag = response .headers .get ("X-Linked-ETag" , response .headers .get ("ETag" , "" ))
159- total_size = int (response .headers .get ("Content-Length" ))
160- try :
161- existing_size = os .path .getsize (result_path )
162- except OSError :
163- existing_size = 0
164- if linked_etag and total_size == existing_size :
165- with builtins .open (result_path , "rb" ) as file :
166- sha256_hash = hashlib .sha256 ()
167- for byte_block in iter (lambda : file .read (4096 ), b"" ):
168- sha256_hash .update (byte_block )
169- if f'"{ sha256_hash .hexdigest ()} "' == linked_etag :
170- nc .set_init_status (min (current_progress + progress_for_task , 99 ))
171- return None
172-
173- with builtins .open (result_path , "wb" ) as file :
174- last_progress = current_progress
175- for chunk in response .iter_raw (- 1 ):
176- downloaded_size += file .write (chunk )
177- if total_size :
178- new_progress = min (current_progress + int (progress_for_task * downloaded_size / total_size ), 99 )
179- if new_progress != last_progress :
180- nc .set_init_status (new_progress )
181- last_progress = new_progress
182-
183- return result_path
184- except Exception as e : # noqa pylint: disable=broad-exception-caught
185- nc .log (LogLvl .ERROR , f"Downloading of '{ model_path } ' raised an exception: { e } " )
186-
187- return None
190+ with niquests .get (model_path , stream = True ) as response :
191+ if not response .ok :
192+ raise ModelFetchError (
193+ f"Downloading of '{ model_path } ' failed, returned ({ response .status_code } ) { response .text } "
194+ )
195+ downloaded_size = 0
196+ linked_etag = ""
197+ for each_history in response .history :
198+ linked_etag = each_history .headers .get ("X-Linked-ETag" , "" )
199+ if linked_etag :
200+ break
201+ if not linked_etag :
202+ linked_etag = response .headers .get ("X-Linked-ETag" , response .headers .get ("ETag" , "" ))
203+ total_size = int (response .headers .get ("Content-Length" ))
204+ try :
205+ existing_size = os .path .getsize (result_path )
206+ except OSError :
207+ existing_size = 0
208+ if linked_etag and total_size == existing_size :
209+ with builtins .open (result_path , "rb" ) as file :
210+ sha256_hash = hashlib .sha256 ()
211+ for byte_block in iter (lambda : file .read (4096 ), b"" ):
212+ sha256_hash .update (byte_block )
213+ if f'"{ sha256_hash .hexdigest ()} "' == linked_etag :
214+ nc .set_init_status (min (current_progress + progress_for_task , 99 ))
215+ return result_path
216+
217+ with builtins .open (result_path , "wb" ) as file :
218+ last_progress = current_progress
219+ for chunk in response .iter_raw (- 1 ):
220+ downloaded_size += file .write (chunk )
221+ if total_size :
222+ new_progress = min (current_progress + int (progress_for_task * downloaded_size / total_size ), 99 )
223+ if new_progress != last_progress :
224+ nc .set_init_status (new_progress )
225+ last_progress = new_progress
226+
227+ return result_path
188228
189229
190230def __fetch_model_as_snapshot (
191- current_progress : int , progress_for_task , nc : NextcloudApp , mode_name : str , download_options : dict
231+ current_progress : int , progress_for_task , nc : NextcloudApp , model_name : str , download_options : dict
192232) -> str :
193233 from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
194234 from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
@@ -201,7 +241,7 @@ def display(self, msg=None, pos=None):
201241 workers = download_options .pop ("max_workers" , 2 )
202242 cache = download_options .pop ("cache_dir" , persistent_storage ())
203243 return snapshot_download (
204- mode_name , tqdm_class = TqdmProgress , ** download_options , max_workers = workers , cache_dir = cache
244+ model_name , tqdm_class = TqdmProgress , ** download_options , max_workers = workers , cache_dir = cache
205245 )
206246
207247
0 commit comments