Skip to content

Commit 6dc0b11

Browse files
author
Ankush Pala ankush@lastmileai.dev
committed
[extensions][py][hf] 2/n ASR model parser impl
Model Parser for the Automatic Speech Recognition task on huggingface. Decisions made while implementing: - manual impl to parse input attachments - - threw exceptions on every unexpected step. Not sure if this is the direction we want to go with this. - This diff does not implement serialize() for the model parser (will implement on diff ontop) ## Testplan Created an mp3 file that says "hi". Used aiconfig to run asr on it. |<img width="922" alt="Screenshot 2024-01-09 at 7 14 47 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/fe68751d-e20b-41d9-9da5-cc9a32859cba"> |<img width="1461" alt="Screenshot 2024-01-09 at 5 54 33 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/78063a3e-2b9a-4a39-80d9-ef28a7d706cf">| | ------------- | ------------- |
1 parent 485faf6 commit 6dc0b11

File tree

2 files changed

+228
-11
lines changed

2 files changed

+228
-11
lines changed

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
66
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
77
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
8+
from .local_inference.automatic_speech_recognition import HuggingFaceAutomaticSpeechRecognitionTransformer
89

910

1011
LOCAL_INFERENCE_CLASSES = [
@@ -15,6 +16,7 @@
1516
"HuggingFaceText2SpeechTransformer",
1617
"HuggingFaceAutomaticSpeechRecognition",
1718
"HuggingFaceImage2TextTransformer",
19+
"HuggingFaceAutomaticSpeechRecognitionTransformer",
1820
]
1921
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
2022
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Lines changed: 226 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Any, Coroutine, Dict, Optional, List, TYPE_CHECKING
2-
from aiconfig import ParameterizedModelParser, InferenceOptions, AIConfig
1+
from typing import Any, Dict, Literal, Optional, List, TYPE_CHECKING
2+
from aiconfig import ParameterizedModelParser, InferenceOptions
3+
from aiconfig.callback import CallbackEvent
4+
from pydantic import BaseModel
5+
import torch
6+
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment
37

4-
from aiconfig.schema import Prompt, Output
5-
from transformers import Pipeline
8+
from transformers import pipeline, Pipeline
69

