1+ # Heavy because they consume a lot of memory and we want to import them as late as possible to reduce the footprint
2+ # Transformers / Sentence transformers utils. This module should be imported as late as possible
3+ # to reduce the memory footprint of a worker: we don't bother handling the uncaching/gc collecting because
4+ # we want to combine it with idle unload: the gunicorn worker will just suppress itself when unused freeing the memory
5+ # as wished
6+ from pathlib import Path
7+ from typing import Optional , Union
8+
9+ from huggingface_hub import HfApi , login , snapshot_download
10+
11+ from transformers import WhisperForConditionalGeneration , pipeline
12+ from transformers .file_utils import is_tf_available , is_torch_available
13+ from transformers .pipelines import Pipeline
14+
15+ from huggingface_inference_toolkit .diffusers_utils import (
16+ get_diffusers_pipeline ,
17+ is_diffusers_available ,
18+ )
19+ from huggingface_inference_toolkit .logging import logger
20+ from huggingface_inference_toolkit .sentence_transformers_utils import (
21+ get_sentence_transformers_pipeline ,
22+ is_sentence_transformers_available ,
23+ )
24+ from huggingface_inference_toolkit .utils import create_artifact_filter
25+ from huggingface_inference_toolkit .optimum_utils import (
26+ get_optimum_neuron_pipeline ,
27+ is_optimum_neuron_available ,
28+ )
29+
30+
31+ def load_repository_from_hf (
32+ repository_id : Optional [str ] = None ,
33+ target_dir : Optional [Union [str , Path ]] = None ,
34+ framework : Optional [str ] = None ,
35+ revision : Optional [str ] = None ,
36+ hf_hub_token : Optional [str ] = None ,
37+ ):
38+ """
39+ Load a model from huggingface hub.
40+ """
41+
42+ if hf_hub_token is not None :
43+ login (token = hf_hub_token )
44+
45+ if framework is None :
46+ framework = _get_framework ()
47+
48+ if isinstance (target_dir , str ):
49+ target_dir = Path (target_dir )
50+
51+ # create workdir
52+ if not target_dir .exists ():
53+ target_dir .mkdir (parents = True )
54+
55+ # check if safetensors weights are available
56+ if framework == "pytorch" :
57+ files = HfApi ().model_info (repository_id ).siblings
58+ if any (f .rfilename .endswith ("safetensors" ) for f in files ):
59+ framework = "safetensors"
60+
61+ # create regex to only include the framework specific weights
62+ ignore_regex = create_artifact_filter (framework )
63+ logger .info (f"Ignore regex pattern for files, which are not downloaded: { ', ' .join (ignore_regex ) } " )
64+
65+ # Download the repository to the workdir and filter out non-framework
66+ # specific weights
67+ snapshot_download (
68+ repo_id = repository_id ,
69+ revision = revision ,
70+ local_dir = str (target_dir ),
71+ local_dir_use_symlinks = False ,
72+ ignore_patterns = ignore_regex ,
73+ )
74+
75+ return target_dir
76+
77+
78+ def get_device ():
79+ """
80+ The get device function will return the device for the DL Framework.
81+ """
82+ gpu = _is_gpu_available ()
83+
84+ if gpu :
85+ return 0
86+ else :
87+ return - 1
88+
89+
90+ if is_tf_available ():
91+ import tensorflow as tf
92+
93+
94+ if is_torch_available ():
95+ import torch
96+
97+
98+ def _is_gpu_available ():
99+ """
100+ checks if a gpu is available.
101+ """
102+ if is_tf_available ():
103+ return True if len (tf .config .list_physical_devices ("GPU" )) > 0 else False
104+ elif is_torch_available ():
105+ return torch .cuda .is_available ()
106+ else :
107+ raise RuntimeError (
108+ "At least one of TensorFlow 2.0 or PyTorch should be installed. "
109+ "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
110+ "To install PyTorch, read the instructions at https://pytorch.org/."
111+ )
112+
113+
114+ def _get_framework ():
115+ """
116+ extracts which DL framework is used for inference, if both are installed use pytorch
117+ """
118+
119+ if is_torch_available ():
120+ return "pytorch"
121+ elif is_tf_available ():
122+ return "tensorflow"
123+ else :
124+ raise RuntimeError (
125+ "At least one of TensorFlow 2.0 or PyTorch should be installed. "
126+ "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
127+ "To install PyTorch, read the instructions at https://pytorch.org/."
128+ )
129+
130+
131+ def get_pipeline (
132+ task : Union [str , None ],
133+ model_dir : Path ,
134+ ** kwargs ,
135+ ) -> Pipeline :
136+ """
137+ create pipeline class for a specific task based on local saved model
138+ """
139+
140+ # import as late as possible to reduce the footprint
141+
142+ if task is None :
143+ raise EnvironmentError (
144+ "The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined"
145+ )
146+
147+ if task == "conversational" :
148+ task = "text-generation"
149+
150+ if is_optimum_neuron_available ():
151+ logger .info ("Using device Neuron" )
152+ return get_optimum_neuron_pipeline (task = task , model_dir = model_dir )
153+
154+ device = get_device ()
155+ logger .info (f"Using device { 'GPU' if device == 0 else 'CPU' } " )
156+
157+ # define tokenizer or feature extractor as kwargs to load it the pipeline
158+ # correctly
159+ if task in {
160+ "automatic-speech-recognition" ,
161+ "image-segmentation" ,
162+ "image-classification" ,
163+ "audio-classification" ,
164+ "object-detection" ,
165+ "zero-shot-image-classification" ,
166+ }:
167+ kwargs ["feature_extractor" ] = model_dir
168+ elif task not in {"image-text-to-text" , "image-to-text" , "text-to-image" }:
169+ kwargs ["tokenizer" ] = model_dir
170+
171+ if is_sentence_transformers_available () and task in [
172+ "sentence-similarity" ,
173+ "sentence-embeddings" ,
174+ "sentence-ranking" ,
175+ ]:
176+ hf_pipeline = get_sentence_transformers_pipeline (task = task , model_dir = model_dir , device = device , ** kwargs )
177+ elif is_diffusers_available () and task == "text-to-image" :
178+ hf_pipeline = get_diffusers_pipeline (task = task , model_dir = model_dir , device = device , ** kwargs )
179+ else :
180+ hf_pipeline = pipeline (task = task , model = model_dir , device = device , ** kwargs )
181+
182+ if task == "automatic-speech-recognition" and isinstance (hf_pipeline .model , WhisperForConditionalGeneration ):
183+ # set chunk length to 30s for whisper to enable long audio files
184+ hf_pipeline ._preprocess_params ["chunk_length_s" ] = 30
185+ hf_pipeline .model .config .forced_decoder_ids = hf_pipeline .tokenizer .get_decoder_prompt_ids (
186+ language = "english" , task = "transcribe"
187+ )
188+ return hf_pipeline # type: ignore
0 commit comments