@@ -140,6 +140,7 @@ def _load_repository_from_hf(
140140
141141 if framework is None :
142142 framework = _get_framework ()
143+ logging .info (f"Framework: { framework } " )
143144
144145 if isinstance (target_dir , str ):
145146 target_dir = Path (target_dir )
@@ -149,22 +150,24 @@ def _load_repository_from_hf(
149150 target_dir .mkdir (parents = True )
150151
151152 # check if safetensors weights are available
152- if framework == "pytorch" :
153- files = HfApi ().model_info (repository_id ).siblings
154- if any (f .rfilename .endswith ("safetensors" ) for f in files ):
155- framework = "safetensors"
153+ # if framework == "pytorch":
154+ # files = HfApi().model_info(repository_id).siblings
155+ # if any(f.rfilename.endswith("safetensors") for f in files):
156+ # framework = "safetensors"
156157
157158 # create regex to only include the framework specific weights
158159 ignore_regex = create_artifact_filter (framework )
160+ logging .info (f"ignore_regex: { ignore_regex } " )
161+ logging .info (f"Framework after filtering: { framework } " )
159162 logger .info (f"Ignore regex pattern for files, which are not downloaded: { ', ' .join (ignore_regex ) } " )
160163
161164 # Download the repository to the workdir and filter out non-framework specific weights
162165 snapshot_download (
163- repository_id ,
164- revision = revision ,
165- local_dir = str (target_dir ),
166- local_dir_use_symlinks = False ,
167- ignore_patterns = ignore_regex ,
166+ repo_id = repository_id ,
167+ revision = revision ,
168+ local_dir = str (target_dir ),
169+ local_dir_use_symlinks = False ,
170+ ignore_patterns = ignore_regex ,
168171 )
169172
170173 return target_dir
@@ -223,7 +226,12 @@ def get_device():
223226 return - 1
224227
225228
226- def get_pipeline (task : str , model_dir : Path , ** kwargs ) -> Pipeline :
229+ def get_pipeline (
230+ task : str ,
231+ model_dir : Path ,
232+ framework = "pytorch" ,
233+ ** kwargs ,
234+ ) -> Pipeline :
227235 """
228236 create pipeline class for a specific task based on local saved model
229237 """
@@ -244,6 +252,12 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
244252 "zero-shot-image-classification" ,
245253 }:
246254 kwargs ["feature_extractor" ] = model_dir
255+ hf_pipeline = pipeline (
256+ task = task ,
257+ model = model_dir ,
258+ device = device ,
259+ ** kwargs
260+ )
247261 elif task in {"image-to-text" }:
248262 pass
249263 else :
@@ -265,12 +279,20 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
265279 logging .info (f"Model: { model_dir } " )
266280 logging .info (f"Device: { device } " )
267281 logging .info (f"Args: { kwargs } " )
268- hf_pipeline = pipeline (task = task , model = model_dir , device = device , ** kwargs )
282+ hf_pipeline = pipeline (
283+ task = task ,
284+ model = model_dir ,
285+ device = device ,
286+ ** kwargs
287+ )
269288
270289 # wrapp specific pipeline to support better ux
271290 if task == "conversational" :
272291 hf_pipeline = wrap_conversation_pipeline (hf_pipeline )
273- elif task == "automatic-speech-recognition" and isinstance (hf_pipeline .model , WhisperForConditionalGeneration ):
292+ elif task == "automatic-speech-recognition" and isinstance (
293+ hf_pipeline .model ,
294+ WhisperForConditionalGeneration
295+ ):
274296 # set chunk length to 30s for whisper to enable long audio files
275297 hf_pipeline ._preprocess_params ["chunk_length_s" ] = 30
276298 hf_pipeline ._preprocess_params ["ignore_warning" ] = True
0 commit comments