-
Notifications
You must be signed in to change notification settings - Fork 0
[DEVX-828] Added image summarization in multimodal pipeline #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
eb8d6e9
2979c62
af46344
d52f5a6
fe0d75e
f7dd88a
09c977c
c0cbfb4
9203227
38ccc91
e3a9f13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,6 +51,9 @@ coverage.xml | |
| .hypothesis/ | ||
| .pytest_cache/ | ||
|
|
||
| testing/* | ||
| !testing/test.ipynb | ||
|
|
||
| # Translations | ||
| *.mo | ||
| *.pot | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| import base64 | ||
| import os | ||
| import random | ||
| from typing import List | ||
|
|
||
| try: | ||
| from unstructured.documents.elements import CompositeElement, ElementMetadata, Image | ||
| except ImportError: | ||
| raise ImportError( | ||
| "Could not import unstructured package. " | ||
| "Please install it with `pip install 'unstructured[pdf] @ git+https://github.com/clarifai/unstructured.git@support_clarifai_model'`." | ||
| ) | ||
|
|
||
| from clarifai.client.input import Inputs | ||
| from clarifai.client.model import Model | ||
|
|
||
| from .basetransform import BaseTransform | ||
|
|
||
|
|
||
| class ImageSummarizer(BaseTransform): | ||
| """ Summarizes image elements. """ | ||
|
|
||
| def __init__(self, model_url="https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat"): | ||
sanjaychelliah marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
sanjaychelliah marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Initializes an LlamaIndexWrapper object. | ||
sanjaychelliah marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Args: | ||
| model_url (str): Model URL to use for summarization. | ||
| """ | ||
| self.model_url = model_url | ||
| self.model = Model(url=model_url, pat=os.environ.get("CLARIFAI_PAT")) | ||
| self.summary_prompt = """You are an assistant tasked with summarizing images for retrieval. \ | ||
| These summaries will be embedded and used to retrieve the raw image. \ | ||
| Give a concise summary of the image that is well optimized for retrieval.""" | ||
|
|
||
| def __call__(self, elements: List) -> List: | ||
| """Applies the transformation. | ||
| Args: | ||
| elements (List[str]): List of text elements. | ||
| Returns: | ||
| List of transformed text elements. | ||
| """ | ||
| img_elements = [] | ||
| for _, element in enumerate(elements): | ||
| element.metadata.update( | ||
| ElementMetadata.from_dict({ | ||
| 'is_original': True, | ||
| 'input_id': f'{random.randint(1000000, 99999999)}' | ||
| })) | ||
| if isinstance(element, Image): | ||
| img_elements.append(element) | ||
| # new_elements = Parallel(n_jobs=len(elements))(delayed(self._summarize_image)(element) for element in img_elements) | ||
sanjaychelliah marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| new_elements = self._summarize_image(elements) | ||
sanjaychelliah marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| elements.extend(new_elements) | ||
| return elements | ||
|
|
||
| def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement]: | ||
| """Summarizes an image element. | ||
| Args: | ||
| image_element (Image): Image element to summarize. | ||
| Returns: | ||
| Summarized image element. | ||
| """ | ||
| img_inputs = [] | ||
| for element in image_elements: | ||
| if not isinstance(element, Image): | ||
| continue | ||
| new_input_id = "summarize_" + element.metadata.input_id | ||
| input_proto = Inputs.get_multimodal_input( | ||
| input_id=new_input_id, | ||
| image_bytes=base64.b64decode(element.metadata.image_base64), | ||
| raw_text=self.summary_prompt) | ||
| img_inputs.append(input_proto) | ||
| resp = self.model.predict(img_inputs) | ||
sanjaychelliah marked this conversation as resolved.
Show resolved
Hide resolved
sanjaychelliah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| new_elements = [] | ||
| for i, element in enumerate(resp.outputs): | ||
| summary = "" | ||
| if image_elements[i].text: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we sure that the index of image in image_elements is same as in resp.outputs?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe image elements will not have text, so why this check here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I observed that some image elements had text too... it can be seen in the output of 9th cell in this notebook |
||
| summary = image_elements[i].text | ||
| summary = summary + " \n " + element.data.text.raw | ||
| eid = image_elements[i].metadata.input_id | ||
| meta_dict = {'source_input_id': eid, 'is_original': False} | ||
| comp_element = CompositeElement( | ||
| text=summary, | ||
| metadata=ElementMetadata.from_dict(meta_dict), | ||
| element_id="summarized_" + eid) | ||
| new_elements.append(comp_element) | ||
|
|
||
| return new_elements | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,3 +66,32 @@ def test_pipeline_run_loader(self,): | |
| assert elements.__class__.__name__ == 'MultiModalLoader' | ||
| assert len(elements) == 14 | ||
| assert elements.elements[0].metadata.to_dict()['filename'] == 'Multimodal_sample_file.pdf' | ||
|
|
||
| def test_pipeline_summarize(self,): | ||
| """Tests for pipeline run with summarizer""" | ||
| from clarifai_datautils.multimodal import Pipeline | ||
| from clarifai_datautils.multimodal.pipeline.cleaners import Clean_extra_whitespace | ||
| from clarifai_datautils.multimodal.pipeline.PDF import PDFPartitionMultimodal | ||
| from clarifai_datautils.multimodal.pipeline.summarizer import ImageSummarizer | ||
|
|
||
| pipeline = Pipeline( | ||
| name='pipeline-1', | ||
| transformations=[ | ||
| PDFPartitionMultimodal(chunking_strategy="by_title", max_characters=1024), | ||
| Clean_extra_whitespace(), | ||
| ImageSummarizer() | ||
| ]) | ||
| elements = pipeline.run(files=PDF_FILE_PATH, loader=False) | ||
| assert len(elements) == 15 | ||
| assert isinstance(elements, list) | ||
| assert elements[0].metadata.to_dict()['filename'] == 'Multimodal_sample_file.pdf' | ||
| assert elements[0].metadata.to_dict()['page_number'] == 1 | ||
| assert elements[0].metadata.to_dict()['email_address'] == ['[email protected]'] | ||
| assert elements[6].__class__.__name__ == 'Table' | ||
| assert elements[-2].__class__.__name__ == 'Image' | ||
| assert elements[-2].metadata.is_original is True | ||
| assert elements[-2].metadata.input_id is not None | ||
| id = elements[-2].metadata.input_id | ||
| assert elements[-1].__class__.__name__ == 'CompositeElement' | ||
| assert elements[-1].metadata.is_original is False | ||
| assert elements[-1].metadata.source_input_id == 'summarized_' + id | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: why are we adding this new field ID and will it be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is to identify the corresponding summary that was generated for the images in the PDF!