Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions clarifai_datautils/multimodal/pipeline/loaders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import uuid

from clarifai_datautils.constants.base import DATASET_UPLOAD_TASKS

Expand Down Expand Up @@ -27,7 +28,10 @@ def __getitem__(self, index: int):
meta.pop('coordinates', None)
meta.pop('detection_class_prob', None)
image_data = meta.pop('image_base64', None)
id = meta.get('input_id', None)
try:
id = self.elements[index].element_id[:8]
except (IndexError, AttributeError, TypeError):
id = str(uuid.uuid4())[:8]
if image_data is not None:
# Ensure image_data is already bytes before encoding
image_data = base64.b64decode(image_data)
Expand All @@ -39,7 +43,8 @@ def __getitem__(self, index: int):
if self.elements[index].to_dict()['type'] == 'Table':
meta['type'] = 'table'

return MultiModalFeatures(text=text, image_bytes=image_data, metadata=meta, id=id)
return MultiModalFeatures(
text=text, image_bytes=image_data, labels=[self.pipeline_name], metadata=meta, id=id)

def __len__(self):
return len(self.elements)
Expand All @@ -64,7 +69,10 @@ def __getitem__(self, index: int):
id = self.elements[index].to_dict().get('element_id', None)
id = id[:48] if id is not None else None
return TextFeatures(
text=self.elements[index].text, metadata=self.elements[index].metadata.to_dict(), id=id)
text=self.elements[index].text,
labels=self.pipeline_name,
metadata=self.elements[index].metadata.to_dict(),
id=id)

def __len__(self):
return len(self.elements)
69 changes: 37 additions & 32 deletions clarifai_datautils/multimodal/pipeline/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

from .basetransform import BaseTransform

SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""
SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval.
These summaries will be embedded and used to retrieve the raw image.
Give a concise summary of the image that is well optimized for retrieval.
Also add relevant keywords that can be used for search. """


class ImageSummarizer(BaseTransform):
Expand All @@ -26,7 +27,8 @@ class ImageSummarizer(BaseTransform):
def __init__(self,
model_url: str = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat",
pat: str = None,
prompt: str = SUMMARY_PROMPT):
prompt: str = SUMMARY_PROMPT,
batch_size: int = 4):
"""Initializes an ImageSummarizer object.

Args:
Expand All @@ -38,6 +40,7 @@ def __init__(self,
self.model_url = model_url
self.model = Model(url=model_url, pat=pat)
self.summary_prompt = prompt
self.batch_size = batch_size

def __call__(self, elements: List) -> List:
"""Applies the transformation.
Expand Down Expand Up @@ -72,31 +75,33 @@ def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement
Summarized image elements list.

"""
img_inputs = []
for element in image_elements:
if not isinstance(element, Image):
continue
new_input_id = "summarize_" + element.metadata.input_id
input_proto = Inputs.get_multimodal_input(
input_id=new_input_id,
image_bytes=base64.b64decode(element.metadata.image_base64),
raw_text=self.summary_prompt)
img_inputs.append(input_proto)
resp = self.model.predict(img_inputs)
del img_inputs

new_elements = []
for i, output in enumerate(resp.outputs):
summary = ""
if image_elements[i].text:
summary = image_elements[i].text
summary = summary + " \n " + output.data.text.raw
eid = image_elements[i].metadata.input_id
meta_dict = {'source_input_id': eid, 'is_original': False}
comp_element = CompositeElement(
text=summary,
metadata=ElementMetadata.from_dict(meta_dict),
element_id="summarized_" + eid)
new_elements.append(comp_element)

return new_elements
image_summary = []
try:
for i in range(0, len(image_elements), self.batch_size):
batch = image_elements[i:i + self.batch_size]

input_proto = [
Inputs.get_multimodal_input(
input_id=batch[id].metadata.input_id,
image_bytes=base64.b64decode(batch[id].metadata.image_base64),
raw_text=self.summary_prompt) for id in range(len(batch))
if isinstance(batch[id], Image)
]
resp = self.model.predict(input_proto)
for i, output in enumerate(resp.outputs):
summary = ""
if image_elements[i].text:
summary = image_elements[i].text
summary = summary + " \n " + output.data.text.raw
eid = batch[i].metadata.input_id
meta_dict = {'source_input_id': eid, 'is_original': False, 'image_summary': 'yes'}
comp_element = CompositeElement(
text=summary,
metadata=ElementMetadata.from_dict(meta_dict),
element_id="summarized_" + eid)
image_summary.append(comp_element)

except Exception as e:
raise e

return image_summary
Loading