Skip to content

Commit 3cdac63

Browse files
committed
feature: evaluation 질문 결과 시각화 페이지 추가
1 parent 07bf11f commit 3cdac63

File tree

3 files changed

+377
-46
lines changed

3 files changed

+377
-46
lines changed

interface/lang2sql.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import streamlit as st
2+
from langchain_core.messages import HumanMessage
3+
from llm_utils.graph import builder
4+
from langchain.chains.sql_database.prompt import SQL_PROMPTS
5+
6+
# Streamlit 앱 제목
7+
st.title("Lang2SQL")
8+
9+
# 사용자 입력 받기
10+
user_query = st.text_area(
11+
"쿼리를 입력하세요:",
12+
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리",
13+
)
14+
15+
user_database_env = st.selectbox(
16+
"db 환경정보를 입력하세요:",
17+
options=SQL_PROMPTS.keys(),
18+
index=0,
19+
)
20+
21+
22+
# Token usage 집계 함수 정의
23+
def summarize_total_tokens(data):
24+
total_tokens = 0
25+
for item in data:
26+
token_usage = getattr(item, "usage_metadata", {})
27+
total_tokens += token_usage.get("total_tokens", 0)
28+
return total_tokens
29+
30+
31+
# 버튼 클릭 시 실행
32+
if st.button("쿼리 실행"):
33+
# 그래프 컴파일 및 쿼리 실행
34+
graph = builder.compile()
35+
36+
res = graph.invoke(
37+
input={
38+
"messages": [HumanMessage(content=user_query)],
39+
"user_database_env": user_database_env,
40+
"best_practice_query": "",
41+
}
42+
)
43+
total_tokens = summarize_total_tokens(res["messages"])
44+
45+
# 결과 출력
46+
st.write("총 토큰 사용량:", total_tokens)
47+
# st.write("결과:", res["generated_query"].content)
48+
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
49+
st.write("결과 설명:\n\n", res["messages"][-1].content)
50+
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
51+
st.write("참고한 테이블 목록:", res["searched_tables"])

interface/streamlit_app.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,11 @@
11
import streamlit as st
2-
from langchain_core.messages import HumanMessage
3-
from llm_utils.graph import builder
4-
from langchain.chains.sql_database.prompt import SQL_PROMPTS
52

6-
# Streamlit 앱 제목
7-
st.title("Lang2SQL")
83

9-
# 사용자 입력 받기
10-
user_query = st.text_area(
11-
"쿼리를 입력하세요:",
12-
value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리",
4+
pg = st.navigation(
5+
[
6+
st.Page("lang2sql.py", title="Lang2SQL"),
7+
st.Page("viz_eval.py", title="Lang2SQL Evaluation 시각화"),
8+
]
139
)
1410

15-
user_database_env = st.selectbox(
16-
"db 환경정보를 입력하세요:",
17-
options=SQL_PROMPTS.keys(),
18-
index=0,
19-
)
20-
21-
22-
# Token usage 집계 함수 정의
23-
def summarize_total_tokens(data):
24-
total_tokens = 0
25-
for item in data:
26-
token_usage = getattr(item, "usage_metadata", {})
27-
total_tokens += token_usage.get("total_tokens", 0)
28-
return total_tokens
29-
30-
31-
# 버튼 클릭 시 실행
32-
if st.button("쿼리 실행"):
33-
# 그래프 컴파일 및 쿼리 실행
34-
graph = builder.compile()
35-
36-
res = graph.invoke(
37-
input={
38-
"messages": [HumanMessage(content=user_query)],
39-
"user_database_env": user_database_env,
40-
"best_practice_query": "",
41-
}
42-
)
43-
total_tokens = summarize_total_tokens(res["messages"])
44-
45-
# 결과 출력
46-
st.write("총 토큰 사용량:", total_tokens)
47-
# st.write("결과:", res["generated_query"].content)
48-
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
49-
st.write("결과 설명:\n\n", res["messages"][-1].content)
50-
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
51-
st.write("참고한 테이블 목록:", res["searched_tables"])
11+
pg.run()

0 commit comments

Comments
 (0)