11from langchain_community .utilities import SQLDatabase
2- # from langchain_community.agent_toolkits import create_sql_agent
3- from langchain_community .llms import Tongyi
2+ from langgraph .prebuilt import create_react_agent
43from langchain_core .prompts import ChatPromptTemplate
54from apps .chat .schemas .chat_base_schema import LLMConfig , LLMFactory
5+ from apps .datasource .models .datasource import CoreDatasource
6+ from apps .db .db import exec_sql , get_uri
67from common .core .config import settings
78import warnings
9+ from langchain .tools import Tool
10+ from functools import partial
11+ import logging
12+ from typing import AsyncGenerator
13+ import json
14+ import asyncio
815
916warnings .filterwarnings ("ignore" )
1017
@@ -31,3 +38,154 @@ def generate_sql(self, question: str) -> str:
3138 schema = self .db .get_table_info ()
3239 return chain .invoke ({"schema" : schema , "question" : question })
3340
41+
42+ class AgentService :
43+ def __init__ (self , config : LLMConfig , ds : CoreDatasource ):
44+ # Initialize database connection
45+ self .ds = ds
46+ db_uri = get_uri (ds )
47+ self .db = SQLDatabase .from_uri (db_uri )
48+ # self.db = SQLDatabase.from_uri(str(settings.SQLALCHEMY_DATABASE_URI))
49+
50+ # Create LLM instance through factory
51+ llm_instance = LLMFactory .create_llm (config )
52+ self .llm = llm_instance .llm
53+
54+ # Create a partial function of execute_sql with preset ds parameter
55+ # bound_execute_sql = partial(execute_sql, self.ds)
56+ bound_execute_sql = partial (execute_sql_with_db , self .db )
57+
58+ # Wrap as Tool object
59+ tools = [
60+ Tool (
61+ name = "execute_sql" ,
62+ func = bound_execute_sql ,
63+ description = """A tool for executing SQL queries.
64+ Input: SQL query statement (string)
65+ Output: Query results
66+ Example: "SELECT * FROM table_name LIMIT 5"
67+ """
68+ )
69+ ]
70+
71+ self .agent_executor = create_react_agent (self .llm , tools )
72+
73+ system_prompt = """
74+ You are an intelligent agent capable of data analysis. When users input their data analysis requirements,
75+ you need to first convert the requirements into executable SQL, then execute the SQL through tools to return results,
76+ and finally summarize the SQL query results. When all tasks are completed, you need to generate an HTML format data analysis report.
77+
78+ You can analyze requirements step by step to determine the final SQL query to generate.
79+ To improve SQL generation accuracy, please evaluate the accuracy of the SQL after generation,
80+ if there are issues, regenerate the SQL.
81+ When SQL execution fails, you need to correct the SQL based on the error message and try to execute again.
82+
83+ ### Tools ###
84+ execute_sql: Can execute SQL by passing in SQL statements and return execution results
85+ """
86+ user_prompt = """
87+ Below is the database information I need to query:
88+ {schema}
89+
90+ My requirement is: {question}
91+ """
92+ # Define prompt template
93+ self .prompt = ChatPromptTemplate .from_messages ([
94+ ("system" , system_prompt ),
95+ ("human" , user_prompt )
96+ ])
97+
98+ def generate_sql (self , question : str ) -> str :
99+ chain = self .prompt | self .agent_executor
100+ schema = self .db .get_table_info ()
101+ return chain .invoke ({"schema" : schema , "question" : question })
102+
103+ async def async_generate (self , question : str ) -> AsyncGenerator [str , None ]:
104+
105+ chain = self .prompt | self .agent_executor
106+ schema = self .db .get_table_info ()
107+
108+ async for chunk in chain .astream ({"schema" : schema , "question" : question }):
109+ if not isinstance (chunk , dict ):
110+ continue
111+
112+ if "agent" in chunk :
113+ messages = chunk ["agent" ].get ("messages" , [])
114+ for msg in messages :
115+ if tool_calls := msg .additional_kwargs .get ("tool_calls" ):
116+ for tool_call in tool_calls :
117+ response = {
118+ "type" : "tool_call" ,
119+ "tool" : tool_call ["function" ]["name" ],
120+ "args" : tool_call ["function" ]["arguments" ]
121+ }
122+ yield f"data: { json .dumps (response , ensure_ascii = False )} \n \n "
123+
124+ if content := msg .content :
125+ html_start = content .find ("```html" )
126+ html_end = content .find ("```" , html_start + 6 )
127+ if html_start != - 1 and html_end != - 1 :
128+ html_content = content [html_start + 7 :html_end ].strip ()
129+ response = {
130+ "type" : "final" ,
131+ "content" : content .split ("```html" )[0 ].strip (),
132+ "html" : html_content
133+ }
134+ else :
135+ response = {
136+ "type" : "final" ,
137+ "content" : content
138+ }
139+ yield f"data: { json .dumps (response , ensure_ascii = False )} \n \n "
140+
141+ if "tools" in chunk :
142+ messages = chunk ["tools" ].get ("messages" , [])
143+ for msg in messages :
144+ response = {
145+ "type" : "tool_result" ,
146+ "tool" : msg .name ,
147+ "content" : msg .content
148+ }
149+ yield f"data: { json .dumps (response , ensure_ascii = False )} \n \n "
150+
151+ await asyncio .sleep (0.1 )
152+
153+ yield f"data: { json .dumps ({'type' : 'complete' })} \n \n "
154+
155+ def execute_sql (ds : CoreDatasource , sql : str ) -> str :
156+ """Execute SQL query
157+
158+ Args:
159+ ds: Data source instance
160+ sql: SQL query statement
161+
162+ Returns:
163+ Query results
164+ """
165+ print (f"Executing SQL on ds_id { ds .id } : { sql } " )
166+ return exec_sql (ds , sql )
167+
168+ def execute_sql_with_db (db : SQLDatabase , sql : str ) -> str :
169+ """Execute SQL query using SQLDatabase
170+
171+ Args:
172+ db: SQLDatabase instance
173+ sql: SQL query statement
174+
175+ Returns:
176+ str: Query results formatted as string
177+ """
178+ try :
179+ # Execute query
180+ result = db .run (sql )
181+
182+ if not result :
183+ return "Query executed successfully but returned no results."
184+
185+ # Format results
186+ return str (result )
187+
188+ except Exception as e :
189+ error_msg = f"SQL execution failed: { str (e )} "
190+ logging .error (error_msg )
191+ raise RuntimeError (error_msg )
0 commit comments