1- from typing import Any , Coroutine , Dict , Optional , List , TYPE_CHECKING
2- from aiconfig import ParameterizedModelParser , InferenceOptions , AIConfig
1+ from typing import Any , Dict , Literal , Optional , List , TYPE_CHECKING
2+ from aiconfig import ParameterizedModelParser , InferenceOptions
3+ from aiconfig .callback import CallbackEvent
4+ from pydantic import BaseModel
5+ import torch
6+ from aiconfig .schema import Prompt , Output , ExecuteResult , Attachment
37
4- from aiconfig .schema import Prompt , Output
5- from transformers import Pipeline
8+ from transformers import pipeline , Pipeline
69
710if TYPE_CHECKING :
811 from aiconfig import AIConfigRuntime
@@ -24,7 +27,7 @@ def __init__(self):
2427 config.register_model_parser(parser)
2528 """
2629 super ().__init__ ()
27- self .generators : dict [str , Pipeline ] = {}
30+ self .pipelines : dict [str , Pipeline ] = {}
2831
2932 def id (self ) -> str :
3033 """
@@ -38,10 +41,10 @@ async def serialize(
3841 data : Any ,
3942 ai_config : "AIConfigRuntime" ,
4043 parameters : Optional [Dict [str , Any ]] = None ,
41- ** completion_params ,
4244 ) -> List [Prompt ]:
4345 """
4446 Defines how a prompt and model inference settings get serialized in the .aiconfig.
47+ Assume input in the form of input(s) being passed into an already constructed pipeline.
4548
4649 Args:
4750 prompt (str): The prompt to be serialized.
@@ -52,14 +55,224 @@ async def serialize(
5255 Returns:
5356 str: Serialized representation of the prompt and inference settings.
5457 """
58+ raise NotImplementedError ("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition" )
5559
5660 async def deserialize (
5761 self ,
5862 prompt : Prompt ,
59- aiconfig : "AIConfig " ,
63+ aiconfig : "AIConfigRuntime " ,
6064 params : Optional [Dict [str , Any ]] = {},
6165 ) -> Dict [str , Any ]:
62- pass
66+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_deserialize_start" , __name__ , {"prompt" : prompt , "params" : params }))
67+
68+ # Build Completion data
69+ model_settings = self .get_model_settings (prompt , aiconfig )
70+ [_pipeline_creation_params , unfiltered_completion_params ] = refine_pipeline_creation_params (model_settings )
71+ completion_data = refine_asr_completion_params (unfiltered_completion_params )
72+
73+ # ASR Pipeline supports input types of bytes, file path, and a dict containing raw sampled audio. Also supports multiple input
74+ # For now, support multiple or single uri's as input
75+ # TODO: Support or figure out if other input types are needed (base64, bytes), as well as the sampled audio dict
76+ # See api docs for more info:
77+ # - https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/pipelines/automatic_speech_recognition.py#L313-L317
78+ # - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
79+ inputs = validate_and_retrieve_audio_from_attachments (prompt )
80+
81+ completion_data ["inputs" ] = inputs
82+
83+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_deserialize_complete" , __name__ , {"output" : completion_data }))
84+ return completion_data
6385
6486 async def run_inference (self , prompt : Prompt , aiconfig : "AIConfigRuntime" , options : InferenceOptions , parameters : Dict [str , Any ]) -> list [Output ]:
65- pass
87+ await aiconfig .callback_manager .run_callbacks (
88+ CallbackEvent (
89+ "on_run_start" ,
90+ __name__ ,
91+ {"prompt" : prompt , "options" : options , "parameters" : parameters },
92+ )
93+ )
94+
95+ model_settings = self .get_model_settings (prompt , aiconfig )
96+ [pipeline_creation_data , _ ] = refine_pipeline_creation_params (model_settings )
97+ model_name = aiconfig .get_model_name (prompt )
98+
99+ if isinstance (model_name , str ) and model_name not in self .pipelines :
100+ device = self ._get_device ()
101+ if pipeline_creation_data .get ("device" , None ) is None :
102+ pipeline_creation_data ["device" ] = device
103+ self .pipelines [model_name ] = pipeline (task = "automatic-speech-recognition" , ** pipeline_creation_data )
104+
105+ asr_pipeline = self .pipelines [model_name ]
106+ completion_data = await self .deserialize (prompt , aiconfig , parameters )
107+
108+ response = asr_pipeline (** completion_data )
109+
110+ # response is a list of text outputs. This can be tested by running an asr pipeline and noticing the outputs are a list of text.
111+ outputs = construct_outputs (response )
112+
113+ prompt .outputs = outputs
114+ await aiconfig .callback_manager .run_callbacks (CallbackEvent ("on_run_complete" , __name__ , {"result" : prompt .outputs }))
115+ return prompt .outputs
116+
117+ def _get_device (self ) -> str :
118+ if torch .cuda .is_available ():
119+ return "cuda"
120+ # Mps backend is not supported for all asr models. Seen when spinning up a default asr pipeline which uses facebook/wav2vec2-base-960h 55bb623
121+ return "cpu"
122+
123+ def get_output_text (self , response : dict [str , Any ]) -> str :
124+ raise NotImplementedError ("get_output_text is not implemented for HuggingFaceAutomaticSpeechRecognition" )
125+
126+
127+ def validate_attachment_type_is_audio (attachment : Attachment ):
128+ if not hasattr (attachment , "mime_type" ):
129+ raise ValueError (f"Attachment has no mime type. Specify the audio mimetype in the aiconfig" )
130+
131+ if not attachment .mime_type .startswith ("audio/" ):
132+ raise ValueError (f"Invalid attachment mimetype { attachment .mime_type } . Expected audio mimetype." )
133+
134+
135+ def validate_and_retrieve_audio_from_attachments (prompt : Prompt ) -> list [str ]:
136+ """
137+ Retrieves the audio uri's from each attachment in the prompt input.
138+
139+ Throws an exception if
140+ - attachment is not audio
141+ - attachment data is not a uri
142+ - no attachments are found
143+ - operation fails for any reason
144+ """
145+
146+ if not hasattr (prompt .input , "attachments" ) or len (prompt .input .attachments ) == 0 :
147+ raise ValueError (f"No attachments found in input for prompt { prompt .name } . Please add an audio attachment to the prompt input." )
148+
149+ audio_uris : list [str ] = []
150+
151+ for i , attachment in enumerate (prompt .input .attachments ):
152+ validate_attachment_type_is_audio (attachment )
153+
154+ try :
155+ # load the attachment data into `AttachmentInputDataWithStringValue`
156+ # TODO: @ankush-lastmile update AIConfig sdk schema to include this type (can also remove this try catch once done)
157+ _ = AttachmentInputDataWithStringValue (** attachment .data )
158+ except Exception as e :
159+ # pylint: disable=W0707
160+ raise ValueError (
161+ f"Attachment #{ i } for prompt { prompt .name } is not a valid attachment for Automatic Speech Recognition model parser. "
162+ f"Error: { str (e )} . Please specify a kind and value for the attachment data."
163+ )
164+ # TODO: once previous todo gets resolved, modify this line to `.kind`` instead of .data.get("kind"). It will be a pydantic base class.
165+ if not attachment .data .get ("kind" ) == "file_uri" :
166+ # See todo above, but for now only support uri's
167+ raise ValueError (f"Attachment #{ i } data is not a uri. Please pass in a uri for the audio attachment in prompt { prompt .name } ." )
168+ # TODO: once previous todo gets resolved, modify this line to `.value`` instead of .data.get("value"). It will be a pydantic base class.
169+ audio_uris .append (attachment .data .get ("value" ))
170+
171+ return audio_uris
172+
173+
174+ def refine_pipeline_creation_params (model_settings : Dict [str , Any ]) -> List [Dict [str , Any ]]:
175+ """
176+ Refines the pipeline creation params for the HF text2Image generation api.
177+ Defers unsupported params as completion params, where they can get processed in
178+ `refine_image_completion_params()`. The supported keys were found by looking at
179+ the HF Pipelines AutomaticSpeechRecognition API:
180+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
181+
182+ Note that this is not the same as the image completion params, which are passed to
183+ the pipeline later to generate the image:
184+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
185+ """
186+
187+ supported_keys = {
188+ "model" ,
189+ "chunk_length_s" ,
190+ "decoder" ,
191+ "device" ,
192+ "framework" ,
193+ "feature_extractor" ,
194+ "stride_length_s" "tokenizer" ,
195+ }
196+
197+ pipeline_creation_params : Dict [str , Any ] = {}
198+ completion_params : Dict [str , Any ] = {}
199+ for key in model_settings :
200+ if key .lower () in supported_keys :
201+ pipeline_creation_params [key .lower ()] = model_settings [key ]
202+ else :
203+ if key .lower () == "kwargs" and isinstance (model_settings [key ], Dict ):
204+ completion_params .update (model_settings [key ])
205+ else :
206+ completion_params [key .lower ()] = model_settings [key ]
207+
208+ return [pipeline_creation_params , completion_params ]
209+
210+
211+ def refine_asr_completion_params (unfiltered_completion_params : Dict [str , Any ]) -> Dict [str , Any ]:
212+ """
213+ Refines the ASR params for the HF asr generation api after a
214+ pipeline has been created via `refine_pipeline_creation_params`. Removes any
215+ unsupported params. The supported keys were found by looking at the HF asr
216+ API for asr pipelines:
217+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
218+
219+ Note that this is not the same as the pipeline completion params, which were passed
220+ earlier to generate the pipeline:
221+ https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
222+
223+ Note: This doesn't support base pipeline params like `num_workers`
224+ TODO: Figure out how to find which params are supported.
225+ """
226+
227+ supported_keys = {
228+ # inputs
229+ "return_timestamps" ,
230+ "generate_kwargs" ,
231+ "max_new_tokens" ,
232+ }
233+
234+ completion_params : Dict [str , Any ] = {}
235+ for key in unfiltered_completion_params :
236+ if key .lower () in supported_keys :
237+ completion_params [key .lower ()] = unfiltered_completion_params [key ]
238+
239+ return completion_params
240+
241+
242+ def construct_outputs (response : list [Any ]) -> list [Output ]:
243+ """
244+ Constructs an output from the response of the HF ASR pipeline.
245+
246+ Response from pipeline could contain multiple outputs and time stamps
247+ """
248+ outputs : list [Output ] = []
249+
250+ if not isinstance (response , list ):
251+ # response contains a single output. Found by testing variations of the asr pipeline
252+ response = [response ]
253+
254+ for i , result in enumerate (response ):
255+ # response is expected to be a dict containing the text output and timestamps if specified. Could not find docs for this.
256+ result : dict [str , Any ]
257+ output = ExecuteResult (
258+ ** {
259+ "output_type" : "execute_result" ,
260+ "data" : result .get ("text" ),
261+ "execution_count" : i ,
262+ "metadata" : {"result" : result } if result .get ("chunks" , False ) else {}, # may contain timestamps and chunks, for now pass result
263+ }
264+ )
265+ outputs .append (output )
266+
267+ return outputs
268+
269+
270+ class AttachmentInputDataWithStringValue (BaseModel ):
271+ """
272+ This represents the input data that is storied as a string, but we use
273+ both the `kind` field here and the `mime_type` to convert
274+ the string into the output format we want.
275+ """
276+
277+ kind : Literal ["file_uri" , "base64" ]
278+ value : str
0 commit comments