11import base64
2- import os
32import random
43from typing import List
54
1615
1716from .basetransform import BaseTransform
1817
18+ SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \
19+ These summaries will be embedded and used to retrieve the raw image. \
20+ Give a concise summary of the image that is well optimized for retrieval."""
21+
1922
2023class ImageSummarizer (BaseTransform ):
2124 """ Summarizes image elements. """
2225
23- def __init__ (self , model_url = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat" ):
24- """Initializes an LlamaIndexWrapper object.
26+ def __init__ (self ,
27+ pat ,
28+ model_url = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat" ,
29+ prompt = SUMMARY_PROMPT ):
30+ """Initializes an ImageSummarizer object.
2531
2632 Args:
33+ pat (str): Clarifai PAT.
2734 model_url (str): Model URL to use for summarization.
28-
35+ prompt (str): Prompt to use for summarization.
2936 """
3037 self .model_url = model_url
31- self .model = Model (url = model_url , pat = os .environ .get ("CLARIFAI_PAT" ))
32- self .summary_prompt = """You are an assistant tasked with summarizing images for retrieval. \
33- These summaries will be embedded and used to retrieve the raw image. \
34- Give a concise summary of the image that is well optimized for retrieval."""
38+ self .model = Model (url = model_url , pat = pat )
39+ self .summary_prompt = prompt
3540
3641 def __call__ (self , elements : List ) -> List :
3742 """Applies the transformation.
3843
3944 Args:
40- elements (List[str]): List of text elements.
45+ elements (List[str]): List of all elements.
4146
4247 Returns:
43- List of transformed text elements.
48+ List of transformed elements along with added summarized elements.
4449
4550 """
4651 img_elements = []
4752 for _ , element in enumerate (elements ):
48- element .metadata .update (
49- ElementMetadata .from_dict ({
50- 'is_original' : True
51- }))
53+ element .metadata .update (ElementMetadata .from_dict ({'is_original' : True }))
5254 if isinstance (element , Image ):
5355 element .metadata .update (
54- ElementMetadata .from_dict ({
55- 'input_id' : f'{ random .randint (1000000 , 99999999 )} '
56- }))
56+ ElementMetadata .from_dict ({
57+ 'input_id' : f'{ random .randint (1000000 , 99999999 )} '
58+ }))
5759 img_elements .append (element )
58- # new_elements = Parallel(n_jobs=len(elements))(delayed(self._summarize_image)(element) for element in img_elements)
5960 new_elements = self ._summarize_image (img_elements )
6061 elements .extend (new_elements )
6162 return elements
@@ -64,13 +65,13 @@ def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement
6465 """Summarizes an image element.
6566
6667 Args:
67- image_element ( Image): Image element to summarize.
68+ image_elements (List[ Image] ): Image elements to summarize.
6869
6970 Returns:
70- Summarized image element .
71+ Summarized image elements list .
7172
7273 """
73- img_inputs = []
74+ img_inputs = []
7475 for element in image_elements :
7576 if not isinstance (element , Image ):
7677 continue
@@ -81,6 +82,7 @@ def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement
8182 raw_text = self .summary_prompt )
8283 img_inputs .append (input_proto )
8384 resp = self .model .predict (img_inputs )
85+ del img_inputs
8486
8587 new_elements = []
8688 for i , output in enumerate (resp .outputs ):
0 commit comments