Skip to content

Commit 022a5f9

Browse files
committed
Enhance database agent with schema and query logging features
- Introduced QuestionMeta model to encapsulate question, SQL, and prompt. - Updated _parse_schema_to_knowledge to save DDL and DML records in JSON format. - Modified question_to_sql to return QuestionMeta instead of raw SQL. - Improved error handling in database manager to raise SQLExecutionError. - Refactored prompt generation for question inference to use predefined templates.
1 parent 1229948 commit 022a5f9

File tree

15 files changed

+7632
-325
lines changed

15 files changed

+7632
-325
lines changed

README.md

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cd camel-database-agent
2929
pip install uv ruff mypy
3030
uv venv .venv --python=3.10
3131
source .venv/bin/activate
32-
uv pip install -e ".[dev,test]"
32+
uv sync --all-extras
3333
````
3434

3535
#### Music Database
@@ -141,34 +141,8 @@ Run the Spider 2.0-Lite evaluation.
141141

142142
```shell
143143
cd spider2_lite
144-
export API_KEY=sk-xx
144+
export OPENAI_API_KEY=sk-xxx
145+
export OPENAI_API_BASE_URL=https://api.openai.com/v1/
146+
export MODEL_NAME=gpt-4o-mini
145147
python spider2_run.py
146-
```
147-
148-
## Development
149-
150-
Install the development dependencies.
151-
```shell
152-
pip install uv ruff mypy
153-
uv pip install -e ".[dev]"
154-
```
155-
156-
Run code formatters
157-
```shell
158-
make format
159-
```
160-
161-
Run code linters
162-
```shell
163-
make lint
164-
```
165-
166-
Run unit tests
167-
```shell
168-
make test
169-
```
170-
171-
Create a uv.lock file from pyproject.toml
172-
```shell
173-
uv pip compile pyproject.toml -o uv.lock --resolution=highest
174148
```

