Skip to content

Commit be133cd

Browse files
committed
✨ image to text tool
1 parent 906ca05 commit be133cd

File tree

8 files changed

+71
-26
lines changed

8 files changed

+71
-26
lines changed

backend/agents/create_agent_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id):
240240
"vdb_core": get_vector_db_core(),
241241
"embedding_model": get_embedding_model(tenant_id=tenant_id),
242242
}
243-
elif tool_config.class_name == "ImageUnderstandingTool":
243+
elif tool_config.class_name == "AnalyzeImageTool":
244244
tool_config.metadata = {
245245
"vlm_model": get_vlm_model(tenant_id=tenant_id),
246246
"storage_client": minio_client,

backend/services/tool_configuration_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def _validate_local_tool(
618618
'embedding_model': embedding_model,
619619
}
620620
tool_instance = tool_class(**params)
621-
elif tool_name == "image_understanding":
621+
elif tool_name == "analyze_image":
622622
if not tenant_id or not user_id:
623623
raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation")
624624
image_to_text_model = get_vlm_model(tenant_id=tenant_id)

sdk/nexent/core/agents/nexent_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def create_local_tool(self, tool_config: ToolConfig):
7171
vdb_core=tool_config.metadata.get("vdb_core", []),
7272
embedding_model=tool_config.metadata.get("embedding_model", []),
7373
**params)
74-
elif class_name == "ImageUnderstandingTool":
74+
elif class_name == "AnalyzeImageTool":
7575
tools_obj = tool_class(observer=self.observer,
7676
vlm_model=tool_config.metadata.get("vlm_model", []),
7777
storage_client=tool_config.metadata.get("storage_client", []),
File renamed without changes.
File renamed without changes.

sdk/nexent/core/tools/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .move_item_tool import MoveItemTool
1313
from .list_directory_tool import ListDirectoryTool
1414
from .terminal_tool import TerminalTool
15-
from .image_understanding_tool import ImageUnderstandingTool
15+
from .analyze_image_tool import AnalyzeImageTool
1616

1717
__all__ = [
1818
"ExaSearchTool",
@@ -29,5 +29,5 @@
2929
"MoveItemTool",
3030
"ListDirectoryTool",
3131
"TerminalTool",
32-
"ImageUnderstandingTool"
32+
"AnalyzeImageTool"
3333
]

sdk/nexent/core/tools/image_understanding_tool.py renamed to sdk/nexent/core/tools/analyze_image_tool.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
from ... import MinIOStorageClient
1414
from ...multi_modal.load_save_object import LoadSaveObjectManager
1515

16-
logger = logging.getLogger("image_understanding_tool")
16+
logger = logging.getLogger("analyze_image_tool")
1717

1818

19-
class ImageUnderstandingTool(Tool):
20-
"""Tool for extracting text from images stored in S3-compatible storage."""
19+
class AnalyzeImageTool(Tool):
20+
"""Tool for understanding and analyzing image"""
2121

22-
name = "image_understanding"
22+
name = "analyze_image"
2323
description = (
24-
"Understand an image stored in S3-compatible storage or HTTP and return the text content inside the image. "
25-
"Provide the object location via an s3:// URL or http:// URL or https:// URL."
24+
"This tool uses a visual language model to understand images based on your query and then returns a description of the image."
25+
"It's used to understand and analyze images stored in S3 buckets, via HTTP and HTTPS."
26+
"Use this tool when you want to retrieve information contained in an image and provide the image's URL and your query."
2627
)
2728
inputs = {
2829
"image_url": {
@@ -45,32 +46,29 @@ def __init__(
4546
observer: MessageObserver = Field(description="Message observer", default=None, exclude=True),
4647
vlm_model: OpenAIVLModel = Field(description="The VLM model to use", default=None, exclude=True),
4748
storage_client: MinIOStorageClient = Field(description="Storage client to use", default=None, exclude=True),
48-
# todo 这么写对不对
49-
system_prompt_template: Template = Field(description="System prompt template to use", default=None, exclude=True),
5049
):
5150
super().__init__()
5251
self.observer = observer
5352
self.vlm_model = vlm_model
5453
self.storage_client = storage_client
55-
self.system_prompt_template = system_prompt_template
5654
# Create LoadSaveObjectManager with the storage client
5755
self.mm = LoadSaveObjectManager(storage_client=self.storage_client)
5856

5957
# Dynamically apply the load_object decorator to forward method
6058
self.forward = self.mm.load_object(input_names=["image_url"])(self._forward_impl)
6159

62-
self.running_prompt_zh = "正在理解图片..."
63-
self.running_prompt_en = "Understanding image..."
60+
self.running_prompt_zh = "正在分析图片..."
61+
self.running_prompt_en = "Analyzing image..."
6462

6563
def _forward_impl(self, image_url: bytes, query: str) -> str:
6664
"""
67-
Analyze the image specified by the S3 URL and return recognized text.
65+
Analyze images of S3 URL, HTTP URL, or HTTPS URL and return the identified text.
6866
6967
Note: This method is wrapped by load_object decorator which downloads
70-
the image from S3 URL and passes bytes to this method.
68+
the image from S3 URL, HTTP URL, or HTTPS URL and passes bytes to this method.
7169
7270
Args:
73-
image_url: Image bytes (converted from S3 URL by decorator).
71+
image_url: Image bytes (converted from S3 URL, HTTP URL, or HTTPS URL by decorator).
7472
7573
Returns:
7674
JSON string containing the recognized text.
@@ -85,23 +83,21 @@ def _forward_impl(self, image_url: bytes, query: str) -> str:
8583
if self.observer:
8684
running_prompt = self.running_prompt_zh if self.observer.lang == "zh" else self.running_prompt_en
8785
self.observer.add_message("", ProcessType.TOOL, running_prompt)
88-
card_content = [{"icon": "image", "text": "Processing image..."}]
86+
card_content = [{"icon": "image", "text": "Analyzing image..."}]
8987
self.observer.add_message("", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False))
9088

9189
# Load prompts from yaml file
92-
prompts = get_prompt_template(template_type='understand_image',language = self.observer.lang)
90+
prompts = get_prompt_template(template_type='analyze_image', language=self.observer.lang)
9391

9492
try:
9593

9694
response = self.vlm_model.analyze_image(
9795
image_input=image_stream,
98-
system_prompt=Template(prompts['system_prompt'],undefined=StrictUndefined).render({'query': query}))
96+
system_prompt=Template(prompts['system_prompt'], undefined=StrictUndefined).render({'query': query}))
9997
except Exception as e:
10098
raise Exception(f"Error understanding image: {str(e)}")
10199
text = response.content
102100
# Record the detailed content of this search
103-
search_results_data = {'text':text}
104-
if self.observer:
105-
search_results_data = json.dumps(search_results_data, ensure_ascii=False)
106-
self.observer.add_message("", ProcessType.SEARCH_CONTENT, search_results_data)
101+
# todo 返回的结构体是什么?
102+
search_results_data = {'text': text}
107103
return json.dumps(search_results_data, ensure_ascii=False)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import logging
2+
import os
3+
from typing import Dict, Any
4+
5+
import yaml
6+
7+
from consts.const import LANGUAGE
8+
9+
logger = logging.getLogger("prompt_template_utils")
10+
11+
# Define template path mapping
12+
template_paths = {
13+
'analyze_image': {
14+
LANGUAGE["ZH"]: 'core/prompts/analyze_image.yaml',
15+
LANGUAGE["EN"]: 'core/prompts/analyze_image_en.yaml'
16+
}
17+
}
18+
19+
def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kwargs) -> Dict[str, Any]:
20+
"""
21+
Get prompt template
22+
23+
Args:
24+
template_type: Template type, supports the following values:
25+
- 'analyze_image': Analyze image template
26+
language: Language code ('zh' or 'en')
27+
**kwargs: Additional parameters, for agent type need to pass is_manager parameter
28+
29+
Returns:
30+
dict: Loaded prompt template
31+
"""
32+
logger.info(
33+
f"Getting prompt template for type: {template_type}, language: {language}, kwargs: {kwargs}")
34+
35+
if template_type not in template_paths:
36+
raise ValueError(f"Unsupported template type: {template_type}")
37+
38+
# Get template path
39+
template_path = template_paths[template_type][language]
40+
41+
# Get the directory of this file and construct absolute path
42+
current_dir = os.path.dirname(os.path.abspath(__file__))
43+
# Go up one level from utils to core, then use the template path
44+
core_dir = os.path.dirname(current_dir)
45+
absolute_template_path = os.path.join(core_dir, template_path.replace('core/', ''))
46+
47+
# Read and return template content
48+
with open(absolute_template_path, 'r', encoding='utf-8') as f:
49+
return yaml.safe_load(f)

0 commit comments

Comments
 (0)