1515
1616from .basetransform import BaseTransform
1717
18- SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \
19- These summaries will be embedded and used to retrieve the raw image. \
20- Give a concise summary of the image that is well optimized for retrieval."""
18+ SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval.
19+ These summaries will be embedded and used to retrieve the raw image.
20+ Give a concise summary of the image that is well optimized for retrieval.
21+ Also add relevant keywords that can be used for search. """
2122
2223
2324class ImageSummarizer (BaseTransform ):
@@ -26,7 +27,8 @@ class ImageSummarizer(BaseTransform):
2627 def __init__ (self ,
2728 model_url : str = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat" ,
2829 pat : str = None ,
29- prompt : str = SUMMARY_PROMPT ):
30+ prompt : str = SUMMARY_PROMPT ,
31+ batch_size : int = 4 ):
3032 """Initializes an ImageSummarizer object.
3133
3234 Args:
@@ -38,6 +40,7 @@ def __init__(self,
3840 self .model_url = model_url
3941 self .model = Model (url = model_url , pat = pat )
4042 self .summary_prompt = prompt
43+ self .batch_size = batch_size
4144
4245 def __call__ (self , elements : List ) -> List :
4346 """Applies the transformation.
@@ -72,31 +75,33 @@ def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement
7275 Summarized image elements list.
7376
7477 """
75- img_inputs = []
76- for element in image_elements :
77- if not isinstance (element , Image ):
78- continue
79- new_input_id = "summarize_" + element .metadata .input_id
80- input_proto = Inputs .get_multimodal_input (
81- input_id = new_input_id ,
82- image_bytes = base64 .b64decode (element .metadata .image_base64 ),
83- raw_text = self .summary_prompt )
84- img_inputs .append (input_proto )
85- resp = self .model .predict (img_inputs )
86- del img_inputs
87-
88- new_elements = []
89- for i , output in enumerate (resp .outputs ):
90- summary = ""
91- if image_elements [i ].text :
92- summary = image_elements [i ].text
93- summary = summary + " \n " + output .data .text .raw
94- eid = image_elements [i ].metadata .input_id
95- meta_dict = {'source_input_id' : eid , 'is_original' : False }
96- comp_element = CompositeElement (
97- text = summary ,
98- metadata = ElementMetadata .from_dict (meta_dict ),
99- element_id = "summarized_" + eid )
100- new_elements .append (comp_element )
101-
102- return new_elements
78+ image_summary = []
79+ try :
80+ for i in range (0 , len (image_elements ), self .batch_size ):
81+ batch = image_elements [i :i + self .batch_size ]
82+
83+ input_proto = [
84+ Inputs .get_multimodal_input (
85+ input_id = batch [id ].metadata .input_id ,
86+ image_bytes = base64 .b64decode (batch [id ].metadata .image_base64 ),
87+ raw_text = self .summary_prompt ) for id in range (len (batch ))
88+ if isinstance (batch [id ], Image )
89+ ]
90+ resp = self .model .predict (input_proto )
91+ for i , output in enumerate (resp .outputs ):
92+ summary = ""
93+ if image_elements [i ].text :
94+ summary = image_elements [i ].text
95+ summary = summary + " \n " + output .data .text .raw
96+ eid = batch [i ].metadata .input_id
97+ meta_dict = {'source_input_id' : eid , 'is_original' : False , 'image_summary' : 'yes' }
98+ comp_element = CompositeElement (
99+ text = summary ,
100+ metadata = ElementMetadata .from_dict (meta_dict ),
101+ element_id = "summarized_" + eid )
102+ image_summary .append (comp_element )
103+
104+ except Exception as e :
105+ raise e
106+
107+ return image_summary
0 commit comments