Skip to content

Commit f2ee469

Browse files
author
Ankush Pala ankush@lastmileai.dev
committed
[extensions][py][hf] 2/n ASR model parser impl
Model Parser for the Automatic Speech Recognition task on huggingface. Decisions made while implementing: - manual impl to parse input attachments - - threw exceptions on every unexpected step. Not sure if this is the direction we want to go with this. - This diff does not implement serialize() for the model parser (will implement on diff ontop) ## Testplan Created an mp3 file that says "hi". Used aiconfig to run asr on it. |<img width="591" alt="Screenshot 2024-01-09 at 5 53 48 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/8f6a8339-e581-4886-8f19-732e0292e4ee"> |<img width="1461" alt="Screenshot 2024-01-09 at 5 54 33 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/78063a3e-2b9a-4a39-80d9-ef28a7d706cf">| | ------------- | ------------- |
1 parent 08d9be8 commit f2ee469

File tree

2 files changed

+203
-10
lines changed

2 files changed

+203
-10
lines changed

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
55
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
66
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer
7-
7+
from .local_inference.automatic_speech_recognition import HuggingFaceAutomaticSpeechRecognition
88

99
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
1010

@@ -14,6 +14,7 @@
1414
"HuggingFaceTextSummarizationTransformer",
1515
"HuggingFaceTextTranslationTransformer",
1616
"HuggingFaceText2SpeechTransformer",
17+
"HuggingFaceAutomaticSpeechRecognition",
1718
]
1819
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
1920
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py

Lines changed: 201 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Any, Coroutine, Dict, Optional, List, TYPE_CHECKING
2-
from aiconfig import ParameterizedModelParser, InferenceOptions, AIConfig
1+
from typing import Any, Dict, Literal, Optional, List, TYPE_CHECKING
2+
from aiconfig import ParameterizedModelParser, InferenceOptions
3+
from aiconfig.callback import CallbackEvent
4+
from pydantic import BaseModel
5+
import torch
6+
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment
37

4-
from aiconfig.schema import Prompt, Output
5-
from transformers import Pipeline
8+
from transformers import pipeline, Pipeline
69

710
if TYPE_CHECKING:
811
from aiconfig import AIConfigRuntime
@@ -24,7 +27,7 @@ def __init__(self):
2427
config.register_model_parser(parser)
2528
"""
2629
super().__init__()
27-
self.generators: dict[str, Pipeline] = {}
30+
self.pipelines: dict[str, Pipeline] = {}
2831

2932
def id(self) -> str:
3033
"""
@@ -38,10 +41,10 @@ async def serialize(
3841
data: Any,
3942
ai_config: "AIConfigRuntime",
4043
parameters: Optional[Dict[str, Any]] = None,
41-
**completion_params,
4244
) -> List[Prompt]:
4345
"""
4446
Defines how a prompt and model inference settings get serialized in the .aiconfig.
47+
Assume input in the form of input(s) being passed into an already constructed pipeline.
4548
4649
Args:
4750
prompt (str): The prompt to be serialized.
@@ -52,14 +55,203 @@ async def serialize(
5255
Returns:
5356
str: Serialized representation of the prompt and inference settings.
5457
"""
58+
raise NotImplementedError("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition")
5559

5660
async def deserialize(
5761
self,
5862
prompt: Prompt,
59-
aiconfig: "AIConfig",
63+
aiconfig: "AIConfigRuntime",
6064
params: Optional[Dict[str, Any]] = {},
6165
) -> Dict[str, Any]:
62-
pass
66+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))
67+
68+
# Build Completion data
69+
model_settings = self.get_model_settings(prompt, aiconfig)
70+
[_pipeline_creation_params, unfiltered_completion_params] = refine_pipeline_creation_params(model_settings)
71+
completion_data = refine_asr_completion_params(unfiltered_completion_params)
72+
73+
# ASR Pipeline supports input types of bytes, file path, and a dict containing raw sampled audio. Also supports multiple input
74+
# For now, support multiple or single uri's as input
75+
# TODO: Support or figure out if other input types are needed (base64, bytes), as well as the sampled audio dict
76+
# See api docs for more info:
77+
# - https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/pipelines/automatic_speech_recognition.py#L313-L317
78+
# - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
79+
inputs = validate_and_retrieve_audio_from_attachments(prompt)
80+
81+
completion_data["inputs"] = inputs
82+
83+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))
84+
return completion_data
6385

