Skip to content

Commit 4220387

Browse files
committed
Refactor database agent methods to return token usage and improve schema parsing response structure
1 parent 11b848a commit 4220387

18 files changed

+198
-75
lines changed

camel_database_agent/cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,13 @@ def main() -> None:
151151
language=args.language,
152152
data_path=data_path,
153153
)
154-
database_agent.train_knowledge(
154+
token_usage = database_agent.train_knowledge(
155155
level=TrainLevel.MEDIUM,
156156
reset_train=args.reset_train,
157157
)
158158

159+
print(f"{Fore.GREEN}")
160+
print("=" * 50)
159161
print(f"{Fore.GREEN}Database Overview")
160162
print("=" * 50)
161163
print(f"{database_agent.get_summary()}")
@@ -173,6 +175,7 @@ def main() -> None:
173175
f"{Fore.CYAN}Type {Fore.LIGHTYELLOW_EX}'help'{Fore.RESET} "
174176
f"to get more recommended questions"
175177
)
178+
print(f"{Fore.CYAN}Training completed, using {token_usage.total_tokens} tokens{Fore.RESET}")
176179
print(f"{Fore.CYAN}=" * 50)
177180

178181
session_id = str(uuid.uuid4())

camel_database_agent/database/database_schema_parse.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import List, Union
1+
import textwrap
2+
from typing import List, Optional, TypeVar, Union
23

34
from camel.agents import ChatAgent
45
from camel.models import BaseModelBackend
@@ -17,7 +18,7 @@ class DDLRecord(BaseModel):
1718
class DMLRecord(BaseModel):
1819
id: str
1920
summary: str
20-
sql: str
21+
dataset: str
2122

2223

2324
class QueryRecord(BaseModel):
@@ -26,6 +27,14 @@ class QueryRecord(BaseModel):
2627
sql: str
2728

2829

30+
RecordType = TypeVar("RecordType", DDLRecord, DMLRecord, QueryRecord)
31+
32+
33+
class SchemaParseResponse(BaseModel):
34+
data: List[RecordType]
35+
usage: Optional[dict]
36+
37+
2938
class DDLRecordResponseFormat(BaseModel):
3039
items: List[DDLRecord]
3140

@@ -52,36 +61,74 @@ def __init__(
5261
)
5362

5463
@timing
55-
def parse_ddl_record(self, text: str) -> List[DDLRecord]:
64+
def parse_ddl_record(self, text: str) -> SchemaParseResponse:
5665
"""Parsing DDL SQL statements"""
5766
prompt = (
58-
"The following are some DDL script. Please read the script in its "
59-
"entirety and provide descriptions for the tables and fields to "
60-
"generate summary information and extract the SQL script for each "
61-
"table.\n\n"
67+
"Translate the following information into a JSON array format, "
68+
"with each JSON object in the array containing three "
69+
"elements: "
70+
"\"id\" for the table name, "
71+
"\"summary\" for a summary of the table, and "
72+
"\"sql\" for the SQL statement of the table creation.\n\n"
6273
)
63-
prompt += f"```sql\n{text}```\n\n"
64-
prompt += "Please output the summary information and SQL script in JSON format."
74+
if text.startswith("```sql"):
75+
prompt += f"{text}\n\n"
76+
else:
77+
prompt += f"```sql\n{text}```\n\n"
78+
79+
# 非 openai 模型要增加以下片段
80+
prompt += textwrap.dedent(
81+
"Output Format:\n"
82+
"{"
83+
" \"items\":"
84+
" ["
85+
" {"
86+
" \"id\": \"<table name>\","
87+
" \"summary\": \"<table summary>\","
88+
" \"sql\": \"<table ddl script>\""
89+
" }"
90+
" ]"
91+
"}\n\n"
92+
)
93+
prompt += "Now, directly output the JSON array without explanation."
6594
response = self.parsing_agent.step(prompt, response_format=DDLRecordResponseFormat)
6695
ddl_record_response = DDLRecordResponseFormat.model_validate_json(response.msgs[0].content)
67-
return ddl_record_response.items
96+
return SchemaParseResponse(data=ddl_record_response.items, usage=response.info["usage"])
6897

