Skip to content

Commit 27ec5ea

Browse files
author
Ankush Pala ankush@lastmileai.dev
committed
[extensions][py][hf] 2/n ASR model parser impl
Model Parser for the Automatic Speech Recognition task on huggingface. Decisions made while implementing: - manual impl to parse input attachments - - threw exceptions on every unexpected step. Not sure if this is the direction we want to go with this. - This diff does not implement serialize() for the model parser (will implement on diff ontop) ## Testplan Created an mp3 file that says "hi". Used aiconfig to run asr on it. |<img width="1518" alt="Screenshot 2024-01-09 at 4 22 29 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/a16a6016-79ea-425e-bc8c-d0dd58e9cccb">|<img width="525" alt="Screenshot 2024-01-09 at 4 21 59 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/fab059b7-e018-46d9-82b6-d9efcdbe4545">| | ------------- | ------------- |
1 parent 15ef3b4 commit 27ec5ea

File tree

2 files changed

+225
-10
lines changed

2 files changed

+225
-10
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
22
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
3-
from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
3+
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
44
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
55
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
6+
from .local_inference.automatic_speech_recognition import HuggingFaceAutomaticSpeechRecognition
67

78
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
89

@@ -11,6 +12,7 @@
1112
"HuggingFaceTextGenerationTransformer",
1213
"HuggingFaceTextSummarizationTransformer",
1314
"HuggingFaceTextTranslationTransformer",
15+
"HuggingFaceAutomaticSpeechRecognition",
1416
]
1517
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationClient"]
1618
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py

Lines changed: 222 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Any, Coroutine, Dict, Optional, List, TYPE_CHECKING
2-
from aiconfig import ParameterizedModelParser, InferenceOptions, AIConfig
1+
from typing import Any, Dict, Literal, Optional, List, TYPE_CHECKING
2+
from aiconfig import ParameterizedModelParser, InferenceOptions
3+
from aiconfig.callback import CallbackEvent
4+
from pydantic import BaseModel
5+
import torch
6+
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment
37

4-
from aiconfig.schema import Prompt, Output
5-
from transformers import Pipeline
8+
from transformers import pipeline, Pipeline
69

710
if TYPE_CHECKING:
811
from aiconfig import AIConfigRuntime
@@ -24,7 +27,7 @@ def __init__(self):
2427
config.register_model_parser(parser)
2528
"""
2629
super().__init__()
27-
self.generators: dict[str, Pipeline] = {}
30+
self.pipelines: dict[str, Pipeline] = {}
2831

2932
def id(self) -> str:
3033
"""
@@ -38,10 +41,10 @@ async def serialize(
3841
data: Any,
3942
ai_config: "AIConfigRuntime",
4043
parameters: Optional[Dict[str, Any]] = None,
41-
**completion_params,
4244
) -> List[Prompt]:
4345
"""
4446
Defines how a prompt and model inference settings get serialized in the .aiconfig.
47+
Assume input in the form of input(s) being passed into an already constructed pipeline.
4548
4649
Args:
4750
prompt (str): The prompt to be serialized.
@@ -52,14 +55,224 @@ async def serialize(
5255
Returns:
5356
str: Serialized representation of the prompt and inference settings.
5457
"""
58+
raise NotImplementedError("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition")
5559

5660
async def deserialize(
5761
self,
5862
prompt: Prompt,
59-
aiconfig: "AIConfig",
63+
aiconfig: "AIConfigRuntime",
6064
params: Optional[Dict[str, Any]] = {},
6165
) -> Dict[str, Any]:
62-
pass
66+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))
67+
68+
# Build Completion data
69+
model_settings = self.get_model_settings(prompt, aiconfig)
70+
[_pipeline_creation_params, unfiltered_completion_params] = refine_pipeline_creation_params(model_settings)
71+
completion_data = refine_asr_completion_params(unfiltered_completion_params)
72+
73+
# ASR Pipeline supports input types of bytes, file path, and a dict containing raw sampled audio. Also supports multiple input
74+
# For now, support multiple or single uri's as input
75+
# TODO: Support or figure out if other input types are needed (base64, bytes), as well as the sampled audio dict
76+
# See api docs for more info:
77+
# - https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/pipelines/automatic_speech_recognition.py#L313-L317
78+
# - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
79+
inputs = validate_and_retrieve_audio_from_attachments(prompt)
80+
81+
completion_data["inputs"] = inputs
82+
83+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))
84+
return completion_data
6385

