Skip to content

Commit 96c5fd9

Browse files
nonegomchanguk
authored andcommitted
feat: SQL Prompt class and update prompt from file function
1 parent a204bbb commit 96c5fd9

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

llm_utils/prompts_class.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from langchain.chains.sql_database.prompt import SQL_PROMPTS
2+
import os
3+
4+
from langchain_core.prompts import load_prompt
5+
6+
7+
class SQLPrompt():
8+
def __init__(self):
9+
# os library를 확인해서 SQL_PROMPTS key에 해당하는ㅁ prompt가 있으면, 이를 교체
10+
self.sql_prompts = SQL_PROMPTS
11+
self.target_db_list = list(SQL_PROMPTS.keys())
12+
self.prompt_path = '../prompt'
13+
14+
def update_prompt_from_path(self):
15+
if os.path.exists(self.prompt_path):
16+
path_list = os.listdir(self.prompt_path)
17+
# yaml 파일만 가져옴
18+
file_list = [file for file in path_list if file.endswith('.yaml')]
19+
key_path_dict = {key.split('.')[0]: os.path.join(self.prompt_path, key) for key in file_list if key.split('.')[0] in self.target_db_list}
20+
# file_list에서 sql_prompts의 key에 해당하는 파일이 있는 것만 가져옴
21+
for key, path in key_path_dict.items():
22+
self.sql_prompts[key] = load_prompt(path, encoding='utf-8')
23+
else:
24+
raise FileNotFoundError(f"Prompt file not found in {self.prompt_path}")
25+
return False
26+
27+
if __name__ == '__main__':
28+
sql_prompts_class = SQLPrompt()
29+
print(sql_prompts_class.sql_prompts['mysql'])
30+
print(sql_prompts_class.update_prompt_from_path())
31+
32+
print(sql_prompts_class.sql_prompts['mysql'])
33+
print(sql_prompts_class.sql_prompts)

prompt/mysql.yaml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
_type: prompt
2+
template: |
3+
커스텀 MySQL 프롬프트입니다.
4+
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
5+
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
6+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
7+
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
8+
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
9+
10+
Use the following format:
11+
12+
Question: Question here
13+
SQLQuery: SQL Query to run
14+
SQLResult: Result of the SQLQuery
15+
Answer: Final answer here
16+
17+
Only use the following tables:
18+
{table_info}
19+
20+
Question: {input}
21+
input_variables: ["input", "table_info", "top_k"]

0 commit comments

Comments
 (0)