Skip to content

Commit 9f5cd3d

Browse files
authored
HF transformers: Small fixes nits (#862)
HF transformers: Small fixes nits Small fixes from comments from Sarmad + me from these diffs: - #854 - #855 - #821 Main things I did - rename `refine_chat_completion_params` --> `chat_completion_params` - edit `get_text_output` to not check for `OutputDataWithValue` - sorted the init file to be alphabetical - fixed some typos/print statements - made some error messages a bit more intuitive with prompt name - sorted some imports - fixed old class name `HuggingFaceAutomaticSpeechRecognition` --> `HuggingFaceAutomaticSpeechRecognitionTransformer` ## Test Plan These are all small nits and shouldn't change functionality
2 parents a036254 + 0e9c8cd commit 9f5cd3d

File tree

9 files changed

+72
-79
lines changed

9 files changed

+72
-79
lines changed
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
1+
from .local_inference.automatic_speech_recognition import HuggingFaceAutomaticSpeechRecognitionTransformer
12
from .local_inference.image_2_text import HuggingFaceImage2TextTransformer
23
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
34
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer
45
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
56
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
67
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
78
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
8-
from .local_inference.automatic_speech_recognition import HuggingFaceAutomaticSpeechRecognitionTransformer
99

1010

1111
LOCAL_INFERENCE_CLASSES = [
12+
"HuggingFaceAutomaticSpeechRecognitionTransformer",
13+
"HuggingFaceImage2TextTransformer",
1214
"HuggingFaceText2ImageDiffusor",
15+
"HuggingFaceText2SpeechTransformer",
1316
"HuggingFaceTextGenerationTransformer",
1417
"HuggingFaceTextSummarizationTransformer",
1518
"HuggingFaceTextTranslationTransformer",
16-
"HuggingFaceText2SpeechTransformer",
17-
"HuggingFaceAutomaticSpeechRecognition",
18-
"HuggingFaceImage2TextTransformer",
19-
"HuggingFaceAutomaticSpeechRecognitionTransformer",
2019
]
2120
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
2221
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
1-
from typing import Any, Dict, Literal, Optional, List, TYPE_CHECKING
1+
from typing import Any, Dict, Optional, List, TYPE_CHECKING
2+
3+
import torch
4+
from transformers import pipeline, Pipeline
25
from aiconfig import ParameterizedModelParser, InferenceOptions
36
from aiconfig.callback import CallbackEvent
4-
from pydantic import BaseModel
5-
import torch
67
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment
78

8-
from transformers import pipeline, Pipeline
99

1010
if TYPE_CHECKING:
1111
from aiconfig import AIConfigRuntime
12-
"""
13-
Model Parser for HuggingFace ASR (Automatic Speech Recognition) models.
14-
"""
1512

1613

1714
class HuggingFaceAutomaticSpeechRecognitionTransformer(ParameterizedModelParser):
15+
"""
16+
Model Parser for HuggingFace ASR (Automatic Speech Recognition) models.
17+
"""
18+
1819
def __init__(self):
1920
"""
2021
Returns:
21-
HuggingFaceAutomaticSpeechRecognition
22+
HuggingFaceAutomaticSpeechRecognitionTransformer
2223
2324
Usage:
2425
1. Create a new model parser object with the model ID of the model to use.
25-
parser = HuggingFaceAutomaticSpeechRecognition()
26+
parser = HuggingFaceAutomaticSpeechRecognitionTransformer()
2627
2. Add the model parser to the registry.
2728
config.register_model_parser(parser)
2829
"""
@@ -55,7 +56,8 @@ async def serialize(
5556
Returns:
5657
str: Serialized representation of the prompt and inference settings.
5758
"""
58-
raise NotImplementedError("serialize is not implemented for HuggingFaceAutomaticSpeechRecognition")
59+
# TODO: See https://github.com/lastmile-ai/aiconfig/issues/822
60+
raise NotImplementedError("serialize is not implemented for HuggingFaceAutomaticSpeechRecognitionTransformer")
5961

6062
async def deserialize(
6163
self,

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Attachment,
1515
ExecuteResult,
1616
Output,
17-
OutputDataWithValue,
1817
Prompt,
1918
)
2019

@@ -140,7 +139,6 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
140139
outputs.append(output)
141140

142141
prompt.outputs = outputs
143-
print(f"{prompt.outputs=}")
144142
await aiconfig.callback_manager.run_callbacks(
145143
CallbackEvent(
146144
"on_run_complete",
@@ -168,12 +166,9 @@ def get_output_text(
168166
output_data = output.data
169167
if isinstance(output_data, str):
170168
return output_data
171-
if isinstance(output_data, OutputDataWithValue):
172-
if isinstance(output_data.value, str):
173-
return output_data.value
174-
# HuggingFace Text summarization does not support function
175-
# calls so shouldn't get here, but just being safe
176-
return json.dumps(output_data.value, indent=2)
169+
# HuggingFace image to text outputs should only ever be string
170+
# format so shouldn't get here, but just being safe
171+
return json.dumps(output_data, indent=2)
177172
return ""
178173

179174

@@ -213,12 +208,19 @@ def construct_regular_output(result: Dict[str, str], execution_count: int) -> Ou
213208
return output
214209

215210

216-
def validate_attachment_type_is_image(attachment: Attachment):
211+
def validate_attachment_type_is_image(
212+
prompt_name: str,
213+
attachment: Attachment,
214+
) -> None:
215+
"""
216+
Simple helper function to verify that the mimetype is set to a valid
217+
image format. Raises ValueError if there's an issue.
218+
"""
217219
if not hasattr(attachment, "mime_type"):
218-
raise ValueError(f"Attachment has no mime type. Specify the image mimetype in the aiconfig")
220+
raise ValueError(f"Attachment has no mime type for prompt '{prompt_name}'. Please specify the image mimetype in the AIConfig")
219221

220222
if not attachment.mime_type.startswith("image/"):
221-
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.")
223+
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type} for prompt '{prompt_name}'. Please use a mimetype that starts with 'image/'.")
222224

223225

224226
def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[str, Image]]:
@@ -233,17 +235,17 @@ def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[
233235
"""
234236

235237
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
236-
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.")
238+
raise ValueError(f"No attachments found in input for prompt '{prompt.name}'. Please add an image attachment to the prompt input.")
237239

238240
images: list[Union[str, Image]] = []
239241

240242
for i, attachment in enumerate(prompt.input.attachments):
241-
validate_attachment_type_is_image(attachment)
243+
validate_attachment_type_is_image(prompt.name, attachment)
242244

243245
input_data = attachment.data
244246
if not isinstance(input_data, str):
245-
# See todo above, but for now only support uri's
246-
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.")
247+
# See todo above, but for now only support uris and base64
248+
raise ValueError(f"Attachment #{i} data is not a uri or base64 string. Please specify a uri or base64 encoded string for the image attachment in prompt '{prompt.name}'.")
247249

248250
# Really basic heurestic to check if the data is a base64 encoded str
249251
# vs. uri. This will be fixed once we have standardized inputs

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import io
44
import itertools
5+
import json
56
import torch
67
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
78
from diffusers import AutoPipelineForText2Image
@@ -351,16 +352,21 @@ def get_output_text(
351352
# TODO (rossdanlm): Handle multiple outputs in list
352353
# https://github.com/lastmile-ai/aiconfig/issues/467
353354
if output.output_type == "execute_result":
354-
if isinstance(output.data, OutputDataWithStringValue):
355-
return output.data.value
356-
elif isinstance(output.data, str):
357-
return output.data
355+
output_data = output.data
356+
if isinstance(output_data, OutputDataWithStringValue):
357+
return output_data.value
358+
# HuggingFace text to image outputs should only ever be in
359+
# outputDataWithStringValue format so shouldn't get here, but
360+
# just being safe
361+
if isinstance(output_data, str):
362+
return output_data
363+
return json.dumps(output_data, indent=2)
358364
return ""
359365

360366
def _get_device(self) -> str:
361367
if torch.cuda.is_available():
362368
return "cuda"
363-
elif torch.backends.mps.is_available():
369+
if torch.backends.mps.is_available():
364370
return "mps"
365371
return "cpu"
366372

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import base64
22
import copy
33
import io
4+
import json
45
import numpy as np
5-
import torch
66
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
77
from transformers import Pipeline, pipeline
88
from scipy.io.wavfile import write as write_wav
@@ -12,7 +12,6 @@
1212
from aiconfig.schema import (
1313
ExecuteResult,
1414
Output,
15-
OutputDataWithValue,
1615
Prompt,
1716
PromptMetadata,
1817
)
@@ -25,7 +24,7 @@
2524

2625
# Step 1: define Helpers
2726
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
28-
# There are from the transformers Github repo:
27+
# These are from the transformers Github repo:
2928
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2534
3029
supported_keys = {
3130
"torch_dtype",
@@ -228,8 +227,9 @@ def get_output_text(
228227
# TODO (rossdanlm): Handle multiple outputs in list
229228
# https://github.com/lastmile-ai/aiconfig/issues/467
230229
if output.output_type == "execute_result":
231-
if isinstance(output.data, OutputDataWithValue):
232-
return output.data.value
233-
elif isinstance(output.data, str):
230+
if isinstance(output.data, str):
234231
return output.data
232+
# HuggingFace text to speech outputs should only ever be string
233+
# format so shouldn't get here, but just being safe
234+
return json.dumps(output.data, indent=2)
235235
return ""

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from aiconfig.schema import (
1515
ExecuteResult,
1616
Output,
17-
OutputDataWithValue,
1817
Prompt,
1918
PromptMetadata,
2019
)
@@ -26,7 +25,7 @@
2625

2726

2827
# Step 1: define Helpers
29-
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
28+
def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
3029
"""
3130
Refines the completion params for the HF text generation api. Removes any unsupported params.
3231
The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()`
@@ -216,7 +215,7 @@ async def deserialize(
216215
"""
217216
# Build Completion data
218217
model_settings = self.get_model_settings(prompt, aiconfig)
219-
completion_data = refine_chat_completion_params(model_settings)
218+
completion_data = refine_completion_params(model_settings)
220219

221220
#Add resolved prompt
222221
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
@@ -296,10 +295,8 @@ def get_output_text(
296295
output_data = output.data
297296
if isinstance(output_data, str):
298297
return output_data
299-
if isinstance(output_data, OutputDataWithValue):
300-
if isinstance(output_data.value, str):
301-
return output_data.value
302-
# HuggingFace Text generation does not support function
303-
# calls so shouldn't get here, but just being safe
304-
return json.dumps(output_data.value, indent=2)
298+
# HuggingFace text generation outputs should only ever be in
299+
# string format so shouldn't get here, but
300+
# just being safe
301+
return json.dumps(output_data, indent=2)
305302
return ""

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from aiconfig.schema import (
1515
ExecuteResult,
1616
Output,
17-
OutputDataWithValue,
1817
Prompt,
1918
PromptMetadata,
2019
)
@@ -26,7 +25,7 @@
2625

2726

2827
# Step 1: define Helpers
29-
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
28+
def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
3029
"""
3130
Refines the completion params for the HF text summarization api. Removes any unsupported params.
3231
The supported keys were found by looking at the HF text summarization api. `huggingface_hub.InferenceClient.text_summarization()`
@@ -221,7 +220,7 @@ async def deserialize(
221220
"""
222221
# Build Completion data
223222
model_settings = self.get_model_settings(prompt, aiconfig)
224-
completion_data = refine_chat_completion_params(model_settings)
223+
completion_data = refine_completion_params(model_settings)
225224

226225
# Add resolved prompt
227226
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
@@ -301,10 +300,7 @@ def get_output_text(
301300
output_data = output.data
302301
if isinstance(output_data, str):
303302
return output_data
304-
if isinstance(output_data, OutputDataWithValue):
305-
if isinstance(output_data.value, str):
306-
return output_data.value
307-
# HuggingFace Text summarization does not support function
308-
# calls so shouldn't get here, but just being safe
309-
return json.dumps(output_data.value, indent=2)
303+
# HuggingFace text summarization outputs should only ever be in
304+
# string format so shouldn't get here, but just being safe
305+
return json.dumps(output_data, indent=2)
310306
return ""

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from aiconfig.schema import (
1515
ExecuteResult,
1616
Output,
17-
OutputDataWithValue,
1817
Prompt,
1918
PromptMetadata,
2019
)
@@ -26,7 +25,7 @@
2625

2726

2827
# Step 1: define Helpers
29-
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
28+
def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
3029
"""
3130
Refines the completion params for the HF text translation api. Removes any unsupported params.
3231
The supported keys were found by looking at the HF text translation api. `huggingface_hub.InferenceClient.text_translation()`
@@ -223,7 +222,7 @@ async def deserialize(
223222
"""
224223
# Build Completion data
225224
model_settings = self.get_model_settings(prompt, aiconfig)
226-
completion_data = refine_chat_completion_params(model_settings)
225+
completion_data = refine_completion_params(model_settings)
227226

228227
# Add resolved prompt
229228
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
@@ -304,10 +303,7 @@ def get_output_text(
304303
output_data = output.data
305304
if isinstance(output_data, str):
306305
return output_data
307-
if isinstance(output_data, OutputDataWithValue):
308-
if isinstance(output_data.value, str):
309-
return output_data.value
310-
# HuggingFace Text translation does not support function
311-
# calls so shouldn't get here, but just being safe
312-
return json.dumps(output_data.value, indent=2)
306+
# HuggingFace text translation outputs should only ever be in
307+
# string format so shouldn't get here, but just being safe
308+
return json.dumps(output_data, indent=2)
313309
return ""

extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from aiconfig.schema import (
1616
ExecuteResult,
1717
Output,
18-
OutputDataWithValue,
1918
Prompt,
2019
PromptMetadata,
2120
)
@@ -29,9 +28,7 @@
2928

3029

3130
# Step 1: define Helpers
32-
33-
34-
def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
31+
def refine_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
3532
"""
3633
Refines the completion params for the HF text generation api. Removes any unsupported params.
3734
The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()`
@@ -243,7 +240,7 @@ async def deserialize(
243240
# Build Completion data
244241
model_settings = self.get_model_settings(prompt, aiconfig)
245242

246-
completion_data = refine_chat_completion_params(model_settings)
243+
completion_data = refine_completion_params(model_settings)
247244

248245
completion_data["prompt"] = resolved_prompt
249246

@@ -318,10 +315,8 @@ def get_output_text(
318315
output_data = output.data
319316
if isinstance(output_data, str):
320317
return output_data
321-
if isinstance(output_data, OutputDataWithValue):
322-
if isinstance(output_data.value, str):
323-
return output_data.value
324-
# HuggingFace Text generation does not support function
325-
# calls so shouldn't get here, but just being safe
326-
return json.dumps(output_data.value, indent=2)
318+
319+
# HuggingFace text generation outputs should only ever be string
320+
# format so shouldn't get here, but just being safe
321+
return json.dumps(output_data, indent=2)
327322
return ""

0 commit comments

Comments
 (0)