44from langchain_core .prompts import load_prompt
55
66
7- class SQLPrompt () :
7+ class SQLPrompt :
88 def __init__ (self ):
9- # os library를 확인해서 SQL_PROMPTS key에 해당하는ㅁ prompt가 있으면, 이를 교체
9+ # os library를 확인해서 SQL_PROMPTS key에 해당하는 prompt가 있으면, 이를 교체
1010 self .sql_prompts = SQL_PROMPTS
1111 self .target_db_list = list (SQL_PROMPTS .keys ())
12- self .prompt_path = ' ../prompt'
12+ self .prompt_path = " ../prompt"
1313
1414 def update_prompt_from_path (self ):
1515 if os .path .exists (self .prompt_path ):
1616 path_list = os .listdir (self .prompt_path )
1717 # yaml 파일만 가져옴
18- file_list = [file for file in path_list if file .endswith ('.yaml' )]
19- key_path_dict = {key .split ('.' )[0 ]: os .path .join (self .prompt_path , key ) for key in file_list if key .split ('.' )[0 ] in self .target_db_list }
18+ file_list = [file for file in path_list if file .endswith (".yaml" )]
19+ key_path_dict = {
20+ key .split ("." )[0 ]: os .path .join (self .prompt_path , key )
21+ for key in file_list
22+ if key .split ("." )[0 ] in self .target_db_list
23+ }
2024 # file_list에서 sql_prompts의 key에 해당하는 파일이 있는 것만 가져옴
2125 for key , path in key_path_dict .items ():
22- self .sql_prompts [key ] = load_prompt (path , encoding = ' utf-8' )
26+ self .sql_prompts [key ] = load_prompt (path , encoding = " utf-8" )
2327 else :
2428 raise FileNotFoundError (f"Prompt file not found in { self .prompt_path } " )
2529 return False
26-
27- if __name__ == '__main__' :
28- sql_prompts_class = SQLPrompt ()
29- print (sql_prompts_class .sql_prompts ['mysql' ])
30- print (sql_prompts_class .update_prompt_from_path ())
31-
32- print (sql_prompts_class .sql_prompts ['mysql' ])
33- print (sql_prompts_class .sql_prompts )
0 commit comments