-
Notifications
You must be signed in to change notification settings - Fork 91
[AIC-py] hf image2text parser #821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,19 +1,20 @@ | ||
| from .local_inference.image_2_text import HuggingFaceImage2TextTransformer | ||
| from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor | ||
| from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer | ||
| from .local_inference.text_generation import HuggingFaceTextGenerationTransformer | ||
| from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser | ||
| from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer | ||
| from .local_inference.text_translation import HuggingFaceTextTranslationTransformer | ||
| from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer | ||
|
|
||
| from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser | ||
|
|
||
| # from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient | ||
|
|
||
| LOCAL_INFERENCE_CLASSES = [ | ||
| "HuggingFaceText2ImageDiffusor", | ||
| "HuggingFaceTextGenerationTransformer", | ||
| "HuggingFaceTextSummarizationTransformer", | ||
| "HuggingFaceTextTranslationTransformer", | ||
| "HuggingFaceText2SpeechTransformer", | ||
| "HuggingFaceAutomaticSpeechRecognition", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @Ankush-lastmile you may have merge conflicts with your other PR?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Can we also do these in alphabetical order?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in #862 |
||
| "HuggingFaceImage2TextTransformer", | ||
| ] | ||
| REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"] | ||
| __ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| from typing import Any, Dict, Optional, List, TYPE_CHECKING | ||
| from aiconfig import ParameterizedModelParser, InferenceOptions | ||
| from aiconfig.callback import CallbackEvent | ||
| import torch | ||
| from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment | ||
|
|
||
| from transformers import pipeline, Pipeline | ||
|
|
||
| if TYPE_CHECKING: | ||
| from aiconfig import AIConfigRuntime | ||
|
|
||
|
|
||
| class HuggingFaceImage2TextTransformer(ParameterizedModelParser): | ||
| def __init__(self): | ||
| """ | ||
| Returns: | ||
| HuggingFaceImage2TextTransformer | ||
|
|
||
| Usage: | ||
| 1. Create a new model parser object with the model ID of the model to use. | ||
| parser = HuggingFaceImage2TextTransformer() | ||
| 2. Add the model parser to the registry. | ||
| config.register_model_parser(parser) | ||
| """ | ||
| super().__init__() | ||
| self.pipelines: dict[str, Pipeline] = {} | ||
|
|
||
| def id(self) -> str: | ||
| """ | ||
| Returns an identifier for the Model Parser | ||
| """ | ||
| return "HuggingFaceImage2TextTransformer" | ||
|
|
||
| async def serialize( | ||
| self, | ||
| prompt_name: str, | ||
| data: Any, | ||
| ai_config: "AIConfigRuntime", | ||
| parameters: Optional[Dict[str, Any]] = None, | ||
| ) -> List[Prompt]: | ||
| """ | ||
| Defines how a prompt and model inference settings get serialized in the .aiconfig. | ||
| Assume input in the form of input(s) being passed into an already constructed pipeline. | ||
|
|
||
| Args: | ||
| prompt (str): The prompt to be serialized. | ||
| data (Any): Model-specific inference settings to be serialized. | ||
| ai_config (AIConfigRuntime): The AIConfig Runtime. | ||
| parameters (Dict[str, Any], optional): Model-specific parameters. Defaults to None. | ||
|
|
||
| Returns: | ||
| str: Serialized representation of the prompt and inference settings. | ||
| """ | ||
| await ai_config.callback_manager.run_callbacks( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a TODO linking to #822 to fix later(and add automated testing. I'll do this later
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's broken?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're not using the correct model_id so I need to pass this in so we can re-create the prompt
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added TODO comment in code in #862 |
||
| CallbackEvent( | ||
| "on_serialize_start", | ||
| __name__, | ||
| { | ||
| "prompt_name": prompt_name, | ||
| "data": data, | ||
| "parameters": parameters, | ||
| }, | ||
| ) | ||
| ) | ||
|
|
||
| prompts = [] | ||
|
|
||
| if not isinstance(data, dict): | ||
| raise ValueError("Invalid data type. Expected dict when serializing prompt data to aiconfig.") | ||
| if data.get("inputs", None) is None: | ||
| raise ValueError("Invalid data when serializing prompt to aiconfig. Input data must contain an inputs field.") | ||
|
|
||
| prompt = Prompt( | ||
| **{ | ||
| "name": prompt_name, | ||
| "input": {"attachments": [{"data": data["inputs"]}]}, | ||
| "metadata": None, | ||
| "outputs": None, | ||
| } | ||
| ) | ||
|
|
||
| prompts.append(prompt) | ||
|
|
||
| await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts})) | ||
| return prompts | ||
|
|
||
| async def deserialize( | ||
| self, | ||
| prompt: Prompt, | ||
| aiconfig: "AIConfigRuntime", | ||
| params: Optional[Dict[str, Any]] = {}, | ||
| ) -> Dict[str, Any]: | ||
| await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) | ||
|
|
||
| # Build Completion data | ||
| completion_params = self.get_model_settings(prompt, aiconfig) | ||
|
|
||
| inputs = validate_and_retrieve_image_from_attachments(prompt) | ||
|
|
||
| completion_params["inputs"] = inputs | ||
|
Comment on lines
+98
to
+100
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @Ankush-lastmile when this lands, can you link to the Attachment format/standardizing inputs issue we mentioned? Jonathan you don't need to do any work, just making sure Ankush is aware |
||
|
|
||
| await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) | ||
| return completion_params | ||
|
|
||
| async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: | ||
| await aiconfig.callback_manager.run_callbacks( | ||
| CallbackEvent( | ||
| "on_run_start", | ||
| __name__, | ||
| {"prompt": prompt, "options": options, "parameters": parameters}, | ||
| ) | ||
| ) | ||
| model_name = aiconfig.get_model_name(prompt) | ||
|
|
||
| self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name) | ||
|
|
||
| captioner = self.pipelines[model_name] | ||
| completion_data = await self.deserialize(prompt, aiconfig, parameters) | ||
| inputs = completion_data.pop("inputs") | ||
| model = completion_data.pop("model") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is never used again, why would we have it in completion data in the first place? If it's never used, pls prefix with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it has to be removed from completion data. Something is definitely off here, I just don't know exactly what. cc @Ankush-lastmile
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| response = captioner(inputs, **completion_data) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does pipeline only support inputs as URI, or does it also work with base64 encoded? If not, pls make task that we need to convert from base64 --> image URI first
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in #856
rossdanlm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| output = ExecuteResult(output_type="execute_result", data=response, metadata={}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh sweet, so response is just purely text? nice! Also let's add "execution_count=0"
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in #855 |
||
|
|
||
| prompt.outputs = [output] | ||
| await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) | ||
| return prompt.outputs | ||
|
|
||
| def get_output_text(self, response: dict[str, Any]) -> str: | ||
| raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pls update to match others like the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in #855 |
||
|
|
||
|
|
||
| def validate_attachment_type_is_image(attachment: Attachment): | ||
| if not hasattr(attachment, "mime_type"): | ||
| raise ValueError(f"Attachment has no mime type. Specify the image mimetype in the aiconfig") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit; add the work "Please" before "Specify"
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated in #862 |
||
|
|
||
| if not attachment.mime_type.startswith("image/"): | ||
| raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.") | ||
|
|
||
|
|
||
| def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]: | ||
| """ | ||
| Retrieves the image uri's from each attachment in the prompt input. | ||
|
|
||
| Throws an exception if | ||
| - attachment is not image | ||
| - attachment data is not a uri | ||
| - no attachments are found | ||
| - operation fails for any reason | ||
| """ | ||
|
|
||
| if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0: | ||
| raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.") | ||
|
|
||
| image_uris: list[str] = [] | ||
|
|
||
| for i, attachment in enumerate(prompt.input.attachments): | ||
| validate_attachment_type_is_image(attachment) | ||
|
|
||
| if not isinstance(attachment.data, str): | ||
| # See todo above, but for now only support uri's | ||
| raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.") | ||
|
Comment on lines
+161
to
+162
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't have to be this diff, but please add support for base64 as well. This is important since if we want to chain prompts, some of our models output in base64 format (ex: text_2_image) At very least, create an issue to track
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in #856 |
||
|
|
||
| image_uris.append(attachment.data) | ||
|
|
||
| return image_uris | ||
Uh oh!
There was an error while loading. Please reload this page.