6998
@timing
70-
def parse_dml_record(self, text: str) -> List[DMLRecord]:
99+
def parse_dml_record(self, text: str) -> SchemaParseResponse:
71100
"""Parsing DML SQL statements"""
72101
prompt = (
73-
"The following are some DML statements from which you need "
74-
"to extract table names, field names, and generate summary "
75-
"information, as well as extract each SQL statement.\n\n"
102+
"Translate the following information into a JSON array format, "
103+
"with each JSON object in the array containing three "
104+
"elements: "
105+
"\"id\" for the table name, "
106+
"\"summary\" for a summary of the table, and "
107+
"\"dataset\" for the Markdown of the data.\n\n"
76108
)
77-
prompt += f"```sql\n{text}```\n"
78-
prompt += "Please output the summary information and SQL script in JSON format."
109+
prompt += f"{text}\n\n"
110+
111+
# 非 openai 模型要增加以下片段
112+
prompt += textwrap.dedent(
113+
"Output Format:\n"
114+
"{"
115+
" \"items\":"
116+
" ["
117+
" {"
118+
" \"id\": \"<table name>\","
119+
" \"summary\": \"<table summary>\","
120+
" \"dataset\": \"<markdown dataset>\""
121+
" }"
122+
" ]"
123+
"}\n\n"
124+
)
125+
prompt += "Now, directly output the JSON array without explanation."
79126
response = self.parsing_agent.step(prompt, response_format=DMLRecordResponseFormat)
80127
dml_record_response = DMLRecordResponseFormat.model_validate_json(response.msgs[0].content)
81-
return dml_record_response.items
128+
return SchemaParseResponse(data=dml_record_response.items, usage=response.info["usage"])
82129

83130
@timing
84-
def parse_query_record(self, text: str) -> List[QueryRecord]:
131+
def parse_query_record(self, text: str) -> SchemaParseResponse:
85132
"""Parsing Query SQL statements"""
86133
prompt = (
87134
"The following is an analysis of user query requirements, "
@@ -94,4 +141,4 @@ def parse_query_record(self, text: str) -> List[QueryRecord]:
94141
query_record_response = QueryRecordResponseFormat.model_validate_json(
95142
response.msgs[0].content
96143
)
97-
return query_record_response.items
144+
return SchemaParseResponse(data=query_record_response.items, usage=response.info["usage"])

camel_database_agent/database/dialect/database_schema_dialect.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from camel.agents import ChatAgent
66
from camel.models import BaseModelBackend
7+
from tabulate import tabulate
78

89
from camel_database_agent.database.database_manager import DatabaseManager
910
from camel_database_agent.database_prompt import POLISH_SCHEMA_OUTPUT_EXAMPLE
@@ -81,38 +82,40 @@ def get_sampled_data(self, data_samples_size: int = 5) -> str:
8182
Must be implemented by all dialect subclasses.
8283
"""
8384
metadata = self.database_manager.get_metadata()
84-
sample_data_sql = []
85+
sample_data = []
8586

8687
for table_name in metadata.tables:
87-
table = metadata.tables[table_name]
88-
column_names = [column.name for column in table.columns]
88+
# table = metadata.tables[table_name]
89+
# column_names = [column.name for column in table.columns]
8990

9091
sample_query = f"SELECT * FROM {table_name} LIMIT {data_samples_size}"
9192
try:
9293
rows = self.database_manager.select(sample_query)
93-
for row in rows:
94-
columns = []
95-
values = []
96-
97-
for col_name in column_names:
98-
if col_name in row and row[col_name] is not None:
99-
columns.append(col_name)
100-
if isinstance(row[col_name], str):
101-
values.append("'" + row[col_name].replace("'", "''") + "'")
102-
elif isinstance(row[col_name], (int, float)):
103-
values.append(str(row[col_name]))
104-
else:
105-
values.append(f"'{row[col_name]!s}'")
106-
107-
if columns and values:
108-
columns_stmt = ', '.join(columns)
109-
values_stmt = ', '.join(values)
110-
insert_stmt = (
111-
f"INSERT INTO {table_name} ({columns_stmt}) VALUES ({values_stmt});"
112-
)
113-
sample_data_sql.append(insert_stmt)
94+
dataset = tabulate(tabular_data=rows, headers='keys', tablefmt='psql')
95+
sample_data.append(f"## {table_name}\n\n{dataset}")
96+
# for row in rows:
97+
# columns = []
98+
# values = []
99+
#
100+
# for col_name in column_names:
101+
# if col_name in row and row[col_name] is not None:
102+
# columns.append(col_name)
103+
# if isinstance(row[col_name], str):
104+
# values.append("'" + row[col_name].replace("'", "''") + "'")
105+
# elif isinstance(row[col_name], (int, float)):
106+
# values.append(str(row[col_name]))
107+
# else:
108+
# values.append(f"'{row[col_name]!s}'")
109+
#
110+
# if columns and values:
111+
# columns_stmt = ', '.join(columns)
112+
# values_stmt = ', '.join(values)
113+
# insert_stmt = (
114+
# f"INSERT INTO {table_name} ({columns_stmt}) VALUES ({values_stmt});"
115+
# )
116+
# sample_data_sql.append(insert_stmt)
114117

115118
except Exception as e:
116119
logger.warning(f"Error sampling data from table {table_name}: {e}")
117120

118-
return "\n".join(sample_data_sql)
121+
return "\n\n".join(sample_data)

0 commit comments

Comments
 (0)