camel_database_agent/database/database_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, db_url: str, read_only_model: bool = True):
6161
@with_session
6262
def select(
6363
self, session: Session, sql: str, bind_pd: bool = False
64-
) -> Union[List[dict], pd.DataFrame, SQLExecutionError]:
64+
) -> Union[List[dict], pd.DataFrame]:
6565
"""Execute Query SQL"""
6666
self._check_sql(sql)
6767
try:
@@ -74,7 +74,7 @@ def select(
7474
rows = [dict(zip(column_names, row)) for row in result]
7575
return rows
7676
except OperationalError as e:
77-
return SQLExecutionError(sql, str(e))
77+
raise SQLExecutionError(sql, str(e))
7878

7979
@with_session
8080
def execute(

camel_database_agent/database/database_schema_parse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def __init__(
5555
def parse_ddl_record(self, text: str) -> List[DDLRecord]:
5656
"""Parsing DDL SQL statements"""
5757
prompt = (
58-
"Here are some DDL statements from which you need to "
59-
"refer to table names, field names, data types, default "
60-
"values, etc., to generate summary information and extract "
61-
"the SQL statements for each table.\n\n"
58+
"The following are some DDL script. Please read the script in its "
59+
"entirety and provide descriptions for the tables and fields to "
60+
"generate summary information and extract the SQL script for each "
61+
"table.\n\n"
6262
)
63-
prompt += f"```sql\n{text}```\n"
63+
prompt += f"```sql\n{text}```\n\n"
6464
prompt += "Please output the summary information and SQL script in JSON format."
6565
response = self.parsing_agent.step(prompt, response_format=DDLRecordResponseFormat)
6666
ddl_record_response = DDLRecordResponseFormat.model_validate_json(response.msgs[0].content)

camel_database_agent/database/dialect/database_schema_dialect.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from camel.models import BaseModelBackend
77

88
from camel_database_agent.database.database_manager import DatabaseManager
9+
from camel_database_agent.database_prompt import POLISH_SCHEMA_OUTPUT_EXAMPLE
910

1011
logger = logging.getLogger(__name__)
1112

@@ -50,16 +51,9 @@ def get_dialect(
5051

5152
def get_polished_schema(self, language: str = "English") -> str:
5253
if self.schema_polish_agent:
53-
prompt = (
54-
f"Please optimize the SQL schema of the database in {language}, "
55-
f"ensuring it includes table name comments, field comments, "
56-
f"foreign key explanations, etc., to make it more readable.\n\n"
57-
)
58-
prompt += f"```sql\n{self.get_schema()}```\n\n"
59-
prompt += (
60-
"Now, please directly output the optimized SQL Schema. "
61-
"Do not explain the process and optimization ideas."
62-
)
54+
prompt = POLISH_SCHEMA_OUTPUT_EXAMPLE.replace(
55+
"{{ddl_sql}}", self.get_schema()
56+
).replace("{{language}}", language)
6357
response = self.schema_polish_agent.step(prompt)
6458
return response.msgs[0].content
6559
else:

camel_database_agent/database_agent.py

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import logging
23
import os
34
import random
@@ -26,13 +27,15 @@
2627
HumanMessage,
2728
MessageLog,
2829
MessageLogToEmpty,
30+
SQLExecutionError,
2931
TrainLevel,
3032
messages_log,
3133
strip_sql_code_block,
3234
timing,
3335
)
3436
from camel_database_agent.database_prompt import (
3537
DATABASE_SUMMARY_OUTPUT_EXAMPLE,
38+
QUESTION_CONVERT_SQL,
3639
)
3740
from camel_database_agent.datagen.sql_query_inference_pipeline import (
3841
DataQueryInferencePipeline,
@@ -45,6 +48,12 @@
4548
logger = logging.getLogger(__name__)
4649

4750

51+
class QuestionMeta(BaseModel):
52+
question: str
53+
sql: str
54+
prompt: str
55+
56+
4857
class DatabaseAgentResponse(BaseModel):
4958
ask: str
5059
dataset: Optional[Any] = None
@@ -167,20 +176,41 @@ def __init__(
167176
self.data_sql = f.read()
168177

169178
@timing
170-
def _parse_schema_to_knowledge(self, polish: bool = True) -> None:
179+
def _parse_schema_to_knowledge(self, polish: bool = False) -> None:
171180
"""Generate schema data to knowledge"""
172181
self.ddl_sql = (
173182
self.dialect.get_polished_schema(self.language)
174183
if polish
175184
else self.dialect.get_schema()
176185
)
186+
# Save the schema to a file
187+
with open(
188+
os.path.join(self.knowledge_path, "ddl_origin.sql"),
189+
"w",
190+
encoding="utf-8",
191+
) as f:
192+
f.write(self.dialect.get_schema())
193+
194+
# Save the polished schema to a file
177195
with open(
178196
os.path.join(self.knowledge_path, "ddl_sql.sql"),
179197
"w",
180198
encoding="utf-8",
181199
) as f:
182200
f.write(self.ddl_sql)
201+
183202
ddl_records: List[DDLRecord] = self.schema_parse.parse_ddl_record(self.ddl_sql)
203+
with open(
204+
os.path.join(self.knowledge_path, "ddl_records.json"),
205+
"w",
206+
encoding="utf-8",
207+
) as f:
208+
f.write(
209+
json.dumps(
210+
[record.model_dump() for record in ddl_records], ensure_ascii=False, indent=4
211+
)
212+
)
213+
184214
self.database_knowledge_backend.add(ddl_records)
185215

186216
@timing
@@ -194,6 +224,18 @@ def _parse_sampled_data_to_knowledge(self, data_samples_size: int = 5) -> None:
194224
) as f:
195225
f.write(self.data_sql)
196226
dml_records: List[DMLRecord] = self.schema_parse.parse_dml_record(self.data_sql)
227+
228+
with open(
229+
os.path.join(self.knowledge_path, "data_records.json"),
230+
"w",
231+
encoding="utf-8",
232+
) as f:
233+
f.write(
234+
json.dumps(
235+
[record.model_dump() for record in dml_records], ensure_ascii=False, indent=4
236+
)
237+
)
238+
197239
self.database_knowledge_backend.add(dml_records)
198240

199241
@timing
@@ -210,6 +252,15 @@ def _parse_query_to_knowledge(self, query_samples_size: int = 20) -> None:
210252
query_records: List[QueryRecord] = []
211253
while len(query_records) < query_samples_size:
212254
query_records.extend(pipeline.generate(query_samples_size=query_samples_size))
255+
256+
with open(
257+
os.path.join(self.knowledge_path, "question_sql.txt"),
258+
"w",
259+
encoding="utf-8",
260+
) as f:
261+
for query_record in query_records:
262+
f.write(f"QUESTION: {query_record.question}\nSQL: {query_record.sql}\n\n")
263+
213264
self.database_knowledge_backend.add(query_records)
214265
else:
215266
raise ValueError("ddl_sql and data_sql must be provided")
@@ -292,6 +343,10 @@ def train_knowledge(
292343

293344
if reset_train and os.path.exists(self.knowledge_path):
294345
self.database_knowledge_backend.clear()
346+
self.ddl_sql = None
347+
self.data_sql = None
348+
self.database_summary = ""
349+
self.recommendation_question = ""
295350
logger.info("Reset knowledge...")
296351

297352
if (
@@ -319,41 +374,36 @@ def train_knowledge(
319374
self.generate_database_summary(query_samples_size=query_samples_size)
320375

321376
@timing
322-
def question_to_sql(self, question: str, dialect_name: str) -> str:
377+
def question_to_sql(self, question: str, dialect_name: str) -> QuestionMeta:
323378
"""Question to SQL"""
324-
prompt = (
325-
f"The following is the table structure in the database and "
326-
f"some common query SQL statements. Please convert the user's "
327-
f"question into an SQL query statement. Note to comply "
328-
f"with {dialect_name} syntax. Do not explain, "
329-
f"just provide the SQL directly.\n\n"
330-
)
331-
prompt += "## Table Schema\n"
379+
prompt = QUESTION_CONVERT_SQL.replace("{{dialect_name}}", dialect_name)
380+
332381
ddl_records: List[DDLRecord] = self.database_knowledge_backend.query_ddl(question)
333-
prompt += "```sql\n"
334-
for ddl_record in ddl_records:
335-
prompt += f"{ddl_record.sql}\n"
336-
prompt += "```\n\n"
382+
prompt = prompt.replace(
383+
"{{table_schema}}", "\n".join([record.sql for record in ddl_records])
384+
)
337385

338-
prompt += "## Data Example\n"
339-
prompt += "```sql\n"
340386
data_records: List[DMLRecord] = self.database_knowledge_backend.query_data(question)
341-
for data_record in data_records:
342-
prompt += f"```{data_record.sql}\n"
343-
prompt += "```\n\n"
387+
prompt = prompt.replace(
388+
"{{sample_data}}", "\n".join([record.sql for record in data_records])
389+
)
344390

345-
# some few shot
346391
query_records: List[QueryRecord] = self.database_knowledge_backend.query_query(question)
347-
for query_record in query_records:
348-
prompt += f"Question: {query_record.question}\n"
349-
prompt += f"SQL: {query_record.sql}\n\n"
392+
prompt = prompt.replace(
393+
"{{qa_pairs}}",
394+
"\n".join(
395+
[f"QUESTION: {record.question}\nSQL: {record.sql}\n\n" for record in query_records]
396+
),
397+
)
350398

351-
prompt += f"Question: {question}\n"
352-
prompt += "SQL: "
353-
logger.debug(Fore.GREEN + "PROMPT:", prompt)
399+
prompt = prompt.replace("{{question}}", question)
400+
logger.debug(Fore.GREEN + "PROMPT:" + prompt)
354401
self.agent.reset()
355402
response = self.agent.step(prompt)
356-
return strip_sql_code_block(response.msgs[0].content)
403+
404+
return QuestionMeta(
405+
question=question, sql=strip_sql_code_block(response.msgs[0].content), prompt=prompt
406+
)
357407

358408
@messages_log
359409
def ask(
@@ -366,27 +416,37 @@ def ask(
366416
if not message_log:
367417
message_log = MessageLogToEmpty()
368418
message_log.messages_writer(HumanMessage(session_id=session_id, content=question))
369-
sql = self.question_to_sql(
419+
question_meta = self.question_to_sql(
370420
question=question,
371421
dialect_name=self.database_manager.dialect_name(),
372422
)
373-
message_log.messages_writer(AssistantMessage(session_id=session_id, content=sql))
374423
try:
375-
dataset = self.database_manager.select(sql=sql, bind_pd=bind_pd)
424+
message_log.messages_writer(
425+
AssistantMessage(session_id=session_id, content=question_meta.sql)
426+
)
427+
dataset = self.database_manager.select(sql=question_meta.sql, bind_pd=bind_pd)
376428
message_log.messages_writer(
377429
AssistantMessage(
378430
session_id=session_id,
379431
content=tabulate(dataset, headers="keys", tablefmt="psql"),
380432
)
381433
)
382-
return DatabaseAgentResponse(ask=question, dataset=dataset, sql=sql)
434+
return DatabaseAgentResponse(ask=question, dataset=dataset, sql=question_meta.sql)
435+
except SQLExecutionError as e:
436+
message_log.messages_writer(AssistantMessage(session_id=session_id, content=str(e)))
437+
return DatabaseAgentResponse(
438+
ask=question,
439+
dataset=None,
440+
sql=e.sql,
441+
success=False,
442+
error=e.error_message,
443+
)
383444
except Exception as e:
384-
logger.error(e)
385445
message_log.messages_writer(AssistantMessage(session_id=session_id, content=str(e)))
386446
return DatabaseAgentResponse(
387447
ask=question,
388448
dataset=None,
389-
sql=sql,
449+
sql=question_meta.sql,
390450
success=False,
391451
error=str(e),
392452
)

camel_database_agent/database_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@ def timing_wrapper(*args: Any, **kwargs: Any) -> Any:
102102
try:
103103
spinner_thread.start()
104104
result = func(*args, **kwargs)
105-
end_time = time.perf_counter()
106-
total_time = end_time - start_time
107105
finally:
108106
# sys.stdout.write('\r' + ' ' * 100 + '\r')
109107
stop_spinner.set()
110108
spinner_thread.join()
109+
end_time = time.perf_counter()
110+
total_time = end_time - start_time
111111
logger.info(f"\r{info} Took {Fore.GREEN}{total_time:.4f} seconds{Fore.RESET}")
112112
return result
113113

0 commit comments

Comments
 (0)