1- from typing import Any , Coroutine , Dict , Optional , List , TYPE_CHECKING
2- from aiconfig import ParameterizedModelParser , InferenceOptions , AIConfig
1+ from typing import Any , Dict , Literal , Optional , List , TYPE_CHECKING
2+ from aiconfig import ParameterizedModelParser , InferenceOptions
3+ from aiconfig .callback import CallbackEvent
4+ from pydantic import BaseModel
5+ import torch
6+ from aiconfig .schema import Prompt , Output , ExecuteResult , Attachment
37
4- from aiconfig .schema import Prompt , Output
5- from transformers import Pipeline
8+ from transformers import pipeline , Pipeline
69
710if TYPE_CHECKING :
811 from aiconfig import AIConfigRuntime
1114"""
1215
1316
14- class HuggingFaceAutomaticSpeechRecognition (ParameterizedModelParser ):
17+ class HuggingFaceAutomaticSpeechRecognitionTransformer (ParameterizedModelParser ):
1518 def __init__ (self ):
1619 """
1720 Returns:
@@ -24,24 +27,24 @@ def __init__(self):
2427 config.register_model_parser(parser)
2528 """
2629 super ().__init__ ()
27- self .generators : dict [str , Pipeline ] = {}
30+ self .pipelines : dict [str , Pipeline ] = {}
2831
2932 def id (self ) -> str :
3033 """
3134 Returns an identifier for the Model Parser
3235 """
33- return "HuggingFaceAutomaticSpeechRecognition "
36+ return "HuggingFaceAutomaticSpeechRecognitionTransformer "
3437
3538 async def serialize (
3639 self ,
3740 prompt_name : str ,
3841 data : Any ,
3942 ai_config : "AIConfigRuntime" ,
4043 parameters : Optional [Dict [str , Any ]] = None ,
41- ** completion_params ,
4244 ) -> List [Prompt ]:
4345 """
4446 Defines how a prompt and model inference settings get serialized in the .aiconfig.
47+ Assume input in the form of input(s) being passed into an already constructed pipeline.
4548
4649 Args:
4750 prompt (str): The prompt to be serialized.
@@ -52,14 +55,226 @@ async def serialize(
5255 Returns:
5356 str: Serialized representation of the prompt and inference settings.
5457 """
58+ raise NotImplementedError ("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition" )
5559
5660 async def deserialize (
5761 self ,
5862 prompt : Prompt ,
59- aiconfig : "AIConfig " ,
63+ aiconfig : "AIConfigRuntime " ,
6064 params : Optional [Dict [str , Any ]] = {},
6165 ) -> Dict [str , Any ]:
62- pass
66+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_deserialize_start" , __name__ , {"prompt" : prompt , "params" : params }))
67+
68+ # Build Completion data
69+ model_settings = self .get_model_settings (prompt , aiconfig )
70+ [_pipeline_creation_params , unfiltered_completion_params ] = refine_pipeline_creation_params (model_settings )
71+ completion_data = refine_asr_completion_params (unfiltered_completion_params )
72+
73+ # ASR Pipeline supports input types of bytes, file path, and a dict containing raw sampled audio. Also supports multiple input
74+ # For now, support multiple or single uri's as input
75+ # TODO: Support or figure out if other input types are needed (base64, bytes), as well as the sampled audio dict
76+ # See api docs for more info:
77+ # - https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/pipelines/automatic_speech_recognition.py#L313-L317
78+ # - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
79+ inputs = validate_and_retrieve_audio_from_attachments (prompt )
80+
81+ completion_data ["inputs" ] = inputs
82+
83+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_deserialize_complete" , __name__ , {"output" : completion_data }))
84+ return completion_data
6385
6486 async def run_inference (self , prompt : Prompt , aiconfig : "AIConfigRuntime" , options : InferenceOptions , parameters : Dict [str , Any ]) -> list [Output ]:
65- pass
87+ await aiconfig .callback_manager .run_callbacks (
88+ CallbackEvent (
89+ "on_run_start" ,
90+ __name__ ,
91+ {"prompt" : prompt , "options" : options , "parameters" : parameters },
92+ )
93+ )
94+
95+ model_settings = self .get_model_settings (prompt , aiconfig )
96+ [pipeline_creation_data , _ ] = refine_pipeline_creation_params (model_settings )
97+ model_name = aiconfig .get_model_name (prompt )
98+
99+ if isinstance (model_name , str ) and model_name not in self .pipelines :
100+ device = self ._get_device ()
101+ if pipeline_creation_data .get ("device" , None ) is None :
102+ pipeline_creation_data ["device" ] = device
103+ self .pipelines [model_name ] = pipeline (task = "automatic-speech-recognition" , ** pipeline_creation_data )
104+
105+ asr_pipeline = self .pipelines [model_name ]
106+ completion_data = await self .deserialize (prompt , aiconfig , parameters )
107+
108+ response = asr_pipeline (** completion_data )
109+
110+ # response is a list of text outputs. This can be tested by running an asr pipeline and noticing the outputs are a list of text.
111+ outputs = construct_outputs (response )
112+
113+ prompt .outputs = outputs
114+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_run_complete" , __name__ , {"result" : prompt .outputs }))
115+ return prompt .outputs
116+
117+ def _get_device (self ) -> str :
118+ if torch .cuda .is_available ():
119+ return "cuda"
120+ # Mps backend is not supported for all asr models. Seen when spinning up a default asr pipeline which uses facebook/wav2vec2-base-960h 55bb623
121+ return "cpu"
122+
123+ def get_output_text (
124+ self ,
125+ prompt : Prompt ,
126+ aiconfig : "AIConfigRuntime" ,
127+ output : Optional [Output ] = None ,
128+ ) -> str :
129+ if output is None :
130+ output = aiconfig .get_latest_output (prompt )
131+
132+ if output is None :
133+ return ""
134+
135+ # TODO (rossdanlm): Handle multiple outputs in list
136+ # https://github.com/lastmile-ai/aiconfig/issues/467
137+ if output .output_type == "execute_result" :
138+ output_data = output .data
139+ if isinstance (output_data , str ):
140+ return output_data
141+ return ""
142+
143+
144+ def validate_attachment_type_is_audio (attachment : Attachment ):
145+ if not hasattr (attachment , "mime_type" ):
146+ raise ValueError (f"Attachment has no mime type. Specify the audio mimetype in the aiconfig" )
147+
148+ if not attachment .mime_type .startswith ("audio/" ):
149+ raise ValueError (f"Invalid attachment mimetype { attachment .mime_type } . Expected audio mimetype." )
150+
151+
152+ def validate_and_retrieve_audio_from_attachments (prompt : Prompt ) -> list [str ]:
153+ """
154+ Retrieves the audio uri's from each attachment in the prompt input.
155+
156+ Throws an exception if
157+ - attachment is not audio
158+ - attachment data is not a uri
159+ - no attachments are found
160+ - operation fails for any reason
161+ """
162+
163+ if not hasattr (prompt .input , "attachments" ) or len (prompt .input .attachments ) == 0 :
164+ raise ValueError (f"No attachments found in input for prompt { prompt .name } . Please add an audio attachment to the prompt input." )
165+
166+ audio_uris : list [str ] = []
167+
168+ for i , attachment in enumerate (prompt .input .attachments ):
169+ validate_attachment_type_is_audio (attachment )
170+
171+ if not isinstance (attachment .data , str ):
172+ # See todo above, but for now only support uri's
173+ raise ValueError (f"Attachment #{ i } data is not a uri. Please specify a uri for the audio attachment in prompt { prompt .name } ." )
174+
175+ audio_uris .append (attachment .data )
176+
177+ return audio_uris
178+
179+
180+ def refine_pipeline_creation_params (model_settings : Dict [str , Any ]) -> List [Dict [str , Any ]]:
181+ """
182+ Refines the pipeline creation params for the HF text2Image generation api.
183+ Defers unsupported params as completion params, where they can get processed in
184+ `refine_image_completion_params()`. The supported keys were found by looking at
185+ the HF Pipelines AutomaticSpeechRecognition API:
186+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
187+
188+ Note that this is not the same as the image completion params, which are passed to
189+ the pipeline later to generate the image:
190+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
191+
192+ TODO: Distinguish pipeline creation and refine completion https://github.com/lastmile-ai/aiconfig/issues/825 https://github.com/lastmile-ai/aiconfig/issues/824
193+ """
194+
195+ supported_keys = {
196+ "model" ,
197+ "chunk_length_s" ,
198+ "decoder" ,
199+ "device" ,
200+ "framework" ,
201+ "feature_extractor" ,
202+ "stride_length_s" "tokenizer" ,
203+ }
204+
205+ pipeline_creation_params : Dict [str , Any ] = {}
206+ completion_params : Dict [str , Any ] = {}
207+ for key in model_settings :
208+ if key .lower () in supported_keys :
209+ pipeline_creation_params [key .lower ()] = model_settings [key ]
210+ else :
211+ if key .lower () == "kwargs" and isinstance (model_settings [key ], Dict ):
212+ completion_params .update (model_settings [key ])
213+ else :
214+ completion_params [key .lower ()] = model_settings [key ]
215+
216+ return [pipeline_creation_params , completion_params ]
217+
218+
219+ def refine_asr_completion_params (unfiltered_completion_params : Dict [str , Any ]) -> Dict [str , Any ]:
220+ """
221+ Refines the ASR params for the HF asr generation api after a
222+ pipeline has been created via `refine_pipeline_creation_params`. Removes any
223+ unsupported params. The supported keys were found by looking at the HF asr
224+ API for asr pipelines:
225+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
226+
227+ Note that this is not the same as the pipeline completion params, which were passed
228+ earlier to generate the pipeline:
229+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
230+
231+ Note: This doesn't support base pipeline params like `num_workers`
232+ TODO: Figure out how to find which params are supported.
233+
234+ TODO: Distinguish pipeline creation and refine completion
235+ https://github.com/lastmile-ai/aiconfig/issues/825
236+ https://github.com/lastmile-ai/aiconfig/issues/824
237+ """
238+
239+ supported_keys = {
240+ # inputs
241+ "return_timestamps" ,
242+ "generate_kwargs" ,
243+ "max_new_tokens" ,
244+ }
245+
246+ completion_params : Dict [str , Any ] = {}
247+ for key in unfiltered_completion_params :
248+ if key .lower () in supported_keys :
249+ completion_params [key .lower ()] = unfiltered_completion_params [key ]
250+
251+ return completion_params
252+
253+
254+ def construct_outputs (response : list [Any ]) -> list [Output ]:
255+ """
256+ Constructs an output from the response of the HF ASR pipeline.
257+
258+ Response from pipeline could contain multiple outputs and time stamps. No Docs found for this.
259+ """
260+ outputs : list [Output ] = []
261+
262+ if not isinstance (response , list ):
263+ # response contains a single output. Found by testing variations of the asr pipeline
264+ response = [response ]
265+
266+ for i , result in enumerate (response ):
267+ # response is expected to be a dict containing the text output and timestamps if specified. Could not find docs for this.
268+ result : dict [str , Any ]
269+ text_output = result .get ("text" ) if "text" in result and isinstance (result , dict ) else result
270+ output = ExecuteResult (
271+ ** {
272+ "output_type" : "execute_result" ,
273+ "data" : text_output ,
274+ "execution_count" : i ,
275+ "metadata" : {"result" : result } if result .get ("chunks" , False ) else {}, # may contain timestamps and chunks, for now pass result
276+ }
277+ )
278+ outputs .append (output )
279+
280+ return outputs
0 commit comments