1- from typing import Any , Coroutine , Dict , Optional , List , TYPE_CHECKING
2- from aiconfig import ParameterizedModelParser , InferenceOptions , AIConfig
1+ from typing import Any , Dict , Literal , Optional , List , TYPE_CHECKING
2+ from aiconfig import ParameterizedModelParser , InferenceOptions
3+ from aiconfig .callback import CallbackEvent
4+ from pydantic import BaseModel
5+ import torch
6+ from aiconfig .schema import Prompt , Output , ExecuteResult , Attachment
37
4- from aiconfig .schema import Prompt , Output
5- from transformers import Pipeline
8+ from transformers import pipeline , Pipeline
69
710if TYPE_CHECKING :
811 from aiconfig import AIConfigRuntime
@@ -24,7 +27,7 @@ def __init__(self):
2427 config.register_model_parser(parser)
2528 """
2629 super ().__init__ ()
27- self .generators : dict [str , Pipeline ] = {}
30+ self .pipelines : dict [str , Pipeline ] = {}
2831
2932 def id (self ) -> str :
3033 """
@@ -38,10 +41,10 @@ async def serialize(
3841 data : Any ,
3942 ai_config : "AIConfigRuntime" ,
4043 parameters : Optional [Dict [str , Any ]] = None ,
41- ** completion_params ,
4244 ) -> List [Prompt ]:
4345 """
4446 Defines how a prompt and model inference settings get serialized in the .aiconfig.
47+ Assume input in the form of input(s) being passed into an already constructed pipeline.
4548
4649 Args:
4750 prompt (str): The prompt to be serialized.
@@ -52,14 +55,203 @@ async def serialize(
5255 Returns:
5356 str: Serialized representation of the prompt and inference settings.
5457 """
58+ raise NotImplementedError ("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition" )
5559
5660 async def deserialize (
5761 self ,
5862 prompt : Prompt ,
59- aiconfig : "AIConfig " ,
63+ aiconfig : "AIConfigRuntime " ,
6064 params : Optional [Dict [str , Any ]] = {},
6165 ) -> Dict [str , Any ]:
62- pass
66+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_deserialize_start" , __name__ , {"prompt" : prompt , "params" : params }))
67+
68+ # Build Completion data
69+ model_settings = self .get_model_settings (prompt , aiconfig )
70+ [_pipeline_creation_params , unfiltered_completion_params ] = refine_pipeline_creation_params (model_settings )
71+ completion_data = refine_asr_completion_params (unfiltered_completion_params )
72+
73+ # ASR Pipeline supports input types of bytes, file path, and a dict containing raw sampled audio. Also supports multiple input
74+ # For now, support multiple or single uri's as input
75+ # TODO: Support or figure out if other input types are needed (base64, bytes), as well as the sampled audio dict
76+ # See api docs for more info:
77+ # - https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/pipelines/automatic_speech_recognition.py#L313-L317
78+ # - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
79+ inputs = validate_and_retrieve_audio_from_attachments (prompt )
80+
81+ completion_data ["inputs" ] = inputs
82+
83+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_deserialize_complete" , __name__ , {"output" : completion_data }))
84+ return completion_data
6385
6486 async def run_inference (self , prompt : Prompt , aiconfig : "AIConfigRuntime" , options : InferenceOptions , parameters : Dict [str , Any ]) -> list [Output ]:
65- pass
87+ await aiconfig .callback_manager .run_callbacks (
88+ CallbackEvent (
89+ "on_run_start" ,
90+ __name__ ,
91+ {"prompt" : prompt , "options" : options , "parameters" : parameters },
92+ )
93+ )
94+
95+ model_settings = self .get_model_settings (prompt , aiconfig )
96+ [pipeline_creation_data , _ ] = refine_pipeline_creation_params (model_settings )
97+ model_name = aiconfig .get_model_name (prompt )
98+
99+ if isinstance (model_name , str ) and model_name not in self .pipelines :
100+ device = self ._get_device ()
101+ if pipeline_creation_data .get ("device" , None ) is None :
102+ pipeline_creation_data ["device" ] = device
103+ self .pipelines [model_name ] = pipeline (task = "automatic-speech-recognition" , ** pipeline_creation_data )
104+
105+ asr_pipeline = self .pipelines [model_name ]
106+ completion_data = await self .deserialize (prompt , aiconfig , parameters )
107+
108+ response = asr_pipeline (** completion_data )
109+
110+ # response is a list of text outputs. This can be tested by running an asr pipeline and noticing the outputs are a list of text.
111+ outputs = construct_outputs (response )
112+
113+ prompt .outputs = outputs
114+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_run_complete" , __name__ , {"result" : prompt .outputs }))
115+ return prompt .outputs
116+
117+ def _get_device (self ) -> str :
118+ if torch .cuda .is_available ():
119+ return "cuda"
120+ # Mps backend is not supported for all asr models. Seen when spinning up a default asr pipeline which uses facebook/wav2vec2-base-960h 55bb623
121+ return "cpu"
122+
123+ def get_output_text (self , response : dict [str , Any ]) -> str :
124+ raise NotImplementedError ("get_output_text is not implemented for HuggingFaceAutomaticSpeechRecognition" )
125+
126+
127+ def validate_attachment_type_is_audio (attachment : Attachment ):
128+ if not hasattr (attachment , "mime_type" ):
129+ raise ValueError (f"Attachment has no mime type. Specify the audio mimetype in the aiconfig" )
130+
131+ if not attachment .mime_type .startswith ("audio/" ):
132+ raise ValueError (f"Invalid attachment mimetype { attachment .mime_type } . Expected audio mimetype." )
133+
134+
135+ def validate_and_retrieve_audio_from_attachments (prompt : Prompt ) -> list [str ]:
136+ """
137+ Retrieves the audio uri's from each attachment in the prompt input.
138+
139+ Throws an exception if
140+ - attachment is not audio
141+ - attachment data is not a uri
142+ - no attachments are found
143+ - operation fails for any reason
144+ """
145+
146+ if not hasattr (prompt .input , "attachments" ) or len (prompt .input .attachments ) == 0 :
147+ raise ValueError (f"No attachments found in input for prompt { prompt .name } . Please add an audio attachment to the prompt input." )
148+
149+ audio_uris : list [str ] = []
150+
151+ for i , attachment in enumerate (prompt .input .attachments ):
152+ validate_attachment_type_is_audio (attachment )
153+
154+ if not isinstance (attachment .data , str ):
155+ # See todo above, but for now only support uri's
156+ raise ValueError (f"Attachment #{ i } data is not a uri. Please specify a uri for the audio attachment in prompt { prompt .name } ." )
157+
158+ audio_uris .append (attachment .data )
159+
160+
161+ return audio_uris
162+
163+
164+ def refine_pipeline_creation_params (model_settings : Dict [str , Any ]) -> List [Dict [str , Any ]]:
165+ """
166+ Refines the pipeline creation params for the HF text2Image generation api.
167+ Defers unsupported params as completion params, where they can get processed in
168+ `refine_image_completion_params()`. The supported keys were found by looking at
169+ the HF Pipelines AutomaticSpeechRecognition API:
170+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
171+
172+ Note that this is not the same as the image completion params, which are passed to
173+ the pipeline later to generate the image:
174+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
175+ """
176+
177+ supported_keys = {
178+ "model" ,
179+ "chunk_length_s" ,
180+ "decoder" ,
181+ "device" ,
182+ "framework" ,
183+ "feature_extractor" ,
184+ "stride_length_s" "tokenizer" ,
185+ }
186+
187+ pipeline_creation_params : Dict [str , Any ] = {}
188+ completion_params : Dict [str , Any ] = {}
189+ for key in model_settings :
190+ if key .lower () in supported_keys :
191+ pipeline_creation_params [key .lower ()] = model_settings [key ]
192+ else :
193+ if key .lower () == "kwargs" and isinstance (model_settings [key ], Dict ):
194+ completion_params .update (model_settings [key ])
195+ else :
196+ completion_params [key .lower ()] = model_settings [key ]
197+
198+ return [pipeline_creation_params , completion_params ]
199+
200+
201+ def refine_asr_completion_params (unfiltered_completion_params : Dict [str , Any ]) -> Dict [str , Any ]:
202+ """
203+ Refines the ASR params for the HF asr generation api after a
204+ pipeline has been created via `refine_pipeline_creation_params`. Removes any
205+ unsupported params. The supported keys were found by looking at the HF asr
206+ API for asr pipelines:
207+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
208+
209+ Note that this is not the same as the pipeline completion params, which were passed
210+ earlier to generate the pipeline:
211+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
212+
213+ Note: This doesn't support base pipeline params like `num_workers`
214+ TODO: Figure out how to find which params are supported.
215+ """
216+
217+ supported_keys = {
218+ # inputs
219+ "return_timestamps" ,
220+ "generate_kwargs" ,
221+ "max_new_tokens" ,
222+ }
223+
224+ completion_params : Dict [str , Any ] = {}
225+ for key in unfiltered_completion_params :
226+ if key .lower () in supported_keys :
227+ completion_params [key .lower ()] = unfiltered_completion_params [key ]
228+
229+ return completion_params
230+
231+
232+ def construct_outputs (response : list [Any ]) -> list [Output ]:
233+ """
234+ Constructs an output from the response of the HF ASR pipeline.
235+
236+ Response from pipeline could contain multiple outputs and time stamps
237+ """
238+ outputs : list [Output ] = []
239+
240+ if not isinstance (response , list ):
241+ # response contains a single output. Found by testing variations of the asr pipeline
242+ response = [response ]
243+
244+ for i , result in enumerate (response ):
245+ # response is expected to be a dict containing the text output and timestamps if specified. Could not find docs for this.
246+ result : dict [str , Any ]
247+ output = ExecuteResult (
248+ ** {
249+ "output_type" : "execute_result" ,
250+ "data" : result .get ("text" ),
251+ "execution_count" : i ,
252+ "metadata" : {"result" : result } if result .get ("chunks" , False ) else {}, # may contain timestamps and chunks, for now pass result
253+ }
254+ )
255+ outputs .append (output )
256+
257+ return outputs
0 commit comments