Skip to content

Commit 876a848

Browse files
authored
Added batch predict for Imagesummarizer (#36)
* Added batch predict for Imagesummarizer * Addressed lint * Added input_id for elements
1 parent 50cecde commit 876a848

File tree

2 files changed

+48
-35
lines changed

2 files changed

+48
-35
lines changed

clarifai_datautils/multimodal/pipeline/loaders.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import uuid
23

34
from clarifai_datautils.constants.base import DATASET_UPLOAD_TASKS
45

@@ -27,7 +28,10 @@ def __getitem__(self, index: int):
2728
meta.pop('coordinates', None)
2829
meta.pop('detection_class_prob', None)
2930
image_data = meta.pop('image_base64', None)
30-
id = meta.get('input_id', None)
31+
try:
32+
id = self.elements[index].element_id[:8]
33+
except (IndexError, AttributeError, TypeError):
34+
id = str(uuid.uuid4())[:8]
3135
if image_data is not None:
3236
# Ensure image_data is already bytes before encoding
3337
image_data = base64.b64decode(image_data)
@@ -39,7 +43,8 @@ def __getitem__(self, index: int):
3943
if self.elements[index].to_dict()['type'] == 'Table':
4044
meta['type'] = 'table'
4145

42-
return MultiModalFeatures(text=text, image_bytes=image_data, metadata=meta, id=id)
46+
return MultiModalFeatures(
47+
text=text, image_bytes=image_data, labels=[self.pipeline_name], metadata=meta, id=id)
4348

4449
def __len__(self):
4550
return len(self.elements)
@@ -64,7 +69,10 @@ def __getitem__(self, index: int):
6469
id = self.elements[index].to_dict().get('element_id', None)
6570
id = id[:48] if id is not None else None
6671
return TextFeatures(
67-
text=self.elements[index].text, metadata=self.elements[index].metadata.to_dict(), id=id)
72+
text=self.elements[index].text,
73+
labels=self.pipeline_name,
74+
metadata=self.elements[index].metadata.to_dict(),
75+
id=id)
6876

6977
def __len__(self):
7078
return len(self.elements)

clarifai_datautils/multimodal/pipeline/summarizer.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
from .basetransform import BaseTransform
1717

18-
SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \
19-
These summaries will be embedded and used to retrieve the raw image. \
20-
Give a concise summary of the image that is well optimized for retrieval."""
18+
SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval.
19+
These summaries will be embedded and used to retrieve the raw image.
20+
Give a concise summary of the image that is well optimized for retrieval.
21+
Also add relevant keywords that can be used for search. """
2122

2223

2324
class ImageSummarizer(BaseTransform):
@@ -26,7 +27,8 @@ class ImageSummarizer(BaseTransform):
2627
def __init__(self,
2728
model_url: str = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat",
2829
pat: str = None,
29-
prompt: str = SUMMARY_PROMPT):
30+
prompt: str = SUMMARY_PROMPT,
31+
batch_size: int = 4):
3032
"""Initializes an ImageSummarizer object.
3133
3234
Args:
@@ -38,6 +40,7 @@ def __init__(self,
3840
self.model_url = model_url
3941
self.model = Model(url=model_url, pat=pat)
4042
self.summary_prompt = prompt
43+
self.batch_size = batch_size
4144

4245
def __call__(self, elements: List) -> List:
4346
"""Applies the transformation.
@@ -72,31 +75,33 @@ def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement
7275
Summarized image elements list.
7376
7477
"""
75-
img_inputs = []
76-
for element in image_elements:
77-
if not isinstance(element, Image):
78-
continue
79-
new_input_id = "summarize_" + element.metadata.input_id
80-
input_proto = Inputs.get_multimodal_input(
81-
input_id=new_input_id,
82-
image_bytes=base64.b64decode(element.metadata.image_base64),
83-
raw_text=self.summary_prompt)
84-
img_inputs.append(input_proto)
85-
resp = self.model.predict(img_inputs)
86-
del img_inputs
87-
88-
new_elements = []
89-
for i, output in enumerate(resp.outputs):
90-
summary = ""
91-
if image_elements[i].text:
92-
summary = image_elements[i].text
93-
summary = summary + " \n " + output.data.text.raw
94-
eid = image_elements[i].metadata.input_id
95-
meta_dict = {'source_input_id': eid, 'is_original': False}
96-
comp_element = CompositeElement(
97-
text=summary,
98-
metadata=ElementMetadata.from_dict(meta_dict),
99-
element_id="summarized_" + eid)
100-
new_elements.append(comp_element)
101-
102-
return new_elements
78+
image_summary = []
79+
try:
80+
for i in range(0, len(image_elements), self.batch_size):
81+
batch = image_elements[i:i + self.batch_size]
82+
83+
input_proto = [
84+
Inputs.get_multimodal_input(
85+
input_id=batch[id].metadata.input_id,
86+
image_bytes=base64.b64decode(batch[id].metadata.image_base64),
87+
raw_text=self.summary_prompt) for id in range(len(batch))
88+
if isinstance(batch[id], Image)
89+
]
90+
resp = self.model.predict(input_proto)
91+
for i, output in enumerate(resp.outputs):
92+
summary = ""
93+
if image_elements[i].text:
94+
summary = image_elements[i].text
95+
summary = summary + " \n " + output.data.text.raw
96+
eid = batch[i].metadata.input_id
97+
meta_dict = {'source_input_id': eid, 'is_original': False, 'image_summary': 'yes'}
98+
comp_element = CompositeElement(
99+
text=summary,
100+
metadata=ElementMetadata.from_dict(meta_dict),
101+
element_id="summarized_" + eid)
102+
image_summary.append(comp_element)
103+
104+
except Exception as e:
105+
raise e
106+
107+
return image_summary

0 commit comments

Comments
 (0)