|
11 | 11 |
|
12 | 12 | from llm_utils.connect_db import ConnectDB |
13 | 13 | from llm_utils.graph import builder |
| 14 | +from llm_utils.enriched_graph import builder as enriched_builder |
14 | 15 |
|
15 | 16 | DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" |
16 | 17 | SIDEBAR_OPTIONS = { |
@@ -65,7 +66,10 @@ def execute_query( |
65 | 66 | # 세션 상태에서 그래프 가져오기 |
66 | 67 | graph = st.session_state.get("graph") |
67 | 68 | if graph is None: |
68 | | - graph = builder.compile() |
| 69 | + graph_builder = ( |
| 70 | + enriched_builder if st.session_state.get("use_enriched") else builder |
| 71 | + ) |
| 72 | + graph = graph_builder.compile() |
69 | 73 | st.session_state["graph"] = graph |
70 | 74 |
|
71 | 75 | res = graph.invoke( |
@@ -124,14 +128,29 @@ def display_result( |
124 | 128 |
|
125 | 129 | st.title("Lang2SQL") |
126 | 130 |
|
| 131 | +# 워크플로우 선택(UI) |
| 132 | +use_enriched = st.sidebar.checkbox( |
| 133 | + "프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False |
| 134 | +) |
| 135 | + |
127 | 136 | # 세션 상태 초기화 |
128 | | -if "graph" not in st.session_state: |
129 | | - st.session_state["graph"] = builder.compile() |
| 137 | +if ( |
| 138 | + "graph" not in st.session_state |
| 139 | + or st.session_state.get("use_enriched") != use_enriched |
| 140 | +): |
| 141 | + graph_builder = enriched_builder if use_enriched else builder |
| 142 | + st.session_state["graph"] = graph_builder.compile() |
| 143 | + |
| 144 | + # 프로파일 추출 & 컨텍스트 보강 그래프 |
| 145 | + st.session_state["use_enriched"] = use_enriched |
130 | 146 | st.info("Lang2SQL이 성공적으로 시작되었습니다.") |
131 | 147 |
|
132 | 148 | # 새로고침 버튼 추가 |
133 | 149 | if st.sidebar.button("Lang2SQL 새로고침"): |
134 | | - st.session_state["graph"] = builder.compile() |
| 150 | + graph_builder = ( |
| 151 | + enriched_builder if st.session_state.get("use_enriched") else builder |
| 152 | + ) |
| 153 | + st.session_state["graph"] = graph_builder.compile() |
135 | 154 | st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.") |
136 | 155 |
|
137 | 156 | user_query = st.text_area( |
|
0 commit comments