|
1 | 1 | import streamlit as st |
2 | | -from langchain_core.messages import HumanMessage |
3 | | -from llm_utils.graph import builder |
4 | | -from langchain.chains.sql_database.prompt import SQL_PROMPTS |
5 | 2 |
|
6 | | -# Streamlit 앱 제목 |
7 | | -st.title("Lang2SQL") |
8 | 3 |
|
9 | | -# 사용자 입력 받기 |
10 | | -user_query = st.text_area( |
11 | | - "쿼리를 입력하세요:", |
12 | | - value="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리", |
| 4 | +pg = st.navigation( |
| 5 | + [ |
| 6 | + st.Page("lang2sql.py", title="Lang2SQL"), |
| 7 | + st.Page("viz_eval.py", title="Lang2SQL Evaluation 시각화"), |
| 8 | + ] |
13 | 9 | ) |
14 | 10 |
|
15 | | -user_database_env = st.selectbox( |
16 | | - "db 환경정보를 입력하세요:", |
17 | | - options=SQL_PROMPTS.keys(), |
18 | | - index=0, |
19 | | -) |
20 | | - |
21 | | - |
22 | | -# Token usage 집계 함수 정의 |
23 | | -def summarize_total_tokens(data): |
24 | | - total_tokens = 0 |
25 | | - for item in data: |
26 | | - token_usage = getattr(item, "usage_metadata", {}) |
27 | | - total_tokens += token_usage.get("total_tokens", 0) |
28 | | - return total_tokens |
29 | | - |
30 | | - |
31 | | -# 버튼 클릭 시 실행 |
32 | | -if st.button("쿼리 실행"): |
33 | | - # 그래프 컴파일 및 쿼리 실행 |
34 | | - graph = builder.compile() |
35 | | - |
36 | | - res = graph.invoke( |
37 | | - input={ |
38 | | - "messages": [HumanMessage(content=user_query)], |
39 | | - "user_database_env": user_database_env, |
40 | | - "best_practice_query": "", |
41 | | - } |
42 | | - ) |
43 | | - total_tokens = summarize_total_tokens(res["messages"]) |
44 | | - |
45 | | - # 결과 출력 |
46 | | - st.write("총 토큰 사용량:", total_tokens) |
47 | | - # st.write("결과:", res["generated_query"].content) |
48 | | - st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```") |
49 | | - st.write("결과 설명:\n\n", res["messages"][-1].content) |
50 | | - st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content) |
51 | | - st.write("참고한 테이블 목록:", res["searched_tables"]) |
| 11 | +pg.run() |
0 commit comments