1+ import json
12import logging
23import os
34import random
2627 HumanMessage ,
2728 MessageLog ,
2829 MessageLogToEmpty ,
30+ SQLExecutionError ,
2931 TrainLevel ,
3032 messages_log ,
3133 strip_sql_code_block ,
3234 timing ,
3335)
3436from camel_database_agent .database_prompt import (
3537 DATABASE_SUMMARY_OUTPUT_EXAMPLE ,
38+ QUESTION_CONVERT_SQL ,
3639)
3740from camel_database_agent .datagen .sql_query_inference_pipeline import (
3841 DataQueryInferencePipeline ,
4548logger = logging .getLogger (__name__ )
4649
4750
51+ class QuestionMeta (BaseModel ):
52+ question : str
53+ sql : str
54+ prompt : str
55+
56+
4857class DatabaseAgentResponse (BaseModel ):
4958 ask : str
5059 dataset : Optional [Any ] = None
@@ -167,20 +176,41 @@ def __init__(
167176 self .data_sql = f .read ()
168177
169178 @timing
170- def _parse_schema_to_knowledge (self , polish : bool = True ) -> None :
179+ def _parse_schema_to_knowledge (self , polish : bool = False ) -> None :
171180 """Generate schema data to knowledge"""
172181 self .ddl_sql = (
173182 self .dialect .get_polished_schema (self .language )
174183 if polish
175184 else self .dialect .get_schema ()
176185 )
186+ # Save the schema to a file
187+ with open (
188+ os .path .join (self .knowledge_path , "ddl_origin.sql" ),
189+ "w" ,
190+ encoding = "utf-8" ,
191+ ) as f :
192+ f .write (self .dialect .get_schema ())
193+
194+ # Save the polished schema to a file
177195 with open (
178196 os .path .join (self .knowledge_path , "ddl_sql.sql" ),
179197 "w" ,
180198 encoding = "utf-8" ,
181199 ) as f :
182200 f .write (self .ddl_sql )
201+
183202 ddl_records : List [DDLRecord ] = self .schema_parse .parse_ddl_record (self .ddl_sql )
203+ with open (
204+ os .path .join (self .knowledge_path , "ddl_records.json" ),
205+ "w" ,
206+ encoding = "utf-8" ,
207+ ) as f :
208+ f .write (
209+ json .dumps (
210+ [record .model_dump () for record in ddl_records ], ensure_ascii = False , indent = 4
211+ )
212+ )
213+
184214 self .database_knowledge_backend .add (ddl_records )
185215
186216 @timing
@@ -194,6 +224,18 @@ def _parse_sampled_data_to_knowledge(self, data_samples_size: int = 5) -> None:
194224 ) as f :
195225 f .write (self .data_sql )
196226 dml_records : List [DMLRecord ] = self .schema_parse .parse_dml_record (self .data_sql )
227+
228+ with open (
229+ os .path .join (self .knowledge_path , "data_records.json" ),
230+ "w" ,
231+ encoding = "utf-8" ,
232+ ) as f :
233+ f .write (
234+ json .dumps (
235+ [record .model_dump () for record in dml_records ], ensure_ascii = False , indent = 4
236+ )
237+ )
238+
197239 self .database_knowledge_backend .add (dml_records )
198240
199241 @timing
@@ -210,6 +252,15 @@ def _parse_query_to_knowledge(self, query_samples_size: int = 20) -> None:
210252 query_records : List [QueryRecord ] = []
211253 while len (query_records ) < query_samples_size :
212254 query_records .extend (pipeline .generate (query_samples_size = query_samples_size ))
255+
256+ with open (
257+ os .path .join (self .knowledge_path , "question_sql.txt" ),
258+ "w" ,
259+ encoding = "utf-8" ,
260+ ) as f :
261+ for query_record in query_records :
262+ f .write (f"QUESTION: { query_record .question } \n SQL: { query_record .sql } \n \n " )
263+
213264 self .database_knowledge_backend .add (query_records )
214265 else :
215266 raise ValueError ("ddl_sql and data_sql must be provided" )
@@ -292,6 +343,10 @@ def train_knowledge(
292343
293344 if reset_train and os .path .exists (self .knowledge_path ):
294345 self .database_knowledge_backend .clear ()
346+ self .ddl_sql = None
347+ self .data_sql = None
348+ self .database_summary = ""
349+ self .recommendation_question = ""
295350 logger .info ("Reset knowledge..." )
296351
297352 if (
@@ -319,41 +374,36 @@ def train_knowledge(
319374 self .generate_database_summary (query_samples_size = query_samples_size )
320375
321376 @timing
322- def question_to_sql (self , question : str , dialect_name : str ) -> str :
377+ def question_to_sql (self , question : str , dialect_name : str ) -> QuestionMeta :
323378 """Question to SQL"""
324- prompt = (
325- f"The following is the table structure in the database and "
326- f"some common query SQL statements. Please convert the user's "
327- f"question into an SQL query statement. Note to comply "
328- f"with { dialect_name } syntax. Do not explain, "
329- f"just provide the SQL directly.\n \n "
330- )
331- prompt += "## Table Schema\n "
379+ prompt = QUESTION_CONVERT_SQL .replace ("{{dialect_name}}" , dialect_name )
380+
332381 ddl_records : List [DDLRecord ] = self .database_knowledge_backend .query_ddl (question )
333- prompt += "```sql\n "
334- for ddl_record in ddl_records :
335- prompt += f"{ ddl_record .sql } \n "
336- prompt += "```\n \n "
382+ prompt = prompt .replace (
383+ "{{table_schema}}" , "\n " .join ([record .sql for record in ddl_records ])
384+ )
337385
338- prompt += "## Data Example\n "
339- prompt += "```sql\n "
340386 data_records : List [DMLRecord ] = self .database_knowledge_backend .query_data (question )
341- for data_record in data_records :
342- prompt += f"``` { data_record . sql } \n "
343- prompt += "``` \n \n "
387+ prompt = prompt . replace (
388+ "{{sample_data}}" , " \n ". join ([ record . sql for record in data_records ])
389+ )
344390
345- # some few shot
346391 query_records : List [QueryRecord ] = self .database_knowledge_backend .query_query (question )
347- for query_record in query_records :
348- prompt += f"Question: { query_record .question } \n "
349- prompt += f"SQL: { query_record .sql } \n \n "
392+ prompt = prompt .replace (
393+ "{{qa_pairs}}" ,
394+ "\n " .join (
395+ [f"QUESTION: { record .question } \n SQL: { record .sql } \n \n " for record in query_records ]
396+ ),
397+ )
350398
351- prompt += f"Question: { question } \n "
352- prompt += "SQL: "
353- logger .debug (Fore .GREEN + "PROMPT:" , prompt )
399+ prompt = prompt .replace ("{{question}}" , question )
400+ logger .debug (Fore .GREEN + "PROMPT:" + prompt )
354401 self .agent .reset ()
355402 response = self .agent .step (prompt )
356- return strip_sql_code_block (response .msgs [0 ].content )
403+
404+ return QuestionMeta (
405+ question = question , sql = strip_sql_code_block (response .msgs [0 ].content ), prompt = prompt
406+ )
357407
358408 @messages_log
359409 def ask (
@@ -366,27 +416,37 @@ def ask(
366416 if not message_log :
367417 message_log = MessageLogToEmpty ()
368418 message_log .messages_writer (HumanMessage (session_id = session_id , content = question ))
369- sql = self .question_to_sql (
419+ question_meta = self .question_to_sql (
370420 question = question ,
371421 dialect_name = self .database_manager .dialect_name (),
372422 )
373- message_log .messages_writer (AssistantMessage (session_id = session_id , content = sql ))
374423 try :
375- dataset = self .database_manager .select (sql = sql , bind_pd = bind_pd )
424+ message_log .messages_writer (
425+ AssistantMessage (session_id = session_id , content = question_meta .sql )
426+ )
427+ dataset = self .database_manager .select (sql = question_meta .sql , bind_pd = bind_pd )
376428 message_log .messages_writer (
377429 AssistantMessage (
378430 session_id = session_id ,
379431 content = tabulate (dataset , headers = "keys" , tablefmt = "psql" ),
380432 )
381433 )
382- return DatabaseAgentResponse (ask = question , dataset = dataset , sql = sql )
434+ return DatabaseAgentResponse (ask = question , dataset = dataset , sql = question_meta .sql )
435+ except SQLExecutionError as e :
436+ message_log .messages_writer (AssistantMessage (session_id = session_id , content = str (e )))
437+ return DatabaseAgentResponse (
438+ ask = question ,
439+ dataset = None ,
440+ sql = e .sql ,
441+ success = False ,
442+ error = e .error_message ,
443+ )
383444 except Exception as e :
384- logger .error (e )
385445 message_log .messages_writer (AssistantMessage (session_id = session_id , content = str (e )))
386446 return DatabaseAgentResponse (
387447 ask = question ,
388448 dataset = None ,
389- sql = sql ,
449+ sql = question_meta . sql ,
390450 success = False ,
391451 error = str (e ),
392452 )
0 commit comments