44from pathlib import Path
55from typing import Optional , Union
66
7- from huggingface_hub import login , snapshot_download
7+ from huggingface_hub import HfApi , login , snapshot_download
88from transformers import WhisperForConditionalGeneration , pipeline
99from transformers .file_utils import is_tf_available , is_torch_available
1010from transformers .pipelines import Conversation , Pipeline
1111
1212from huggingface_inference_toolkit .const import HF_DEFAULT_PIPELINE_NAME , HF_MODULE_NAME
1313from huggingface_inference_toolkit .diffusers_utils import (
14- check_supported_pipeline ,
1514 get_diffusers_pipeline ,
1615 is_diffusers_available ,
1716)
@@ -46,11 +45,12 @@ def is_optimum_available():
4645 "pt" : "pytorch*" ,
4746 "flax" : "flax*" ,
4847 "rust" : "rust*" ,
49- "onnx" : "*onnx" ,
48+ "onnx" : "*onnx* " ,
5049 "safetensors" : "*safetensors" ,
5150 "coreml" : "*mlmodel" ,
5251 "tflite" : "*tflite" ,
5352 "savedmodel" : "*tar.gz" ,
53+ "openvino" : "*openvino*" ,
5454 "ckpt" : "*ckpt" ,
5555}
5656
@@ -59,18 +59,8 @@ def create_artifact_filter(framework):
5959 """
6060 Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading
6161 """
62- ignore_regex_list = [
63- "pytorch*" ,
64- "tf*" ,
65- "flax*" ,
66- "rust*" ,
67- "*onnx" ,
68- "*safetensors" ,
69- "*mlmodel" ,
70- "*tflite" ,
71- "*tar.gz" ,
72- "*ckpt" ,
73- ]
62+ ignore_regex_list = list (framework2weight .values ())
63+
7464 pattern = framework2weight .get (framework , None )
7565 if pattern in ignore_regex_list :
7666 ignore_regex_list .remove (pattern )
@@ -157,6 +147,13 @@ def _load_repository_from_hf(
157147 if not target_dir .exists ():
158148 target_dir .mkdir (parents = True )
159149
150+ # check if safetensors weights are available
151+ if framework == "pytorch" :
152+ files = HfApi ().model_info (repository_id ).siblings
153+ if any (f .rfilename .endswith ("safetensors" ) for f in files ):
154+ framework = "safetensors"
155+
156+
160157 # create regex to only include the framework specific weights
161158 ignore_regex = create_artifact_filter (framework )
162159 logger .info (f"Ignore regex pattern for files, which are not downloaded: { ', ' .join (ignore_regex ) } " )
@@ -259,7 +256,7 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
259256 "sentence-ranking" ,
260257 ]:
261258 hf_pipeline = get_sentence_transformers_pipeline (task = task , model_dir = model_dir , device = device , ** kwargs )
262- elif is_diffusers_available () and check_supported_pipeline ( model_dir ) and task == "text-to-image" :
259+ elif is_diffusers_available () and task == "text-to-image" :
263260 hf_pipeline = get_diffusers_pipeline (task = task , model_dir = model_dir , device = device , ** kwargs )
264261 else :
265262 hf_pipeline = pipeline (task = task , model = model_dir , device = device , ** kwargs )
0 commit comments