11import json
22import logging
3- from typing import List , Optional
3+ from typing import Optional
44
5- import httpx
65from pydantic import Field
76from smolagents .tools import Tool
87
8+ from ...datamate import DataMateClient
99from ..utils .observer import MessageObserver , ProcessType
1010from ..utils .tools_common_message import SearchResultTextMessage , ToolCategory , ToolSign
1111
12-
1312# Get logger instance
1413logger = logging .getLogger ("datamate_search_tool" )
1514
@@ -61,10 +60,10 @@ class DataMateSearchTool(Tool):
6160 tool_sign = ToolSign .DATAMATE_KNOWLEDGE_BASE .value
6261
6362 def __init__ (
64- self ,
65- server_ip : str = Field (description = "DataMate server IP or hostname" ),
66- server_port : int = Field (description = "DataMate server port" ),
67- observer : MessageObserver = Field (description = "Message observer" , default = None , exclude = True ),
63+ self ,
64+ server_ip : str = Field (description = "DataMate server IP or hostname" ),
65+ server_port : int = Field (description = "DataMate server port" ),
66+ observer : MessageObserver = Field (description = "Message observer" , default = None , exclude = True ),
6867 ):
6968 """Initialize the DataMateSearchTool.
7069
@@ -88,6 +87,9 @@ def __init__(
8887 # Build base URL: http://host:port
8988 self .server_base_url = f"http://{ self .server_ip } :{ self .server_port } " .rstrip ("/" )
9089
90+ # Initialize DataMate SDK client
91+ self .datamate_client = DataMateClient (base_url = self .server_base_url )
92+
9193 self .kb_page = 0
9294 self .kb_page_size = 20
9395 self .observer = observer
@@ -97,12 +99,12 @@ def __init__(
9799 self .running_prompt_en = "Searching the DataMate knowledge base..."
98100
99101 def forward (
100- self ,
101- query : str ,
102- top_k : int = 10 ,
103- threshold : float = 0.2 ,
104- kb_page : int = 0 ,
105- kb_page_size : int = 20 ,
102+ self ,
103+ query : str ,
104+ top_k : int = 10 ,
105+ threshold : float = 0.2 ,
106+ kb_page : int = 0 ,
107+ kb_page_size : int = 20 ,
106108 ) -> str :
107109 """Execute DataMate search.
108110
@@ -130,17 +132,37 @@ def forward(
130132 )
131133
132134 try :
133- # Step 1: Get knowledge base list
134- knowledge_base_ids = self ._get_knowledge_base_list ()
135+ # Step 1: Get knowledge base list using SDK
136+ knowledge_bases = self .datamate_client .list_knowledge_bases (
137+ page = self .kb_page ,
138+ size = self .kb_page_size
139+ )
140+
141+ # Extract knowledge base IDs
142+ knowledge_base_ids = []
143+ for kb in knowledge_bases :
144+ kb_id = kb .get ("id" )
145+ chunk_count = kb .get ("chunkCount" )
146+ if kb_id and chunk_count :
147+ knowledge_base_ids .append (str (kb_id ))
148+
135149 if not knowledge_base_ids :
136150 return json .dumps ("No knowledge base found. No relevant information found." , ensure_ascii = False )
137151
138- # Step 2: Retrieve knowledge base content
139- kb_search_results = self . _retrieve_knowledge_base_content ( query , knowledge_base_ids , top_k , threshold
140- )
152+ # Step 2: Retrieve knowledge base content using SDK
153+ kb_search_results = []
154+ for knowledge_base_id in knowledge_base_ids :
141155
142- if not kb_search_results :
143- raise Exception ("No results found! Try a less restrictive/shorter query." )
156+ kb_search = self .datamate_client .retrieve_knowledge_base (
157+ query = query ,
158+ knowledge_base_ids = [knowledge_base_id ],
159+ top_k = top_k ,
160+ threshold = threshold
161+ )
162+
163+ if not kb_search :
164+ raise Exception ("No results found! Try a less restrictive/shorter query." )
165+ kb_search_results .extend (kb_search )
144166
145167 # Format search results
146168 search_results_json = [] # Organize search results into a unified format
@@ -150,8 +172,8 @@ def forward(
150172 entity_data = single_search_result .get ("entity" , {})
151173 metadata = self ._parse_metadata (entity_data .get ("metadata" ))
152174 dataset_id = self ._extract_dataset_id (metadata .get ("absolute_directory_path" , "" ))
153- file_id = entity_data .get ("id " )
154- download_url = self ._build_file_download_url (dataset_id , file_id )
175+ file_id = metadata .get ("original_file_id " )
176+ download_url = self .datamate_client . build_file_download_url (dataset_id , file_id )
155177
156178 score_details = entity_data .get ("scoreDetails" , {}) or {}
157179 score_details .update ({
@@ -191,100 +213,6 @@ def forward(
191213 logger .error (error_msg )
192214 raise Exception (error_msg )
193215
194- def _get_knowledge_base_list (self ) -> List [str ]:
195- """Get knowledge base list from DataMate API.
196-
197- Returns:
198- List[str]: List of knowledge base IDs.
199- """
200- try :
201- url = f"{ self .server_base_url } /api/knowledge-base/list"
202- payload = {"page" : self .kb_page , "size" : self .kb_page_size }
203-
204- with httpx .Client (timeout = 30 ) as client :
205- response = client .post (url , json = payload )
206-
207- if response .status_code != 200 :
208- error_detail = (
209- response .json ().get ("detail" , "unknown error" )
210- if response .headers .get ("content-type" , "" ).startswith ("application/json" )
211- else response .text
212- )
213- raise Exception (f"Failed to get knowledge base list (status { response .status_code } ): { error_detail } " )
214-
215- result = response .json ()
216- # Extract knowledge base IDs from response
217- # Assuming the response structure contains a list of knowledge bases with 'id' field
218- data = result .get ("data" , {})
219- knowledge_bases = data .get ("content" , []) if data else []
220-
221- knowledge_base_ids = []
222- for kb in knowledge_bases :
223- kb_id = kb .get ("id" )
224- chunk_count = kb .get ("chunkCount" )
225- if kb_id and chunk_count :
226- knowledge_base_ids .append (str (kb_id ))
227-
228- logger .info (f"Retrieved { len (knowledge_base_ids )} knowledge base(s): { knowledge_base_ids } " )
229- return knowledge_base_ids
230-
231- except httpx .TimeoutException :
232- raise Exception ("Timeout while getting knowledge base list from DataMate API" )
233- except httpx .RequestError as e :
234- raise Exception (f"Request error while getting knowledge base list: { str (e )} " )
235- except Exception as e :
236- raise Exception (f"Error getting knowledge base list: { str (e )} " )
237-
238- def _retrieve_knowledge_base_content (
239- self , query : str , knowledge_base_ids : List [str ], top_k : int , threshold : float
240- ) -> List [dict ]:
241- """Retrieve knowledge base content from DataMate API.
242-
243- Args:
244- query (str): Search query.
245- knowledge_base_ids (List[str]): List of knowledge base IDs to search.
246- top_k (int): Maximum number of results to return.
247- threshold (float): Similarity threshold.
248-
249- Returns:
250- List[dict]: List of search results.
251- """
252- search_results = []
253- for knowledge_base_id in knowledge_base_ids :
254- try :
255- url = f"{ self .server_base_url } /api/knowledge-base/retrieve"
256- payload = {
257- "query" : query ,
258- "topK" : top_k ,
259- "threshold" : threshold ,
260- "knowledgeBaseIds" : [knowledge_base_id ],
261- }
262-
263- with httpx .Client (timeout = 60 ) as client :
264- response = client .post (url , json = payload )
265-
266- if response .status_code != 200 :
267- error_detail = (
268- response .json ().get ("detail" , "unknown error" )
269- if response .headers .get ("content-type" , "" ).startswith ("application/json" )
270- else response .text
271- )
272- raise Exception (
273- f"Failed to retrieve knowledge base content (status { response .status_code } ): { error_detail } " )
274-
275- result = response .json ()
276- # Extract search results from response
277- for data in result .get ("data" , {}):
278- search_results .append (data )
279- except httpx .TimeoutException :
280- raise Exception ("Timeout while retrieving knowledge base content from DataMate API" )
281- except httpx .RequestError as e :
282- raise Exception (f"Request error while retrieving knowledge base content: { str (e )} " )
283- except Exception as e :
284- raise Exception (f"Error retrieving knowledge base content: { str (e )} " )
285- logger .info (f"Retrieved { len (search_results )} search result(s)" )
286- return search_results
287-
288216 @staticmethod
289217 def _parse_metadata (metadata_raw : Optional [str ]) -> dict :
290218 """Parse metadata payload safely."""
@@ -304,10 +232,4 @@ def _extract_dataset_id(absolute_path: str) -> str:
304232 if not absolute_path :
305233 return ""
306234 segments = [segment for segment in absolute_path .strip ("/" ).split ("/" ) if segment ]
307- return segments [- 1 ] if segments else ""
308-
309- def _build_file_download_url (self , dataset_id : str , file_id : str ) -> str :
310- """Build the download URL for a dataset file."""
311- if not (self .server_ip and dataset_id and file_id ):
312- return ""
313- return f"{ self .server_ip } /api/data-management/datasets/{ dataset_id } /files/{ file_id } /download"
235+ return segments [- 1 ] if segments else ""
0 commit comments