Skip to content

Commit d328bd0

Browse files
author
Ankush Pala ankush@lastmileai.dev
committed
2/n model parser impl
## Testplan Created an mp3 file that says "hi". Used aiconfig to run asr on it. |<img width="596" alt="Screenshot 2024-01-05 at 2 39 17 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/42c7ddbe-20ca-4828-b609-725b88a08939">|<img width="900" alt="Screenshot 2024-01-05 at 2 41 04 PM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/74b21333-d347-4270-bbd6-efd318785172">| | ------------- | ------------- |
1 parent 15ef3b4 commit d328bd0

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
22
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
3-
from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
3+
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
44
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
55
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
6+
from .local_inference.automatic_speech_recognition import HuggingFaceAutomaticSpeechRecognition
67

78
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
89

@@ -11,6 +12,7 @@
1112
"HuggingFaceTextGenerationTransformer",
1213
"HuggingFaceTextSummarizationTransformer",
1314
"HuggingFaceTextTranslationTransformer",
15+
"HuggingFaceAutomaticSpeechRecognition",
1416
]
1517
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationClient"]
1618
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Any, Coroutine, Dict, Optional, List, TYPE_CHECKING
1+
from typing import Any, Dict, Optional, List, TYPE_CHECKING
22
from aiconfig import ParameterizedModelParser, InferenceOptions, AIConfig
3-
4-
from aiconfig.schema import Prompt, Output
5-
from transformers import Pipeline
3+
import torch
4+
from aiconfig.schema import Prompt, Output, ExecuteResult
5+
from transformers import pipeline, Pipeline
66

77
if TYPE_CHECKING:
88
from aiconfig import AIConfigRuntime
@@ -24,7 +24,7 @@ def __init__(self):
2424
config.register_model_parser(parser)
2525
"""
2626
super().__init__()
27-
self.generators: dict[str, Pipeline] = {}
27+
self.pipelines: dict[str, Pipeline] = {}
2828

2929
def id(self) -> str:
3030
"""
@@ -56,10 +56,41 @@ async def serialize(
5656
async def deserialize(
5757
self,
5858
prompt: Prompt,
59-
aiconfig: "AIConfig",
59+
aiconfig: "AIConfigRuntime",
6060
params: Optional[Dict[str, Any]] = {},
6161
) -> Dict[str, Any]:
62-
pass
62+
# Build Completion data
63+
completion_params = self.get_model_settings(prompt, aiconfig)
64+
65+
inputs = prompt.input.attachments[0].data
66+
67+
completion_params["inputs"] = inputs
68+
return completion_params
6369

6470
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
65-
pass
71+
model_name = aiconfig.get_model_name(prompt)
72+
73+
if isinstance(model_name, str) and model_name not in self.pipelines:
74+
device = self._get_device()
75+
self.pipelines[model_name] = pipeline(task="automatic-speech-recognition", model=model_name, device=device)
76+
77+
asr_pipeline = self.pipelines[model_name]
78+
completion_data = await self.deserialize(prompt, aiconfig, parameters)
79+
80+
response = asr_pipeline(**completion_data)
81+
82+
output = ExecuteResult(output_type="execute_result", data=response, metadata={})
83+
84+
prompt.outputs = [output]
85+
86+
return prompt.outputs
87+
88+
def _get_device(self) -> str:
89+
if torch.cuda.is_available():
90+
return "cuda"
91+
# Mps backend is not supported for all asr models.
92+
# This is currently a torch library limitation. Test this by creating a pipeline with mps backend.
93+
return "cpu"
94+
95+
def get_output_text(self, response: dict[str, Any]) -> str:
96+
return

0 commit comments

Comments
 (0)