Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ coverage.xml
.hypothesis/
.pytest_cache/

testing/*
!testing/test.ipynb

# Translations
*.mo
*.pot
Expand Down
8 changes: 6 additions & 2 deletions clarifai_datautils/multimodal/pipeline/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __getitem__(self, index: int):
meta.pop('coordinates', None)
meta.pop('detection_class_prob', None)
image_data = meta.pop('image_base64', None)
id = meta.get('input_id', None)
Comment on lines 29 to +30

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: why are we adding this new field ID and will it be used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is to identify the corresponding summary that was generated for the images in the PDF!

if image_data is not None:
# Ensure image_data is already bytes before encoding
image_data = base64.b64decode(image_data)
Expand All @@ -39,7 +40,7 @@ def __getitem__(self, index: int):
meta['type'] = 'table'

return MultiModalFeatures(
text=text, image_bytes=image_data, labels=[self.pipeline_name], metadata=meta)
text=text, image_bytes=image_data, labels=[self.pipeline_name], metadata=meta, id=id)

def __len__(self):
return len(self.elements)
Expand All @@ -61,10 +62,13 @@ def task(self):
return DATASET_UPLOAD_TASKS.TEXT_CLASSIFICATION #TODO: Better dataset name in SDK

def __getitem__(self, index: int):
id = self.elements[index].to_dict().get('element_id', None)
id = id[:48] if id is not None else None
return TextFeatures(
text=self.elements[index].text,
labels=self.pipeline_name,
metadata=self.elements[index].metadata.to_dict())
metadata=self.elements[index].metadata.to_dict(),
id=id)

def __len__(self):
return len(self.elements)
96 changes: 96 additions & 0 deletions clarifai_datautils/multimodal/pipeline/summarizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import base64
import os
import random
from typing import List

try:
from unstructured.documents.elements import CompositeElement, ElementMetadata, Image
except ImportError:
raise ImportError(
"Could not import unstructured package. "
"Please install it with `pip install 'unstructured[pdf] @ git+https://github.com/clarifai/unstructured.git@support_clarifai_model'`."
)

from clarifai.client.input import Inputs
from clarifai.client.model import Model

from .basetransform import BaseTransform


class ImageSummarizer(BaseTransform):
""" Summarizes image elements. """

def __init__(self, model_url="https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat"):
"""Initializes an LlamaIndexWrapper object.
Args:
model_url (str): Model URL to use for summarization.
"""
self.model_url = model_url
self.model = Model(url=model_url, pat=os.environ.get("CLARIFAI_PAT"))
self.summary_prompt = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""

def __call__(self, elements: List) -> List:
"""Applies the transformation.
Args:
elements (List[str]): List of text elements.
Returns:
List of transformed text elements.
"""
img_elements = []
for _, element in enumerate(elements):
element.metadata.update(
ElementMetadata.from_dict({
'is_original': True,
'input_id': f'{random.randint(1000000, 99999999)}'
}))
if isinstance(element, Image):
img_elements.append(element)
# new_elements = Parallel(n_jobs=len(elements))(delayed(self._summarize_image)(element) for element in img_elements)
new_elements = self._summarize_image(elements)
elements.extend(new_elements)
return elements

def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement]:
"""Summarizes an image element.
Args:
image_element (Image): Image element to summarize.
Returns:
Summarized image element.
"""
img_inputs = []
for element in image_elements:
if not isinstance(element, Image):
continue
new_input_id = "summarize_" + element.metadata.input_id
input_proto = Inputs.get_multimodal_input(
input_id=new_input_id,
image_bytes=base64.b64decode(element.metadata.image_base64),
raw_text=self.summary_prompt)
img_inputs.append(input_proto)
resp = self.model.predict(img_inputs)

new_elements = []
for i, element in enumerate(resp.outputs):
summary = ""
if image_elements[i].text:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure that the index of image in image_elements is same as in resp.outputs?

Copy link
Contributor

@sanjaychelliah sanjaychelliah Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe image elements will not have text, so why this check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that some image elements had text too... it can be seen in the output of 9th cell in this notebook

summary = image_elements[i].text
summary = summary + " \n " + element.data.text.raw
eid = image_elements[i].metadata.input_id
meta_dict = {'source_input_id': eid, 'is_original': False}
comp_element = CompositeElement(
text=summary,
metadata=ElementMetadata.from_dict(meta_dict),
element_id="summarized_" + eid)
new_elements.append(comp_element)

return new_elements
391 changes: 391 additions & 0 deletions testing/test.ipynb

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions tests/pipelines/test_multimodal_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,32 @@ def test_pipeline_run_loader(self,):
assert elements.__class__.__name__ == 'MultiModalLoader'
assert len(elements) == 14
assert elements.elements[0].metadata.to_dict()['filename'] == 'Multimodal_sample_file.pdf'

def test_pipeline_summarize(self,):
"""Tests for pipeline run with summarizer"""
from clarifai_datautils.multimodal import Pipeline
from clarifai_datautils.multimodal.pipeline.cleaners import Clean_extra_whitespace
from clarifai_datautils.multimodal.pipeline.PDF import PDFPartitionMultimodal
from clarifai_datautils.multimodal.pipeline.summarizer import ImageSummarizer

pipeline = Pipeline(
name='pipeline-1',
transformations=[
PDFPartitionMultimodal(chunking_strategy="by_title", max_characters=1024),
Clean_extra_whitespace(),
ImageSummarizer()
])
elements = pipeline.run(files=PDF_FILE_PATH, loader=False)
assert len(elements) == 15
assert isinstance(elements, list)
assert elements[0].metadata.to_dict()['filename'] == 'Multimodal_sample_file.pdf'
assert elements[0].metadata.to_dict()['page_number'] == 1
assert elements[0].metadata.to_dict()['email_address'] == ['[email protected]']
assert elements[6].__class__.__name__ == 'Table'
assert elements[-2].__class__.__name__ == 'Image'
assert elements[-2].metadata.is_original is True
assert elements[-2].metadata.input_id is not None
id = elements[-2].metadata.input_id
assert elements[-1].__class__.__name__ == 'CompositeElement'
assert elements[-1].metadata.is_original is False
assert elements[-1].metadata.source_input_id == 'summarized_' + id
Loading