Skip to content

Commit d5d0b34

Browse files
committed
Rename parse_dml_record to parse_sampled_record and update prompt templates for JSON output
1 parent 3a001ee commit d5d0b34

File tree

6 files changed

+100
-36
lines changed

6 files changed

+100
-36
lines changed

camel_database_agent/database/prompts.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,34 @@ class PromptTemplates:
3737
- Specify the expected format and content of comments
3838
- Emphasize professionalism and conciseness
3939
""")
40+
41+
PARSE_SAMPLED_RECORD = textwrap.dedent("""
42+
# JSON Format Request
43+
You are a specialized JSON generator. Your only function is to parse the provided data and convert it to JSON format, strictly following the format requirements.
44+
45+
## Input Data:
46+
{{section}}
47+
48+
## Instructions:
49+
1. Create a JSON array with each table as an object
50+
2. Each object must have exactly three fields:
51+
- "id": the table name
52+
- "summary": a brief description of the table
53+
- "dataset": the data in markdown format
54+
3. The entire response must be ONLY valid JSON without any additional text, explanation, or markdown code blocks
55+
56+
## Required Output Format:
57+
{
58+
"items":[{
59+
"id": "<table name>",
60+
"summary": "<table summary>",
61+
"dataset": "<markdown dataset>"
62+
}]
63+
}
64+
65+
## IMPORTANT:
66+
- Your response must contain ONLY the JSON object, nothing else
67+
- Do not include explanations, introductions, or conclusions
68+
- Do not use markdown code blocks (```) around the JSON
69+
- Do not include phrases like "Here's the JSON" or "I've created the JSON"
70+
- Do not indicate that you are providing the output in any way""")

camel_database_agent/database/schema.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
import re
13
import textwrap
24
from typing import Generic, List, Optional, TypeVar, Union
35

@@ -6,8 +8,11 @@
68
from pydantic import BaseModel
79

810
from camel_database_agent.database.manager import DatabaseManager
11+
from camel_database_agent.database.prompts import PromptTemplates
912
from camel_database_agent.database_base import timing
1013

14+
logger = logging.getLogger(__name__)
15+
1116

1217
class DDLRecord(BaseModel):
1318
id: str
@@ -97,36 +102,30 @@ def parse_ddl_record(self, text: str) -> SchemaParseResponse:
97102
return SchemaParseResponse(data=ddl_record_response.items, usage=response.info["usage"])
98103

99104
@timing
100-
def parse_dml_record(self, text: str) -> SchemaParseResponse:
105+
def parse_sampled_record(self, text: str) -> SchemaParseResponse:
101106
"""Parsing Sampled Data"""
102-
prompt = (
103-
"Translate the following information into a JSON array format, "
104-
"with each JSON object in the array containing three "
105-
"elements: "
106-
"\"id\" for the table name, "
107-
"\"summary\" for a summary of the table, and "
108-
"\"dataset\" for the Markdown of the data.\n\n"
109-
)
110-
prompt += f"{text}\n\n"
111-
112-
# 非 openai 模型要增加以下片段
113-
prompt += textwrap.dedent(
114-
"Output Format:\n"
115-
"{"
116-
" \"items\":"
117-
" ["
118-
" {"
119-
" \"id\": \"<table name>\","
120-
" \"summary\": \"<table summary>\","
121-
" \"dataset\": \"<markdown dataset>\""
122-
" }"
123-
" ]"
124-
"}\n\n"
125-
)
126-
prompt += "Now, directly output the JSON array without explanation."
127-
response = self.parsing_agent.step(prompt, response_format=DMLRecordResponseFormat)
128-
dml_record_response = DMLRecordResponseFormat.model_validate_json(response.msgs[0].content)
129-
return SchemaParseResponse(data=dml_record_response.items, usage=response.info["usage"])
107+
data: List[DMLRecord] = []
108+
usage: Optional[dict] = None
109+
sections = self.split_markdown_by_h2(text)
110+
for section in sections:
111+
prompt = PromptTemplates.PARSE_SAMPLED_RECORD.replace("{{section}}", section)
112+
try:
113+
self.parsing_agent.reset()
114+
response = self.parsing_agent.step(prompt, response_format=DMLRecordResponseFormat)
115+
dml_record_response = DMLRecordResponseFormat.model_validate_json(
116+
response.msgs[0].content
117+
)
118+
data.extend(dml_record_response.items)
119+
if usage is None:
120+
usage = response.info["usage"]
121+
else:
122+
usage["completion_tokens"] += response.info["usage"]["completion_tokens"]
123+
usage["prompt_tokens"] += response.info["usage"]["prompt_tokens"]
124+
usage["total_tokens"] += response.info["usage"]["total_tokens"]
125+
except Exception as e:
126+
logger.error(f"Unable to process messages: {e}")
127+
logger.error(f"Prompt: {prompt}")
128+
return SchemaParseResponse(data=data, usage=usage)
130129

131130
@timing
132131
def parse_query_record(self, text: str) -> SchemaParseResponse:
@@ -143,3 +142,8 @@ def parse_query_record(self, text: str) -> SchemaParseResponse:
143142
response.msgs[0].content
144143
)
145144
return SchemaParseResponse(data=query_record_response.items, usage=response.info["usage"])
145+
146+
def split_markdown_by_h2(self, markdown_text):
147+
sections = re.split(r'(?=^##\s+)', markdown_text, flags=re.MULTILINE)
148+
sections = [section.strip() for section in sections if section.strip()]
149+
return sections

camel_database_agent/database_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _parse_sampled_data_to_knowledge(self, data_samples_size: int = 5) -> TokenU
236236
) as f:
237237
f.write(self.data_sql)
238238

239-
schema_parse_response: SchemaParseResponse = self.schema_parse.parse_dml_record(
239+
schema_parse_response: SchemaParseResponse = self.schema_parse.parse_sampled_record(
240240
self.data_sql
241241
)
242242

camel_database_agent/datagen/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _prepare_prompt(self, query_samples_needed: int) -> str:
5555
prompt = prompt.replace("{{ddl_sql}}", self.ddl_sql)
5656
prompt = prompt.replace("{{data_sql}}", self.data_sql)
5757
prompt = prompt.replace("{{query_samples_size}}", str(query_samples_needed))
58+
prompt = prompt.replace("{{dialect_name}}", self.database_manager.dialect_name())
5859
return prompt
5960

6061
def _parse_response_content(self, content: str) -> List[QueryRecord]:
@@ -76,7 +77,7 @@ def _validate_query(self, query_record: QueryRecord) -> bool:
7677
self.database_manager.select(query_record.sql)
7778
return True
7879
except SQLExecutionError as e:
79-
logger.error(f"{Fore.RED}SQLExecutionError{Fore.RESET}: {e.sql} {e.error_message}")
80+
logger.debug(f"{Fore.RED}SQLExecutionError{Fore.RESET}: {e.sql} {e.error_message}")
8081
return False
8182
except Exception as e:
8283
logger.error(

camel_database_agent/datagen/prompts.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,44 @@
33

44
class PromptTemplates:
55
QUESTION_INFERENCE_PIPELINE = textwrap.dedent("""
6-
Please carefully analyze the following database information and conduct an in-depth analysis from a business perspective. What business query questions might users raise? Please fully consider some complex query scenarios, including but not limited to multi-table associations, grouping statistics, etc.
6+
# JSON Format Request
7+
8+
You are a specialized JSON generator. Your only function is to parse the provided data and convert it to JSON format, strictly following the format requirements.
79
8-
Database Schema:
10+
## Database Schema:
911
```
1012
{{ddl_sql}}
1113
```
1214
13-
Data Example:
15+
## Data Example:
1416
```sql
1517
{{data_sql}}
1618
```
1719
18-
Now, Please generate {{query_samples_size}} real user query questions along with the corresponding SQL query statements without using placeholders. Please output in JSON format.""")
20+
## Instructions:
21+
Database System: {{dialect_name}}
22+
1. Please carefully analyze the following database information and conduct an in-depth analysis from a business perspective. What business query questions might users raise? Please fully consider some complex query scenarios, including but not limited to multi-table associations, grouping statistics, etc.
23+
2. Please ensure that the SQL you write conforms to {{dialect_name}} syntax.
24+
3. Generate {{query_samples_size}} real user query questions along with the corresponding SQL query statements without using placeholders
25+
4. Create a JSON array with each table as an object
26+
5. Each object must have exactly three fields:
27+
- "id": the table name
28+
- "question": a query in natural language.
29+
- "sql": sql statements without placeholders.
30+
6. The entire response must be ONLY valid JSON without any additional text, explanation, or markdown code blocks
31+
32+
## Required Output Format:
33+
{
34+
"items":[{
35+
"id": "<table name>",
36+
"question": "<a query in natural language>",
37+
"sql": "<sql statements>"
38+
}]
39+
}
40+
41+
## IMPORTANT:
42+
- Your response must contain ONLY the JSON object, nothing else
43+
- Do not include explanations, introductions, or conclusions
44+
- Do not use markdown code blocks (```) around the JSON
45+
- Do not include phrases like "Here's the JSON" or "I've created the JSON"
46+
- Do not indicate that you are providing the output in any way.""")

tests/integration_tests/test_database_schema_parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_parse_ddl_record(self) -> None:
3737
def test_parse_dml_record(self) -> None:
3838
current_dir = os.path.dirname(os.path.abspath(__file__))
3939
with open(os.path.join(current_dir, "data.sql"), "r") as f:
40-
schema_parse_response: SchemaParseResponse = self.parse.parse_dml_record(f.read())
40+
schema_parse_response: SchemaParseResponse = self.parse.parse_sampled_record(f.read())
4141
assert len(schema_parse_response.data) == 6
4242

4343
def test_parse_query_record(self) -> None:

0 commit comments

Comments
 (0)