Skip to content

Commit 62ee1d4

Browse files
authored
Merge pull request #138 from #134
그래프 빌더 및 Lang2SQL UI 업데이트 #134
2 parents 9f2cf79 + 5881db6 commit 62ee1d4

File tree

3 files changed

+70
-6
lines changed

3 files changed

+70
-6
lines changed

interface/graph_builder.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,48 @@ def render_sequence(sequence: List[str]) -> str:
128128
# 프리셋에서는 QUERY_MAKER 자동 포함
129129
use_query_maker = True
130130

131+
# GET_TABLE_INFO 설정
132+
st.subheader("GET_TABLE_INFO 설정")
133+
_prev_cfg = st.session_state.get("graph_config", {})
134+
135+
_retriever_options = {
136+
"기본": "벡터 검색 (기본)",
137+
"Reranker": "Reranker 검색 (정확도 향상)",
138+
}
139+
_retriever_keys = list(_retriever_options.keys())
140+
_retriever_default = _prev_cfg.get("retriever_name", "기본")
141+
_retriever_index = (
142+
_retriever_keys.index(_retriever_default)
143+
if _retriever_default in _retriever_keys
144+
else 0
145+
)
146+
147+
retriever_name = st.selectbox(
148+
"테이블 검색기",
149+
options=_retriever_keys,
150+
format_func=lambda x: _retriever_options[x],
151+
index=_retriever_index,
152+
)
153+
154+
top_n = st.slider(
155+
"검색할 테이블 정보 개수",
156+
min_value=1,
157+
max_value=20,
158+
value=int(_prev_cfg.get("top_n", 5)),
159+
step=1,
160+
)
161+
162+
_device_options = ["cpu", "cuda"]
163+
_device_default = _prev_cfg.get("device", "cpu")
164+
_device_index = (
165+
_device_options.index(_device_default) if _device_default in _device_options else 0
166+
)
167+
device = st.selectbox(
168+
"모델 실행 장치",
169+
options=_device_options,
170+
index=_device_index,
171+
)
172+
131173

132174
def build_sequence_with_qm(
133175
preset: str, use_profile: bool, use_context: bool, use_qm: bool
@@ -166,6 +208,9 @@ def build_sequence_with_qm(
166208
"use_profile": use_profile,
167209
"use_context": use_context,
168210
"use_query_maker": use_query_maker,
211+
"retriever_name": retriever_name,
212+
"top_n": top_n,
213+
"device": device,
169214
}
170215

171216
# 선택이 바뀌면 자동으로 세션 그래프 갱신
@@ -174,13 +219,20 @@ def build_sequence_with_qm(
174219
_builder = build_state_graph(sequence)
175220
st.session_state["graph"] = _builder.compile()
176221
st.session_state["graph_config"] = config
222+
# Lang2SQL 메인 UI에서 기본값으로 사용할 옵션 전달
223+
st.session_state["default_retriever_name"] = retriever_name
224+
st.session_state["default_top_n"] = top_n
225+
st.session_state["default_device"] = device
177226
st.info("그래프가 세션에 적용되었습니다.")
178227

179228
# 수동 새로고침 버튼
180229
if st.button("세션 그래프 새로고침"):
181230
_builder = build_state_graph(sequence)
182231
st.session_state["graph"] = _builder.compile()
183232
st.session_state["graph_config"] = config
233+
st.session_state["default_retriever_name"] = retriever_name
234+
st.session_state["default_top_n"] = top_n
235+
st.session_state["default_device"] = device
184236
st.success("세션 그래프가 새로고침되었습니다.")
185237

186238
with st.expander("현재 세션 그래프 설정"):

interface/lang2sql.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,29 +302,41 @@ def should_show(_key: str) -> bool:
302302
index=0,
303303
)
304304

305+
_device_options = ["cpu", "cuda"]
306+
_default_device = st.session_state.get("default_device", "cpu")
307+
_device_index = (
308+
_device_options.index(_default_device) if _default_device in _device_options else 0
309+
)
305310
device = st.selectbox(
306311
"모델 실행 장치를 선택하세요:",
307-
options=["cpu", "cuda"],
308-
index=0,
312+
options=_device_options,
313+
index=_device_index,
309314
)
310315

311316
retriever_options = {
312317
"기본": "벡터 검색 (기본)",
313318
"Reranker": "Reranker 검색 (정확도 향상)",
314319
}
315320

321+
_retriever_keys = list(retriever_options.keys())
322+
_default_retriever = st.session_state.get("default_retriever_name", "기본")
323+
_retriever_index = (
324+
_retriever_keys.index(_default_retriever)
325+
if _default_retriever in _retriever_keys
326+
else 0
327+
)
316328
user_retriever = st.selectbox(
317329
"검색기 유형을 선택하세요:",
318-
options=list(retriever_options.keys()),
330+
options=_retriever_keys,
319331
format_func=lambda x: retriever_options[x],
320-
index=0,
332+
index=_retriever_index,
321333
)
322334

323335
user_top_n = st.slider(
324336
"검색할 테이블 정보 개수:",
325337
min_value=1,
326338
max_value=20,
327-
value=5,
339+
value=int(st.session_state.get("default_top_n", 5)),
328340
step=1,
329341
help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.",
330342
)

version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
- PATCH는 1로 증가합니다.
1919
"""
2020

21-
__version__ = "0.2.1"
21+
__version__ = "0.2.2"

0 commit comments

Comments
 (0)