6486
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
65-
pass
87+
await aiconfig.callback_manager.run_callbacks(
88+
CallbackEvent(
89+
"on_run_start",
90+
__name__,
91+
{"prompt": prompt, "options": options, "parameters": parameters},
92+
)
93+
)
94+
95+
model_settings = self.get_model_settings(prompt, aiconfig)
96+
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)
97+
model_name = aiconfig.get_model_name(prompt)
98+
99+
if isinstance(model_name, str) and model_name not in self.pipelines:
100+
device = self._get_device()
101+
if pipeline_creation_data.get("device", None) is None:
102+
pipeline_creation_data["device"] = device
103+
self.pipelines[model_name] = pipeline(task="automatic-speech-recognition", **pipeline_creation_data)
104+
105+
asr_pipeline = self.pipelines[model_name]
106+
completion_data = await self.deserialize(prompt, aiconfig, parameters)
107+
108+
response = asr_pipeline(**completion_data)
109+
110+
# response is a list of text outputs. This can be tested by running an asr pipeline and noticing the outputs are a list of text.
111+
outputs = construct_outputs(response)
112+
113+
prompt.outputs = outputs
114+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
115+
return prompt.outputs
116+
117+
def _get_device(self) -> str:
118+
if torch.cuda.is_available():
119+
return "cuda"
120+
# Mps backend is not supported for all asr models. Seen when spinning up a default asr pipeline which uses facebook/wav2vec2-base-960h 55bb623
121+
return "cpu"
122+
123+
def get_output_text(self, response: dict[str, Any]) -> str:
124+
raise NotImplementedError("get_output_text is not implemented for HuggingFaceAutomaticSpeechRecognition")
125+
126+
127+
def validate_attachment_type_is_audio(attachment: Attachment):
128+
if not hasattr(attachment, "mime_type"):
129+
raise ValueError(f"Attachment has no mime type. Specify the audio mimetype in the aiconfig")
130+
131+
if not attachment.mime_type.startswith("audio/"):
132+
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected audio mimetype.")
133+
134+
135+
def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]:
136+
"""
137+
Retrieves the audio uri's from each attachment in the prompt input.
138+
139+
Throws an exception if
140+
- attachment is not audio
141+
- attachment data is not a uri
142+
- no attachments are found
143+
- operation fails for any reason
144+
"""
145+
146+
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
147+
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input.")
148+
149+
audio_uris: list[str] = []
150+
151+
for i, attachment in enumerate(prompt.input.attachments):
152+
validate_attachment_type_is_audio(attachment)
153+
154+
try:
155+
# load the attachment data into `AttachmentInputDataWithStringValue`
156+
# TODO: @ankush-lastmile update AIConfig sdk schema to include this type (can also remove this try catch once done)
157+
_ = AttachmentInputDataWithStringValue(**attachment.data)
158+
except Exception as e:
159+
# pylint: disable=W0707
160+
raise ValueError(
161+
f"Attachment #{i} for prompt {prompt.name} is not a valid attachment for Automatic Speech Recognition model parser. "
162+
f"Error: {str(e)}. Please specify a kind and value for the attachment data."
163+
)
164+
# TODO: once previous todo gets resolved, modify this line to `.kind`` instead of .data.get("kind"). It will be a pydantic base class.
165+
if not attachment.data.get("kind") == "file_uri":
166+
# See todo above, but for now only support uri's
167+
raise ValueError(f"Attachment #{i} data is not a uri. Please pass in a uri for the audio attachment in prompt {prompt.name}.")
168+
# TODO: once previous todo gets resolved, modify this line to `.value`` instead of .data.get("value"). It will be a pydantic base class.
169+
audio_uris.append(attachment.data.get("value"))
170+
171+
return audio_uris
172+
173+
174+
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
175+
"""
176+
Refines the pipeline creation params for the HF text2Image generation api.
177+
Defers unsupported params as completion params, where they can get processed in
178+
`refine_image_completion_params()`. The supported keys were found by looking at
179+
the HF Pipelines AutomaticSpeechRecognition API:
180+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
181+
182+
Note that this is not the same as the image completion params, which are passed to
183+
the pipeline later to generate the image:
184+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
185+
"""
186+
187+
supported_keys = {
188+
"model",
189+
"chunk_length_s",
190+
"decoder",
191+
"device",
192+
"framework",
193+
"feature_extractor",
194+
"stride_length_s" "tokenizer",
195+
}
196+
197+
pipeline_creation_params: Dict[str, Any] = {}
198+
completion_params: Dict[str, Any] = {}
199+
for key in model_settings:
200+
if key.lower() in supported_keys:
201+
pipeline_creation_params[key.lower()] = model_settings[key]
202+
else:
203+
if key.lower() == "kwargs" and isinstance(model_settings[key], Dict):
204+
completion_params.update(model_settings[key])
205+
else:
206+
completion_params[key.lower()] = model_settings[key]
207+
208+
return [pipeline_creation_params, completion_params]
209+
210+
211+
def refine_asr_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]:
212+
"""
213+
Refines the ASR params for the HF asr generation api after a
214+
pipeline has been created via `refine_pipeline_creation_params`. Removes any
215+
unsupported params. The supported keys were found by looking at the HF asr
216+
API for asr pipelines:
217+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
218+
219+
Note that this is not the same as the pipeline completion params, which were passed
220+
earlier to generate the pipeline:
221+
https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
222+
223+
Note: This doesn't support base pipeline params like `num_workers`
224+
TODO: Figure out how to find which params are supported.
225+
"""
226+
227+
supported_keys = {
228+
# inputs
229+
"return_timestamps",
230+
"generate_kwargs",
231+
"max_new_tokens",
232+
}
233+
234+
completion_params: Dict[str, Any] = {}
235+
for key in unfiltered_completion_params:
236+
if key.lower() in supported_keys:
237+
completion_params[key.lower()] = unfiltered_completion_params[key]
238+
239+
return completion_params
240+
241+
242+
def construct_outputs(response: list[Any]) -> list[Output]:
243+
"""
244+
Constructs an output from the response of the HF ASR pipeline.
245+
246+
Response from pipeline could contain multiple outputs and time stamps
247+
"""
248+
outputs: list[Output] = []
249+
250+
if not isinstance(response, list):
251+
# response contains a single output. Found by testing variations of the asr pipeline
252+
response = [response]
253+
254+
for i, result in enumerate(response):
255+
# response is expected to be a dict containing the text output and timestamps if specified. Could not find docs for this.
256+
result: dict[str, Any]
257+
output = ExecuteResult(
258+
**{
259+
"output_type": "execute_result",
260+
"data": result.get("text"),
261+
"execution_count": i,
262+
"metadata": {"result": result} if result.get("chunks", False) else {}, # may contain timestamps and chunks, for now pass result
263+
}
264+
)
265+
outputs.append(output)
266+
267+
return outputs
268+
269+
270+
class AttachmentInputDataWithStringValue(BaseModel):
271+
"""
272+
This represents the input data that is storied as a string, but we use
273+
both the `kind` field here and the `mime_type` to convert
274+
the string into the output format we want.
275+
"""
276+
277+
kind: Literal["file_uri", "base64"]
278+
value: str

0 commit comments

Comments
 (0)