Skip to content

Commit 9e0318d

Browse files
committed
feat(streamlit) : 사이드 바에 프로파일 추출 & 컨텍스트 보강 워크플로우 사용 체크박스 추가, 그래프 연결
1 parent 7837996 commit 9e0318d

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

interface/lang2sql.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from llm_utils.connect_db import ConnectDB
1313
from llm_utils.graph import builder
14+
from llm_utils.enriched_graph import builder as enriched_builder
1415

1516
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
1617
SIDEBAR_OPTIONS = {
@@ -65,7 +66,10 @@ def execute_query(
6566
# 세션 상태에서 그래프 가져오기
6667
graph = st.session_state.get("graph")
6768
if graph is None:
68-
graph = builder.compile()
69+
graph_builder = (
70+
enriched_builder if st.session_state.get("use_enriched") else builder
71+
)
72+
graph = graph_builder.compile()
6973
st.session_state["graph"] = graph
7074

7175
res = graph.invoke(
@@ -124,14 +128,29 @@ def display_result(
124128

125129
st.title("Lang2SQL")
126130

131+
# 워크플로우 선택(UI)
132+
use_enriched = st.sidebar.checkbox(
133+
"프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False
134+
)
135+
127136
# 세션 상태 초기화
128-
if "graph" not in st.session_state:
129-
st.session_state["graph"] = builder.compile()
137+
if (
138+
"graph" not in st.session_state
139+
or st.session_state.get("use_enriched") != use_enriched
140+
):
141+
graph_builder = enriched_builder if use_enriched else builder
142+
st.session_state["graph"] = graph_builder.compile()
143+
144+
# 프로파일 추출 & 컨텍스트 보강 그래프
145+
st.session_state["use_enriched"] = use_enriched
130146
st.info("Lang2SQL이 성공적으로 시작되었습니다.")
131147

132148
# 새로고침 버튼 추가
133149
if st.sidebar.button("Lang2SQL 새로고침"):
134-
st.session_state["graph"] = builder.compile()
150+
graph_builder = (
151+
enriched_builder if st.session_state.get("use_enriched") else builder
152+
)
153+
st.session_state["graph"] = graph_builder.compile()
135154
st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.")
136155

137156
user_query = st.text_area(

0 commit comments

Comments
 (0)