Skip to content

Commit 3f0cbce

Browse files
authored
[AIC-py] hf image2text parser (#821)
[AIC-py] hf image2text parser test patch #816 ![pic](https://github.com/lastmile-ai/aiconfig/assets/148090348/d5cc26b3-6cb7-4331-af8a-92fd8c4e2471) python extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/run_hf_example.py extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/hf_local_example.aiconfig.json -> "red fox in the woods"
2 parents bcd2921 + 448d52a commit 3f0cbce

File tree

2 files changed

+171
-4
lines changed

2 files changed

+171
-4
lines changed
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1+
from .local_inference.image_2_text import HuggingFaceImage2TextTransformer
12
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
3+
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer
24
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
3-
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
45
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
56
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
6-
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer
7-
7+
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
88

9-
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
109

1110
LOCAL_INFERENCE_CLASSES = [
1211
"HuggingFaceText2ImageDiffusor",
1312
"HuggingFaceTextGenerationTransformer",
1413
"HuggingFaceTextSummarizationTransformer",
1514
"HuggingFaceTextTranslationTransformer",
1615
"HuggingFaceText2SpeechTransformer",
16+
"HuggingFaceAutomaticSpeechRecognition",
17+
"HuggingFaceImage2TextTransformer",
1718
]
1819
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
1920
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from typing import Any, Dict, Optional, List, TYPE_CHECKING
2+
from aiconfig import ParameterizedModelParser, InferenceOptions
3+
from aiconfig.callback import CallbackEvent
4+
import torch
5+
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment
6+
7+
from transformers import pipeline, Pipeline
8+
9+
if TYPE_CHECKING:
10+
from aiconfig import AIConfigRuntime
11+
12+
13+
class HuggingFaceImage2TextTransformer(ParameterizedModelParser):
14+
def __init__(self):
15+
"""
16+
Returns:
17+
HuggingFaceImage2TextTransformer
18+
19+
Usage:
20+
1. Create a new model parser object with the model ID of the model to use.
21+
parser = HuggingFaceImage2TextTransformer()
22+
2. Add the model parser to the registry.
23+
config.register_model_parser(parser)
24+
"""
25+
super().__init__()
26+
self.pipelines: dict[str, Pipeline] = {}
27+
28+
def id(self) -> str:
29+
"""
30+
Returns an identifier for the Model Parser
31+
"""
32+
return "HuggingFaceImage2TextTransformer"
33+
34+
async def serialize(
35+
self,
36+
prompt_name: str,
37+
data: Any,
38+
ai_config: "AIConfigRuntime",
39+
parameters: Optional[Dict[str, Any]] = None,
40+
) -> List[Prompt]:
41+
"""
42+
Defines how a prompt and model inference settings get serialized in the .aiconfig.
43+
Assume input in the form of input(s) being passed into an already constructed pipeline.
44+
45+
Args:
46+
prompt (str): The prompt to be serialized.
47+
data (Any): Model-specific inference settings to be serialized.
48+
ai_config (AIConfigRuntime): The AIConfig Runtime.
49+
parameters (Dict[str, Any], optional): Model-specific parameters. Defaults to None.
50+
51+
Returns:
52+
str: Serialized representation of the prompt and inference settings.
53+
"""
54+
await ai_config.callback_manager.run_callbacks(
55+
CallbackEvent(
56+
"on_serialize_start",
57+
__name__,
58+
{
59+
"prompt_name": prompt_name,
60+
"data": data,
61+
"parameters": parameters,
62+
},
63+
)
64+
)
65+
66+
prompts = []
67+
68+
if not isinstance(data, dict):
69+
raise ValueError("Invalid data type. Expected dict when serializing prompt data to aiconfig.")
70+
if data.get("inputs", None) is None:
71+
raise ValueError("Invalid data when serializing prompt to aiconfig. Input data must contain an inputs field.")
72+
73+
prompt = Prompt(
74+
**{
75+
"name": prompt_name,
76+
"input": {"attachments": [{"data": data["inputs"]}]},
77+
"metadata": None,
78+
"outputs": None,
79+
}
80+
)
81+
82+
prompts.append(prompt)
83+
84+
await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts}))
85+
return prompts
86+
87+
async def deserialize(
88+
self,
89+
prompt: Prompt,
90+
aiconfig: "AIConfigRuntime",
91+
params: Optional[Dict[str, Any]] = {},
92+
) -> Dict[str, Any]:
93+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))
94+
95+
# Build Completion data
96+
completion_params = self.get_model_settings(prompt, aiconfig)
97+
98+
inputs = validate_and_retrieve_image_from_attachments(prompt)
99+
100+
completion_params["inputs"] = inputs
101+
102+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
103+
return completion_params
104+
105+
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
106+
await aiconfig.callback_manager.run_callbacks(
107+
CallbackEvent(
108+
"on_run_start",
109+
__name__,
110+
{"prompt": prompt, "options": options, "parameters": parameters},
111+
)
112+
)
113+
model_name = aiconfig.get_model_name(prompt)
114+
115+
self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)
116+
117+
captioner = self.pipelines[model_name]
118+
completion_data = await self.deserialize(prompt, aiconfig, parameters)
119+
inputs = completion_data.pop("inputs")
120+
model = completion_data.pop("model")
121+
response = captioner(inputs, **completion_data)
122+
123+
output = ExecuteResult(output_type="execute_result", data=response, metadata={})
124+
125+
prompt.outputs = [output]
126+
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
127+
return prompt.outputs
128+
129+
def get_output_text(self, response: dict[str, Any]) -> str:
130+
raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer")
131+
132+
133+
def validate_attachment_type_is_image(attachment: Attachment):
134+
if not hasattr(attachment, "mime_type"):
135+
raise ValueError(f"Attachment has no mime type. Specify the image mimetype in the aiconfig")
136+
137+
if not attachment.mime_type.startswith("image/"):
138+
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.")
139+
140+
141+
def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
142+
"""
143+
Retrieves the image uri's from each attachment in the prompt input.
144+
145+
Throws an exception if
146+
- attachment is not image
147+
- attachment data is not a uri
148+
- no attachments are found
149+
- operation fails for any reason
150+
"""
151+
152+
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
153+
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.")
154+
155+
image_uris: list[str] = []
156+
157+
for i, attachment in enumerate(prompt.input.attachments):
158+
validate_attachment_type_is_image(attachment)
159+
160+
if not isinstance(attachment.data, str):
161+
# See todo above, but for now only support uri's
162+
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.")
163+
164+
image_uris.append(attachment.data)
165+
166+
return image_uris

0 commit comments

Comments
 (0)