|
2 | 2 | import logging |
3 | 3 | from io import BytesIO |
4 | 4 |
|
5 | | -from jinja2 import Template |
| 5 | +from jinja2 import Template, StrictUndefined |
6 | 6 | from pydantic import Field |
7 | 7 | from smolagents.tools import Tool |
8 | 8 |
|
9 | 9 | from ..models.openai_vlm import OpenAIVLModel |
10 | 10 | from ..utils.observer import MessageObserver, ProcessType |
| 11 | +from ..utils.prompt_template_utils import get_prompt_template |
11 | 12 | from ..utils.tools_common_message import ToolCategory, ToolSign |
12 | 13 | from ... import MinIOStorageClient |
13 | 14 | from ...multi_modal.load_save_object import LoadSaveObjectManager |
@@ -50,21 +51,16 @@ def __init__( |
50 | 51 | super().__init__() |
51 | 52 | self.observer = observer |
52 | 53 | self.vlm_model = vlm_model |
53 | | - # Use provided storage_client or create a default one |
54 | | - # if storage_client is None: |
55 | | - # storage_client = create_storage_client_from_config() |
56 | 54 | self.storage_client = storage_client |
57 | 55 | self.system_prompt_template = system_prompt_template |
58 | | - |
59 | | - |
60 | 56 | # Create LoadSaveObjectManager with the storage client |
61 | 57 | self.mm = LoadSaveObjectManager(storage_client=self.storage_client) |
62 | 58 |
|
63 | 59 | # Dynamically apply the load_object decorator to forward method |
64 | 60 | self.forward = self.mm.load_object(input_names=["image_url"])(self._forward_impl) |
65 | 61 |
|
66 | | - self.running_prompt_zh = "正在分析图片文字..." |
67 | | - self.running_prompt_en = "Analyzing image text..." |
| 62 | + self.running_prompt_zh = "正在理解图片..." |
| 63 | + self.running_prompt_en = "Understanding image..." |
68 | 64 |
|
69 | 65 | def _forward_impl(self, image_url: bytes, query: str) -> str: |
70 | 66 | """ |
@@ -92,15 +88,20 @@ def _forward_impl(self, image_url: bytes, query: str) -> str: |
92 | 88 | card_content = [{"icon": "image", "text": "Processing image..."}] |
93 | 89 | self.observer.add_message("", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False)) |
94 | 90 |
|
95 | | - # # Load messages based on language |
96 | | - # messages = get_file_processing_messages_template(language) |
| 91 | + # Load prompts from yaml file |
| 92 | + prompts = get_prompt_template(template_type='understand_image',language = self.observer.lang) |
97 | 93 |
|
98 | 94 | try: |
99 | | - text = self.vlm_model.analyze_image( |
| 95 | + |
| 96 | + response = self.vlm_model.analyze_image( |
100 | 97 | image_input=image_stream, |
101 | | - system_prompt=self.system_prompt_template.render({'query': query})).content |
102 | | - return text |
103 | | - # return messages["IMAGE_CONTENT_SUCCESS"].format(filename=filename, content=text) |
| 98 | + system_prompt=Template(prompts['system_prompt'],undefined=StrictUndefined).render({'query': query})) |
104 | 99 | except Exception as e: |
105 | | - raise e |
106 | | - |
| 100 | + raise Exception(f"Error understanding image: {str(e)}") |
| 101 | + text = response.content |
| 102 | + # Record the detailed content of this search |
| 103 | + search_results_data = {'text':text} |
| 104 | + if self.observer: |
| 105 | + search_results_data = json.dumps(search_results_data, ensure_ascii=False) |
| 106 | + self.observer.add_message("", ProcessType.SEARCH_CONTENT, search_results_data) |
| 107 | + return json.dumps(search_results_data, ensure_ascii=False) |
0 commit comments