Skip to content

Commit 4e7045b

Browse files
[AIC-py] hf image2text parser
test patch #816 ![pic](https://github.com/lastmile-ai/aiconfig/assets/148090348/d5cc26b3-6cb7-4331-af8a-92fd8c4e2471) python extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/run_hf_example.py extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/hf_local_example.aiconfig.json -> "red fox in the woods"
1 parent 53fbb69 commit 4e7045b

File tree

2 files changed

+141
-4
lines changed

2 files changed

+141
-4
lines changed
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1+
from .local_inference.image_2_text import HuggingFaceImage2TextTransformer
12
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
3+
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer
24
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
3-
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
45
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
56
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
6-
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer
7-
7+
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
88

9-
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
109

1110
LOCAL_INFERENCE_CLASSES = [
1211
"HuggingFaceText2ImageDiffusor",
1312
"HuggingFaceTextGenerationTransformer",
1413
"HuggingFaceTextSummarizationTransformer",
1514
"HuggingFaceTextTranslationTransformer",
1615
"HuggingFaceText2SpeechTransformer",
16+
"HuggingFaceAutomaticSpeechRecognition",
17+
"HuggingFaceImage2TextTransformer",
1718
]
1819
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
1920
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import Any, Dict, Optional, List, TYPE_CHECKING
2+
from aiconfig import ParameterizedModelParser, InferenceOptions
3+
from aiconfig.callback import CallbackEvent
4+
import torch
5+
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment
6+
7+
from transformers import pipeline, Pipeline
8+
9+
if TYPE_CHECKING:
10+
from aiconfig import AIConfigRuntime
11+
12+
13+
class HuggingFaceImage2TextTransformer(ParameterizedModelParser):
14+
def __init__(self):
15+
"""
16+
Returns:
17+
HuggingFaceImage2TextTransformer
18+
19+
Usage:
20+
1. Create a new model parser object with the model ID of the model to use.
21+
parser = HuggingFaceImage2TextTransformer()
22+
2. Add the model parser to the registry.
23+
config.register_model_parser(parser)
24+
"""
25+
super().__init__()
26+
self.pipelines: dict[str, Pipeline] = {}
27+
28+
def id(self) -> str:
29+
"""
30+
Returns an identifier for the Model Parser
31+
"""
32+
return "HuggingFaceImage2TextTransformer"
33+
34+
async def serialize(
35+
self,
36+
prompt_name: str,
37+
data: Any,
38+
ai_config: "AIConfigRuntime",
39+
parameters: Optional[Dict[str, Any]] = None,
40+
) -> List[Prompt]:
41+
"""
42+
Defines how a prompt and model inference settings get serialized in the .aiconfig.
43+
Assume input in the form of input(s) being passed into an already constructed pipeline.
44+
45+
Args:
46+
prompt (str): The prompt to be serialized.
47+
data (Any): Model-specific inference settings to be serialized.
48+
ai_config (AIConfigRuntime): The AIConfig Runtime.
49+
parameters (Dict[str, Any], optional): Model-specific parameters. Defaults to None.
50+
51+
Returns:
52+
str: Serialized representation of the prompt and inference settings.
53+
"""
54+
raise NotImplementedError("serialize is not implemented for HuggingFaceImage2TextTransformer")
55+
56+
async def deserialize(
57+
self,
58+
prompt: Prompt,
59+
aiconfig: "AIConfigRuntime",
60+
params: Optional[Dict[str, Any]] = {},
61+
) -> Dict[str, Any]:
62+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))
63+
64+
# Build Completion data
65+
completion_params = self.get_model_settings(prompt, aiconfig)
66+
67+
inputs = validate_and_retrieve_image_from_attachments(prompt)
68+
69+
completion_params["inputs"] = inputs
70+
71+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
72+
return completion_params
73+
74+
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
75+
await aiconfig.callback_manager.run_callbacks(
76+
CallbackEvent(
77+
"on_run_start",
78+
__name__,
79+
{"prompt": prompt, "options": options, "parameters": parameters},
80+
)
81+
)
82+
model_name = aiconfig.get_model_name(prompt)
83+
84+
self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)
85+
86+
captioner = self.pipelines[model_name]
87+
completion_data = await self.deserialize(prompt, aiconfig, parameters)
88+
print(f"{completion_data=}")
89+
inputs = completion_data.pop("inputs")
90+
model = completion_data.pop("model")
91+
response = captioner(inputs, **completion_data)
92+
93+
output = ExecuteResult(output_type="execute_result", data=response, metadata={})
94+
95+
prompt.outputs = [output]
96+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
97+
return prompt.outputs
98+
99+
def get_output_text(self, response: dict[str, Any]) -> str:
100+
raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer")
101+
102+
103+
def validate_attachment_type_is_image(attachment: Attachment):
104+
if not hasattr(attachment, "mime_type"):
105+
raise ValueError(f"Attachment has no mime type. Specify the image mimetype in the aiconfig")
106+
107+
if not attachment.mime_type.startswith("image/"):
108+
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.")
109+
110+
111+
def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
112+
"""
113+
Retrieves the image uri's from each attachment in the prompt input.
114+
115+
Throws an exception if
116+
- attachment is not image
117+
- attachment data is not a uri
118+
- no attachments are found
119+
- operation fails for any reason
120+
"""
121+
122+
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
123+
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.")
124+
125+
image_uris: list[str] = []
126+
127+
for i, attachment in enumerate(prompt.input.attachments):
128+
validate_attachment_type_is_image(attachment)
129+
130+
if not isinstance(attachment.data, str):
131+
# See todo above, but for now only support uri's
132+
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.")
133+
134+
image_uris.append(attachment.data)
135+
136+
return image_uris

0 commit comments

Comments
 (0)