Skip to content

Commit c0cbfb4

Browse files
committed
addressed comments
1 parent 09c977c commit c0cbfb4

File tree

2 files changed

+27
-26
lines changed

2 files changed

+27
-26
lines changed

clarifai_datautils/multimodal/pipeline/summarizer.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import base64
2-
import os
32
import random
43
from typing import List
54

@@ -16,46 +15,48 @@
1615

1716
from .basetransform import BaseTransform
1817

18+
SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \
19+
These summaries will be embedded and used to retrieve the raw image. \
20+
Give a concise summary of the image that is well optimized for retrieval."""
21+
1922

2023
class ImageSummarizer(BaseTransform):
2124
""" Summarizes image elements. """
2225

23-
def __init__(self, model_url="https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat"):
24-
"""Initializes an LlamaIndexWrapper object.
26+
def __init__(self,
27+
pat,
28+
model_url="https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat",
29+
prompt=SUMMARY_PROMPT):
30+
"""Initializes an ImageSummarizer object.
2531
2632
Args:
33+
pat (str): Clarifai PAT.
2734
model_url (str): Model URL to use for summarization.
28-
35+
prompt (str): Prompt to use for summarization.
2936
"""
3037
self.model_url = model_url
31-
self.model = Model(url=model_url, pat=os.environ.get("CLARIFAI_PAT"))
32-
self.summary_prompt = """You are an assistant tasked with summarizing images for retrieval. \
33-
These summaries will be embedded and used to retrieve the raw image. \
34-
Give a concise summary of the image that is well optimized for retrieval."""
38+
self.model = Model(url=model_url, pat=pat)
39+
self.summary_prompt = prompt
3540

3641
def __call__(self, elements: List) -> List:
3742
"""Applies the transformation.
3843
3944
Args:
40-
elements (List[str]): List of text elements.
45+
elements (List[str]): List of all elements.
4146
4247
Returns:
43-
List of transformed text elements.
48+
List of transformed elements along with added summarized elements.
4449
4550
"""
4651
img_elements = []
4752
for _, element in enumerate(elements):
48-
element.metadata.update(
49-
ElementMetadata.from_dict({
50-
'is_original': True
51-
}))
53+
element.metadata.update(ElementMetadata.from_dict({'is_original': True}))
5254
if isinstance(element, Image):
5355
element.metadata.update(
54-
ElementMetadata.from_dict({
55-
'input_id': f'{random.randint(1000000, 99999999)}'
56-
}))
56+
ElementMetadata.from_dict({
57+
'input_id': f'{random.randint(1000000, 99999999)}'
58+
}))
5759
img_elements.append(element)
58-
# new_elements = Parallel(n_jobs=len(elements))(delayed(self._summarize_image)(element) for element in img_elements)
5960
new_elements = self._summarize_image(img_elements)
6061
elements.extend(new_elements)
6162
return elements
@@ -64,13 +65,13 @@ def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement
6465
"""Summarizes an image element.
6566
6667
Args:
67-
image_element (Image): Image element to summarize.
68+
image_elements (List[Image]): Image elements to summarize.
6869
6970
Returns:
70-
Summarized image element.
71+
Summarized image elements list.
7172
7273
"""
73-
img_inputs = []
74+
img_inputs = []
7475
for element in image_elements:
7576
if not isinstance(element, Image):
7677
continue
@@ -81,6 +82,7 @@ def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement
8182
raw_text=self.summary_prompt)
8283
img_inputs.append(input_proto)
8384
resp = self.model.predict(img_inputs)
85+
del img_inputs
8486

8587
new_elements = []
8688
for i, output in enumerate(resp.outputs):

tests/pipelines/test_multimodal_pipelines.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import os.path as osp
22

3-
import pytest
4-
53
PDF_FILE_PATH = osp.abspath(
64
osp.join(osp.dirname(__file__), "assets", "Multimodal_sample_file.pdf"))
75

86

9-
@pytest.mark.skip(reason="Need additional build dependencies")
7+
# @pytest.mark.skip(reason="Need additional build dependencies")
108
class TestMultimodalPipelines:
119
"""Tests for pipeline transformations."""
1210

@@ -69,18 +67,19 @@ def test_pipeline_run_loader(self,):
6967

7068
def test_pipeline_summarize(self,):
7169
"""Tests for pipeline run with summarizer"""
70+
import os
71+
7272
from clarifai_datautils.multimodal import Pipeline
7373
from clarifai_datautils.multimodal.pipeline.cleaners import Clean_extra_whitespace
7474
from clarifai_datautils.multimodal.pipeline.PDF import PDFPartitionMultimodal
7575
from clarifai_datautils.multimodal.pipeline.summarizer import ImageSummarizer
76-
import os
7776

7877
pipeline = Pipeline(
7978
name='pipeline-1',
8079
transformations=[
8180
PDFPartitionMultimodal(chunking_strategy="by_title", max_characters=1024),
8281
Clean_extra_whitespace(),
83-
ImageSummarizer()
82+
ImageSummarizer(pat=os.environ.get("CLARIFAI_PAT"))
8483
])
8584
elements = pipeline.run(files=PDF_FILE_PATH, loader=False)
8685

0 commit comments

Comments
 (0)