Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 37 additions & 32 deletions clarifai_datautils/multimodal/pipeline/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

from .basetransform import BaseTransform

SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""
SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval.
These summaries will be embedded and used to retrieve the raw image.
Give a concise summary of the image that is well optimized for retrieval.
Also add relevant keywords that can be used for search. """


class ImageSummarizer(BaseTransform):
Expand All @@ -26,7 +27,8 @@ class ImageSummarizer(BaseTransform):
def __init__(self,
model_url: str = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat",
pat: str = None,
prompt: str = SUMMARY_PROMPT):
prompt: str = SUMMARY_PROMPT,
batch_size: int = 4):
"""Initializes an ImageSummarizer object.

Args:
Expand All @@ -38,6 +40,7 @@ def __init__(self,
self.model_url = model_url
self.model = Model(url=model_url, pat=pat)
self.summary_prompt = prompt
self.batch_size = batch_size

def __call__(self, elements: List) -> List:
"""Applies the transformation.
Expand Down Expand Up @@ -72,31 +75,33 @@ def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement
Summarized image elements list.

"""
img_inputs = []
for element in image_elements:
if not isinstance(element, Image):
continue
new_input_id = "summarize_" + element.metadata.input_id
input_proto = Inputs.get_multimodal_input(
input_id=new_input_id,
image_bytes=base64.b64decode(element.metadata.image_base64),
raw_text=self.summary_prompt)
img_inputs.append(input_proto)
resp = self.model.predict(img_inputs)
del img_inputs

new_elements = []
for i, output in enumerate(resp.outputs):
summary = ""
if image_elements[i].text:
summary = image_elements[i].text
summary = summary + " \n " + output.data.text.raw
eid = image_elements[i].metadata.input_id
meta_dict = {'source_input_id': eid, 'is_original': False}
comp_element = CompositeElement(
text=summary,
metadata=ElementMetadata.from_dict(meta_dict),
element_id="summarized_" + eid)
new_elements.append(comp_element)

return new_elements
image_summary = []
try:
for i in range(0, len(image_elements), self.batch_size):
batch = image_elements[i:i + self.batch_size]

input_proto = [
Inputs.get_multimodal_input(
input_id=batch[id].metadata.input_id,
image_bytes=base64.b64decode(batch[id].metadata.image_base64),
raw_text=self.summary_prompt) for id in range(len(batch))
if isinstance(batch[id], Image)
]
resp = self.model.predict(input_proto)
for i, output in enumerate(resp.outputs):
summary = ""
if image_elements[i].text:
summary = image_elements[i].text
summary = summary + " \n " + output.data.text.raw
eid = batch[i].metadata.input_id
meta_dict = {'source_input_id': eid, 'is_original': False, 'image_summary': 'yes'}
comp_element = CompositeElement(
text=summary,
metadata=ElementMetadata.from_dict(meta_dict),
element_id="summarized_" + eid)
image_summary.append(comp_element)

except Exception as e:
raise e

return image_summary
Loading