6486
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
65-
pass
87+
await aiconfig.callback_manager.run_callbacks(
88+
CallbackEvent(
89+
"on_run_start",
90+
__name__,
91+
{"prompt": prompt, "options": options, "parameters": parameters},
92+
)
93+
)
94+
95+
model_settings = self.get_model_settings(prompt, aiconfig)
96+
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)
97+
model_name = aiconfig.get_model_name(prompt)
98+
99+
if isinstance(model_name, str) and model_name not in self.pipelines:
100+
device = self._get_device()
101+
if pipeline_creation_data.get("device", None) is None:
102+
pipeline_creation_data["device"] = device
103+
self.pipelines[model_name] = pipeline(task="automatic-speech-recognition", **pipeline_creation_data)
104+
105+
asr_pipeline = self.pipelines[model_name]
106+
completion_data = await self.deserialize(prompt, aiconfig, parameters)
107+
108+
response = asr_pipeline(**completion_data)
109+
110+
# response is a list of text outputs. This can be tested by running an asr pipeline and noticing the outputs are a list of text.
111+
outputs = construct_outputs(response)
112+
113+
prompt.outputs = outputs
114+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
115+
return prompt.outputs
116+
117+
def _get_device(self) -> str:
118+
if torch.cuda.is_available():
119+
return "cuda"
120+
# Mps backend is not supported for all asr models. Seen when spinning up a default asr pipeline which uses facebook/wav2vec2-base-960h 55bb623
121+
return "cpu"
122+
123+
def get_output_text(self, response: dict[str, Any]) -> str:
124+
raise NotImplementedError("get_output_text is not implemented for HuggingFaceAutomaticSpeechRecognition")
125+
126+
127+
def validate_attachment_type_is_audio(attachment: Attachment):
128+
if not hasattr(attachment, "mime_type"):
129+
raise ValueError(f"Attachment has no mime type. Specify the audio mimetype in the aiconfig")
130+
131+
if not attachment.mime_type.startswith("audio/"):
132+
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected audio mimetype.")
133+
134+
135+
def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]:
136+
"""
137+
Retrieves the audio uri's from each attachment in the prompt input.
138+
139+
Throws an exception if
140+
- attachment is not audio
141+
- attachment data is not a uri
142+
- no attachments are found
143+
- operation fails for any reason
144+
"""
145+
146+
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
147+
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input.")
148+
149+
audio_uris: list[str] = []
150+
151+
for i, attachment in enumerate(prompt.input.attachments):
152+
validate_attachment_type_is_audio(attachment)
153+
154+
if not isinstance(attachment.data, str):
155+
# See todo above, but for now only support uri's
156+
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the audio attachment in prompt {prompt.name}.")
157+
158+
audio_uris.append(attachment.data)
159+
160+
161+
return audio_uris
162+
163+
164+
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
165+
"""
166+
Refines the pipeline creation params for the HF text2Image generation api.
167+
Defers unsupported params as completion params, where they can get processed in
168+
`refine_image_completion_params()`. The supported keys were found by looking at
169+
the HF Pipelines AutomaticSpeechRecognition API:
170+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
171+
172+
Note that this is not the same as the image completion params, which are passed to
173+
the pipeline later to generate the image:
174+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
175+
"""
176+
177+
supported_keys = {
178+
"model",
179+
"chunk_length_s",
180+
"decoder",
181+
"device",
182+
"framework",
183+
"feature_extractor",
184+
"stride_length_s" "tokenizer",
185+
}
186+
187+
pipeline_creation_params: Dict[str, Any] = {}
188+
completion_params: Dict[str, Any] = {}
189+
for key in model_settings:
190+
if key.lower() in supported_keys:
191+
pipeline_creation_params[key.lower()] = model_settings[key]
192+
else:
193+
if key.lower() == "kwargs" and isinstance(model_settings[key], Dict):
194+
completion_params.update(model_settings[key])
195+
else:
196+
completion_params[key.lower()] = model_settings[key]
197+
198+
return [pipeline_creation_params, completion_params]
199+
200+
201+
def refine_asr_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]:
202+
"""
203+
Refines the ASR params for the HF asr generation api after a
204+
pipeline has been created via `refine_pipeline_creation_params`. Removes any
205+
unsupported params. The supported keys were found by looking at the HF asr
206+
API for asr pipelines:
207+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
208+
209+
Note that this is not the same as the pipeline completion params, which were passed
210+
earlier to generate the pipeline:
211+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
212+
213+
Note: This doesn't support base pipeline params like `num_workers`
214+
TODO: Figure out how to find which params are supported.
215+
"""
216+
217+
supported_keys = {
218+
# inputs
219+
"return_timestamps",
220+
"generate_kwargs",
221+
"max_new_tokens",
222+
}
223+
224+
completion_params: Dict[str, Any] = {}
225+
for key in unfiltered_completion_params:
226+
if key.lower() in supported_keys:
227+
completion_params[key.lower()] = unfiltered_completion_params[key]
228+
229+
return completion_params
230+
231+
232+
def construct_outputs(response: list[Any]) -> list[Output]:
233+
"""
234+
Constructs an output from the response of the HF ASR pipeline.
235+
236+
Response from pipeline could contain multiple outputs and time stamps
237+
"""
238+
outputs: list[Output] = []
239+
240+
if not isinstance(response, list):
241+
# response contains a single output. Found by testing variations of the asr pipeline
242+
response = [response]
243+
244+
for i, result in enumerate(response):
245+
# response is expected to be a dict containing the text output and timestamps if specified. Could not find docs for this.
246+
result: dict[str, Any]
247+
output = ExecuteResult(
248+
**{
249+
"output_type": "execute_result",
250+
"data": result.get("text"),
251+
"execution_count": i,
252+
"metadata": {"result": result} if result.get("chunks", False) else {}, # may contain timestamps and chunks, for now pass result
253+
}
254+
)
255+
outputs.append(output)
256+
257+
return outputs

0 commit comments

Comments
 (0)