710
if TYPE_CHECKING:
811
from aiconfig import AIConfigRuntime
@@ -11,7 +14,7 @@
1114
"""
1215

1316

14-
class HuggingFaceAutomaticSpeechRecognition(ParameterizedModelParser):
17+
class HuggingFaceAutomaticSpeechRecognitionTransformer(ParameterizedModelParser):
1518
def __init__(self):
1619
"""
1720
Returns:
@@ -24,24 +27,24 @@ def __init__(self):
2427
config.register_model_parser(parser)
2528
"""
2629
super().__init__()
27-
self.generators: dict[str, Pipeline] = {}
30+
self.pipelines: dict[str, Pipeline] = {}
2831

2932
def id(self) -> str:
3033
"""
3134
Returns an identifier for the Model Parser
3235
"""
33-
return "HuggingFaceAutomaticSpeechRecognition"
36+
return "HuggingFaceAutomaticSpeechRecognitionTransformer"
3437

3538
async def serialize(
3639
self,
3740
prompt_name: str,
3841
data: Any,
3942
ai_config: "AIConfigRuntime",
4043
parameters: Optional[Dict[str, Any]] = None,
41-
**completion_params,
4244
) -> List[Prompt]:
4345
"""
4446
Defines how a prompt and model inference settings get serialized in the .aiconfig.
47+
Assume input in the form of input(s) being passed into an already constructed pipeline.
4548
4649
Args:
4750
prompt (str): The prompt to be serialized.
@@ -52,14 +55,226 @@ async def serialize(
5255
Returns:
5356
str: Serialized representation of the prompt and inference settings.
5457
"""
58+
raise NotImplementedError("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition")
5559

5660
async def deserialize(
5761
self,
5862
prompt: Prompt,
59-
aiconfig: "AIConfig",
63+
aiconfig: "AIConfigRuntime",
6064
params: Optional[Dict[str, Any]] = {},
6165
) -> Dict[str, Any]:
62-
pass
66+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))
67+
68+
# Build Completion data
69+
model_settings = self.get_model_settings(prompt, aiconfig)
70+
[_pipeline_creation_params, unfiltered_completion_params] = refine_pipeline_creation_params(model_settings)
71+
completion_data = refine_asr_completion_params(unfiltered_completion_params)
72+
73+
# ASR Pipeline supports input types of bytes, file path, and a dict containing raw sampled audio. Also supports multiple input
74+
# For now, support multiple or single uri's as input
75+
# TODO: Support or figure out if other input types are needed (base64, bytes), as well as the sampled audio dict
76+
# See api docs for more info:
77+
# - https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/pipelines/automatic_speech_recognition.py#L313-L317
78+
# - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
79+
inputs = validate_and_retrieve_audio_from_attachments(prompt)
80+
81+
completion_data["inputs"] = inputs
82+
83+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))
84+
return completion_data
6385

6486
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
65-
pass
87+
await aiconfig.callback_manager.run_callbacks(
88+
CallbackEvent(
89+
"on_run_start",
90+
__name__,
91+
{"prompt": prompt, "options": options, "parameters": parameters},
92+
)
93+
)
94+
95+
model_settings = self.get_model_settings(prompt, aiconfig)
96+
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)
97+
model_name = aiconfig.get_model_name(prompt)
98+
99+
if isinstance(model_name, str) and model_name not in self.pipelines:
100+
device = self._get_device()
101+
if pipeline_creation_data.get("device", None) is None:
102+
pipeline_creation_data["device"] = device
103+
self.pipelines[model_name] = pipeline(task="automatic-speech-recognition", **pipeline_creation_data)
104+
105+
asr_pipeline = self.pipelines[model_name]
106+
completion_data = await self.deserialize(prompt, aiconfig, parameters)
107+
108+
response = asr_pipeline(**completion_data)
109+
110+
# response is a list of text outputs. This can be tested by running an asr pipeline and noticing the outputs are a list of text.
111+
outputs = construct_outputs(response)
112+
113+
prompt.outputs = outputs
114+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
115+
return prompt.outputs
116+
117+
def _get_device(self) -> str:
118+
if torch.cuda.is_available():
119+
return "cuda"
120+
# Mps backend is not supported for all asr models. Seen when spinning up a default asr pipeline which uses facebook/wav2vec2-base-960h 55bb623
121+
return "cpu"
122+
123+
def get_output_text(
124+
self,
125+
prompt: Prompt,
126+
aiconfig: "AIConfigRuntime",
127+
output: Optional[Output] = None,
128+
) -> str:
129+
if output is None:
130+
output = aiconfig.get_latest_output(prompt)
131+
132+
if output is None:
133+
return ""
134+
135+
# TODO (rossdanlm): Handle multiple outputs in list
136+
# https://github.com/lastmile-ai/aiconfig/issues/467
137+
if output.output_type == "execute_result":
138+
output_data = output.data
139+
if isinstance(output_data, str):
140+
return output_data
141+
return ""
142+
143+
144+
def validate_attachment_type_is_audio(attachment: Attachment):
145+
if not hasattr(attachment, "mime_type"):
146+
raise ValueError(f"Attachment has no mime type. Specify the audio mimetype in the aiconfig")
147+
148+
if not attachment.mime_type.startswith("audio/"):
149+
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected audio mimetype.")
150+
151+
152+
def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]:
153+
"""
154+
Retrieves the audio uri's from each attachment in the prompt input.
155+
156+
Throws an exception if
157+
- attachment is not audio
158+
- attachment data is not a uri
159+
- no attachments are found
160+
- operation fails for any reason
161+
"""
162+
163+
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
164+
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input.")
165+
166+
audio_uris: list[str] = []
167+
168+
for i, attachment in enumerate(prompt.input.attachments):
169+
validate_attachment_type_is_audio(attachment)
170+
171+
if not isinstance(attachment.data, str):
172+
# See todo above, but for now only support uri's
173+
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the audio attachment in prompt {prompt.name}.")
174+
175+
audio_uris.append(attachment.data)
176+
177+
return audio_uris
178+
179+
180+
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
181+
"""
182+
Refines the pipeline creation params for the HF text2Image generation api.
183+
Defers unsupported params as completion params, where they can get processed in
184+
`refine_image_completion_params()`. The supported keys were found by looking at
185+
the HF Pipelines AutomaticSpeechRecognition API:
186+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
187+
188+
Note that this is not the same as the image completion params, which are passed to
189+
the pipeline later to generate the image:
190+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
191+
192+
TODO: Distinguish pipeline creation and refine completion https://github.com/lastmile-ai/aiconfig/issues/825 https://github.com/lastmile-ai/aiconfig/issues/824
193+
"""
194+
195+
supported_keys = {
196+
"model",
197+
"chunk_length_s",
198+
"decoder",
199+
"device",
200+
"framework",
201+
"feature_extractor",
202+
"stride_length_s" "tokenizer",
203+
}
204+
205+
pipeline_creation_params: Dict[str, Any] = {}
206+
completion_params: Dict[str, Any] = {}
207+
for key in model_settings:
208+
if key.lower() in supported_keys:
209+
pipeline_creation_params[key.lower()] = model_settings[key]
210+
else:
211+
if key.lower() == "kwargs" and isinstance(model_settings[key], Dict):
212+
completion_params.update(model_settings[key])
213+
else:
214+
completion_params[key.lower()] = model_settings[key]
215+
216+
return [pipeline_creation_params, completion_params]
217+
218+
219+
def refine_asr_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]:
220+
"""
221+
Refines the ASR params for the HF asr generation api after a
222+
pipeline has been created via `refine_pipeline_creation_params`. Removes any
223+
unsupported params. The supported keys were found by looking at the HF asr
224+
API for asr pipelines:
225+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
226+
227+
Note that this is not the same as the pipeline completion params, which were passed
228+
earlier to generate the pipeline:
229+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
230+
231+
Note: This doesn't support base pipeline params like `num_workers`
232+
TODO: Figure out how to find which params are supported.
233+
234+
TODO: Distinguish pipeline creation and refine completion
235+
https://github.com/lastmile-ai/aiconfig/issues/825
236+
https://github.com/lastmile-ai/aiconfig/issues/824
237+
"""
238+
239+
supported_keys = {
240+
# inputs
241+
"return_timestamps",
242+
"generate_kwargs",
243+
"max_new_tokens",
244+
}
245+
246+
completion_params: Dict[str, Any] = {}
247+
for key in unfiltered_completion_params:
248+
if key.lower() in supported_keys:
249+
completion_params[key.lower()] = unfiltered_completion_params[key]
250+
251+
return completion_params
252+
253+
254+
def construct_outputs(response: list[Any]) -> list[Output]:
255+
"""
256+
Constructs an output from the response of the HF ASR pipeline.
257+
258+
Response from pipeline could contain multiple outputs and time stamps. No Docs found for this.
259+
"""
260+
outputs: list[Output] = []
261+
262+
if not isinstance(response, list):
263+
# response contains a single output. Found by testing variations of the asr pipeline
264+
response = [response]
265+
266+
for i, result in enumerate(response):
267+
# response is expected to be a dict containing the text output and timestamps if specified. Could not find docs for this.
268+
result: dict[str, Any]
269+
text_output = result.get("text") if "text" in result and isinstance(result, dict) else result
270+
output = ExecuteResult(
271+
**{
272+
"output_type": "execute_result",
273+
"data": text_output,
274+
"execution_count": i,
275+
"metadata": {"result": result} if result.get("chunks", False) else {}, # may contain timestamps and chunks, for now pass result
276+
}
277+
)
278+
outputs.append(output)
279+
280+
return outputs

0 commit comments

Comments
 (0)