diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 395b2b7..6a80e1c 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -2,6 +2,18 @@ from langchain_core.messages import HumanMessage from llm_utils.graph import builder from langchain.chains.sql_database.prompt import SQL_PROMPTS +import os +from typing import Union +import pandas as pd + +from clickhouse_driver import Client +from llm_utils.connect_db import ConnectDB +from dotenv import load_dotenv + + +# Clickhouse 연결 +db = ConnectDB() +db.connect_to_clickhouse() # Streamlit 앱 제목 st.title("Lang2SQL") @@ -17,6 +29,22 @@ options=SQL_PROMPTS.keys(), index=0, ) +st.sidebar.title("Output Settings") +st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage") +st.sidebar.checkbox( + "Show Result Description", value=True, key="show_result_description" +) +st.sidebar.checkbox("Show SQL", value=True, key="show_sql") +st.sidebar.checkbox( + "Show User Question Reinterpreted by AI", + value=True, + key="show_question_reinterpreted_by_ai", +) +st.sidebar.checkbox( + "Show List of Referenced Tables", value=True, key="show_referenced_tables" +) +st.sidebar.checkbox("Show Table", value=True, key="show_table") +st.sidebar.checkbox("Show Chart", value=True, key="show_chart") # Token usage 집계 함수 정의 @@ -43,9 +71,20 @@ def summarize_total_tokens(data): total_tokens = summarize_total_tokens(res["messages"]) # 결과 출력 - st.write("총 토큰 사용량:", total_tokens) - # st.write("결과:", res["generated_query"].content) - st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```") - st.write("결과 설명:\n\n", res["messages"][-1].content) - st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content) - st.write("참고한 테이블 목록:", res["searched_tables"]) + if st.session_state.get("show_total_token_usage", True): + st.write("총 토큰 사용량:", total_tokens) + if st.session_state.get("show_sql", True): + st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```") + if st.session_state.get("show_result_description", True): + st.write("결과 설명:\n\n", res["messages"][-1].content) + if st.session_state.get("show_question_reinterpreted_by_ai", True): + st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content) + if st.session_state.get("show_referenced_tables", True): + st.write("참고한 테이블 목록:", res["searched_tables"]) + if st.session_state.get("show_table", True): + sql = res["generated_query"] + df = db.run_sql(sql) + if len(df) > 10: + st.dataframe(df.head(10)) + else: + st.dataframe(df) diff --git a/interface/streamlit_app.py b/interface/streamlit_app.py index 81df482..83c85b1 100644 --- a/interface/streamlit_app.py +++ b/interface/streamlit_app.py @@ -1,6 +1,5 @@ import streamlit as st - pg = st.navigation( [ st.Page("lang2sql.py", title="Lang2SQL"), diff --git a/llm_utils/chains.py b/llm_utils/chains.py index d9e5e6c..3a222fa 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -12,11 +12,7 @@ else: print(f"⚠️ 환경변수 파일(.env)이 {os.getcwd()}에 없습니다!") -llm = get_llm( - model_type="openai", - model_name="gpt-4o-mini", - openai_api_key=os.getenv("OPENAI_API_KEY"), -) +llm = get_llm() def create_query_refiner_chain(llm): diff --git a/llm_utils/connect_db.py b/llm_utils/connect_db.py new file mode 100644 index 0000000..aa2c099 --- /dev/null +++ b/llm_utils/connect_db.py @@ -0,0 +1,44 @@ +import os +from typing import Union +import pandas as pd +from clickhouse_driver import Client +from dotenv import load_dotenv + +# 환경변수 +load_dotenv() + + +class ConnectDB: + def __init__(self): + self.client = None + self.host = os.getenv("CLICKHOUSE_HOST") + self.dbname = os.getenv("CLICKHOUSE_DATABASE") + self.user = os.getenv("CLICKHOUSE_USER") + self.password = os.getenv("CLICKHOUSE_PASSWORD") + self.port = os.getenv("CLICKHOUSE_PORT") + + def connect_to_clickhouse(self): + + # ClickHouse 서버 정보 + self.client = Client( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=self.dbname, # 예: '127.0.0.1' # 기본 TCP 포트 + ) + + def run_sql(self, sql: str) -> Union[pd.DataFrame, None]: + if self.client: + try: + result = self.client.execute(sql, with_column_types=True) + # 결과와 컬럼 정보 분리 + rows, columns = result + column_names = [col[0] for col in columns] + + # Create a pandas dataframe from the results + df = pd.DataFrame(rows, columns=column_names) + return df + + except Exception as e: + raise e diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 0aef51d..772cec2 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -62,7 +62,6 @@ def get_table_info_node(state: QueryMakerState): documents = get_info_from_db() db = FAISS.from_documents(documents, embeddings) db.save_local(os.getcwd() + "/table_info_db") - print("table_info_db not found") doc_res = db.similarity_search(state["messages"][-1].content) documents_dict = {} @@ -112,11 +111,7 @@ class SQLResult(BaseModel): def query_maker_node_with_db_guide(state: QueryMakerState): sql_prompt = SQL_PROMPTS[state["user_database_env"]] - llm = get_llm( - model_type="openai", - model_name="gpt-4o-mini", - openai_api_key=os.getenv("OPENAI_API_KEY"), - ) + llm = get_llm() chain = sql_prompt | llm.with_structured_output(SQLResult) res = chain.invoke( input={ diff --git a/setup.py b/setup.py index d5e4805..71a31ac 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,10 @@ "streamlit==1.41.1", "python-dotenv==1.0.1", "faiss-cpu==1.10.0", + "langchain-aws>=0.2.21,<0.3.0", + "langchain-google-genai>=2.1.3,<3.0.0", + "langchain-ollama>=0.3.2,<0.4.0", + "langchain-huggingface>=0.1.2,<0.2.0", ], entry_points={ "console_scripts": [