Skip to content

Commit 4e84565

Browse files
committed
그래프 빌더 및 Lang2SQL UI 업데이트 #134
1 parent bbec798 commit 4e84565

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

interface/graph_builder.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,50 @@ def render_sequence(sequence: List[str]) -> str:
128128
# 프리셋에서는 QUERY_MAKER 자동 포함
129129
use_query_maker = True
130130

131+
# GET_TABLE_INFO 설정
132+
st.subheader("GET_TABLE_INFO 설정")
133+
_prev_cfg = st.session_state.get("graph_config", {})
134+
135+
_retriever_options = {
136+
"기본": "벡터 검색 (기본)",
137+
"Reranker": "Reranker 검색 (정확도 향상)",
138+
}
139+
_retriever_keys = list(_retriever_options.keys())
140+
_retriever_default = _prev_cfg.get("retriever_name", "기본")
141+
_retriever_index = (
142+
_retriever_keys.index(_retriever_default)
143+
if _retriever_default in _retriever_keys
144+
else 0
145+
)
146+
147+
retriever_name = st.selectbox(
148+
"테이블 검색기",
149+
options=_retriever_keys,
150+
format_func=lambda x: _retriever_options[x],
151+
index=_retriever_index,
152+
)
153+
154+
top_n = st.slider(
155+
"검색할 테이블 정보 개수",
156+
min_value=1,
157+
max_value=20,
158+
value=int(_prev_cfg.get("top_n", 5)),
159+
step=1,
160+
)
161+
162+
_device_options = ["cpu", "cuda"]
163+
_device_default = _prev_cfg.get("device", "cpu")
164+
_device_index = (
165+
_device_options.index(_device_default)
166+
if _device_default in _device_options
167+
else 0
168+
)
169+
device = st.selectbox(
170+
"모델 실행 장치",
171+
options=_device_options,
172+
index=_device_index,
173+
)
174+
131175

132176
def build_sequence_with_qm(
133177
preset: str, use_profile: bool, use_context: bool, use_qm: bool
@@ -166,6 +210,9 @@ def build_sequence_with_qm(
166210
"use_profile": use_profile,
167211
"use_context": use_context,
168212
"use_query_maker": use_query_maker,
213+
"retriever_name": retriever_name,
214+
"top_n": top_n,
215+
"device": device,
169216
}
170217

171218
# 선택이 바뀌면 자동으로 세션 그래프 갱신
@@ -174,13 +221,20 @@ def build_sequence_with_qm(
174221
_builder = build_state_graph(sequence)
175222
st.session_state["graph"] = _builder.compile()
176223
st.session_state["graph_config"] = config
224+
# Lang2SQL 메인 UI에서 기본값으로 사용할 옵션 전달
225+
st.session_state["default_retriever_name"] = retriever_name
226+
st.session_state["default_top_n"] = top_n
227+
st.session_state["default_device"] = device
177228
st.info("그래프가 세션에 적용되었습니다.")
178229

179230
# 수동 새로고침 버튼
180231
if st.button("세션 그래프 새로고침"):
181232
_builder = build_state_graph(sequence)
182233
st.session_state["graph"] = _builder.compile()
183234
st.session_state["graph_config"] = config
235+
st.session_state["default_retriever_name"] = retriever_name
236+
st.session_state["default_top_n"] = top_n
237+
st.session_state["default_device"] = device
184238
st.success("세션 그래프가 새로고침되었습니다.")
185239

186240
with st.expander("현재 세션 그래프 설정"):

interface/lang2sql.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,29 +302,43 @@ def should_show(_key: str) -> bool:
302302
index=0,
303303
)
304304

305+
_device_options = ["cpu", "cuda"]
306+
_default_device = st.session_state.get("default_device", "cpu")
307+
_device_index = (
308+
_device_options.index(_default_device)
309+
if _default_device in _device_options
310+
else 0
311+
)
305312
device = st.selectbox(
306313
"모델 실행 장치를 선택하세요:",
307-
options=["cpu", "cuda"],
308-
index=0,
314+
options=_device_options,
315+
index=_device_index,
309316
)
310317

311318
retriever_options = {
312319
"기본": "벡터 검색 (기본)",
313320
"Reranker": "Reranker 검색 (정확도 향상)",
314321
}
315322

323+
_retriever_keys = list(retriever_options.keys())
324+
_default_retriever = st.session_state.get("default_retriever_name", "기본")
325+
_retriever_index = (
326+
_retriever_keys.index(_default_retriever)
327+
if _default_retriever in _retriever_keys
328+
else 0
329+
)
316330
user_retriever = st.selectbox(
317331
"검색기 유형을 선택하세요:",
318-
options=list(retriever_options.keys()),
332+
options=_retriever_keys,
319333
format_func=lambda x: retriever_options[x],
320-
index=0,
334+
index=_retriever_index,
321335
)
322336

323337
user_top_n = st.slider(
324338
"검색할 테이블 정보 개수:",
325339
min_value=1,
326340
max_value=20,
327-
value=5,
341+
value=int(st.session_state.get("default_top_n", 5)),
328342
step=1,
329343
help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.",
330344
)

0 commit comments

Comments
 (0)