Skip to content

Commit 56ae3c8

Browse files
authored
Merge pull request #147 from #141, #2
close #2
2 parents 94e7375 + 827db80 commit 56ae3c8

File tree

10 files changed

+407
-13
lines changed

10 files changed

+407
-13
lines changed

interface/lang2sql.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from db_utils import get_db_connector
1515
from db_utils.base_connector import BaseConnector
16-
from infra.db.connect_db import ConnectDB
1716
from viz.display_chart import DisplayChart
1817
from engine.query_executor import execute_query as execute_query_common
1918
from llm_utils.llm_response_parser import LLMResponseParser
@@ -30,6 +29,8 @@
3029
"show_sql": "Show SQL",
3130
"show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI",
3231
"show_referenced_tables": "Show List of Referenced Tables",
32+
"show_question_gate_result": "Show Question Gate Result",
33+
"show_document_suitability": "Show Document Suitability",
3334
"show_table": "Show Table",
3435
"show_chart": "Show Chart",
3536
}
@@ -103,8 +104,55 @@ def should_show(_key: str) -> bool:
103104
show_sql_section = has_query and should_show("show_sql")
104105
show_result_desc = has_query and should_show("show_result_description")
105106
show_reinterpreted = has_query and should_show("show_question_reinterpreted_by_ai")
107+
show_gate_result = should_show("show_question_gate_result")
108+
show_doc_suitability = should_show("show_document_suitability")
106109
show_table_section = has_query and should_show("show_table")
107110
show_chart_section = has_query and should_show("show_chart")
111+
if show_gate_result and ("question_gate_result" in res):
112+
st.markdown("---")
113+
st.markdown("**Question Gate 결과:**")
114+
details = res.get("question_gate_result")
115+
if details:
116+
try:
117+
import json as _json
118+
119+
st.code(
120+
_json.dumps(details, ensure_ascii=False, indent=2), language="json"
121+
)
122+
except Exception:
123+
st.write(details)
124+
125+
if show_doc_suitability and ("document_suitability" in res):
126+
st.markdown("---")
127+
st.markdown("**문서 적합성 평가:**")
128+
ds = res.get("document_suitability")
129+
if not isinstance(ds, dict):
130+
st.write(ds)
131+
else:
132+
133+
def _as_float(value):
134+
try:
135+
return float(value)
136+
except Exception:
137+
return -1.0
138+
139+
rows = [
140+
{
141+
"table": table_name,
142+
"score": _as_float(info.get("score", -1)),
143+
"matched_columns": ", ".join(info.get("matched_columns", [])),
144+
"missing_entities": ", ".join(info.get("missing_entities", [])),
145+
"reason": info.get("reason", ""),
146+
}
147+
for table_name, info in ds.items()
148+
if isinstance(info, dict)
149+
]
150+
151+
rows.sort(key=lambda r: r["score"], reverse=True)
152+
if rows:
153+
st.dataframe(rows, use_container_width=True)
154+
else:
155+
st.info("문서 적합성 평가 결과가 비어 있습니다.")
108156

109157
if should_show("show_token_usage"):
110158
st.markdown("---")

llm_utils/chains.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
1+
"""
2+
LLM 체인 생성 모듈.
3+
4+
이 모듈은 Lang2SQL에서 사용하는 다양한 LangChain 기반 체인을 정의합니다.
5+
- Query Maker
6+
- Query Enrichment
7+
- Profile Extraction
8+
- Question Gate (SQL 적합성 분류)
9+
"""
10+
111
import os
212
from langchain_core.prompts import (
313
ChatPromptTemplate,
4-
MessagesPlaceholder,
514
SystemMessagePromptTemplate,
615
)
716
from pydantic import BaseModel, Field
17+
from llm_utils.output_parser.question_suitability import QuestionSuitability
18+
from llm_utils.output_parser.document_suitability import (
19+
DocumentSuitabilityList,
20+
)
821

922
from llm_utils.llm import get_llm
1023

@@ -15,6 +28,12 @@
1528

1629

1730
class QuestionProfile(BaseModel):
31+
"""
32+
자연어 질문의 특징을 구조화해 표현하는 프로파일 모델.
33+
34+
이 프로파일은 이후 컨텍스트 보강 및 SQL 생성 시 힌트로 사용됩니다.
35+
"""
36+
1837
is_timeseries: bool = Field(description="시계열 분석 필요 여부")
1938
is_aggregation: bool = Field(description="집계 함수 필요 여부")
2039
has_filter: bool = Field(description="조건 필터 필요 여부")
@@ -26,6 +45,15 @@ class QuestionProfile(BaseModel):
2645

2746
# QueryMakerChain
2847
def create_query_maker_chain(llm):
48+
"""
49+
SQL 쿼리 생성을 위한 체인을 생성합니다.
50+
51+
Args:
52+
llm: LangChain 호환 LLM 인스턴스
53+
54+
Returns:
55+
Runnable: 입력 프롬프트를 받아 SQL을 생성하는 체인
56+
"""
2957
prompt = get_prompt_template("query_maker_prompt")
3058
query_maker_prompt = ChatPromptTemplate.from_messages(
3159
[
@@ -36,6 +64,15 @@ def create_query_maker_chain(llm):
3664

3765

3866
def create_query_enrichment_chain(llm):
67+
"""
68+
사용자 질문을 메타데이터로 보강하기 위한 체인을 생성합니다.
69+
70+
Args:
71+
llm: LangChain 호환 LLM 인스턴스
72+
73+
Returns:
74+
Runnable: 보강된 질문 텍스트를 반환하는 체인
75+
"""
3976
prompt = get_prompt_template("query_enrichment_prompt")
4077

4178
enrichment_prompt = ChatPromptTemplate.from_messages(
@@ -49,6 +86,15 @@ def create_query_enrichment_chain(llm):
4986

5087

5188
def create_profile_extraction_chain(llm):
89+
"""
90+
질문으로부터 `QuestionProfile`을 추출하는 체인을 생성합니다.
91+
92+
Args:
93+
llm: LangChain 호환 LLM 인스턴스
94+
95+
Returns:
96+
Runnable: `QuestionProfile` 구조화 출력을 반환하는 체인
97+
"""
5298
prompt = get_prompt_template("profile_extraction_prompt")
5399

54100
profile_prompt = ChatPromptTemplate.from_messages(
@@ -61,9 +107,47 @@ def create_profile_extraction_chain(llm):
61107
return chain
62108

63109

110+
def create_question_gate_chain(llm):
111+
"""
112+
질문 적합성(Question Gate) 체인을 생성합니다.
113+
114+
ChatPromptTemplate(SystemMessage) + LLM 구조화 출력으로
115+
`QuestionSuitability`를 반환합니다.
116+
117+
Args:
118+
llm: LangChain 호환 LLM 인스턴스
119+
120+
Returns:
121+
Runnable: invoke({"question": str}) -> QuestionSuitability
122+
"""
123+
124+
prompt = get_prompt_template("question_gate_prompt")
125+
gate_prompt = ChatPromptTemplate.from_messages(
126+
[SystemMessagePromptTemplate.from_template(prompt)]
127+
)
128+
return gate_prompt | llm.with_structured_output(QuestionSuitability)
129+
130+
131+
def create_document_suitability_chain(llm):
132+
"""
133+
문서 적합성 평가 체인을 생성합니다.
134+
135+
질문(question)과 검색 결과(tables)를 입력으로 받아
136+
테이블별 적합도 점수를 포함한 JSON 딕셔너리를 반환합니다.
137+
138+
Returns:
139+
Runnable: invoke({"question": str, "tables": dict}) -> {"results": DocumentSuitability[]}
140+
"""
141+
142+
prompt = get_prompt_template("document_suitability_prompt")
143+
doc_prompt = ChatPromptTemplate.from_messages(
144+
[SystemMessagePromptTemplate.from_template(prompt)]
145+
)
146+
return doc_prompt | llm.with_structured_output(DocumentSuitabilityList)
147+
148+
64149
query_maker_chain = create_query_maker_chain(llm)
65150
profile_extraction_chain = create_profile_extraction_chain(llm)
66151
query_enrichment_chain = create_query_enrichment_chain(llm)
67-
68-
if __name__ == "__main__":
69-
pass
152+
question_gate_chain = create_question_gate_chain(llm)
153+
document_suitability_chain = create_document_suitability_chain(llm)

llm_utils/graph_utils/base.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
import os
21
import json
32

43
from typing_extensions import TypedDict, Annotated
5-
from langgraph.graph import END, StateGraph
64
from langgraph.graph.message import add_messages
75

86

97
from llm_utils.chains import (
108
query_maker_chain,
119
profile_extraction_chain,
1210
query_enrichment_chain,
11+
question_gate_chain,
12+
document_suitability_chain,
1313
)
1414

15-
from llm_utils.tools import get_info_from_db
1615
from llm_utils.retrieval import search_tables
17-
from llm_utils.graph_utils.profile_utils import profile_to_text
1816

1917
# 노드 식별자 정의
18+
QUESTION_GATE = "question_gate"
19+
EVALUATE_DOCUMENT_SUITABILITY = "evaluate_document_suitability"
2020
GET_TABLE_INFO = "get_table_info"
2121
TOOL = "tool"
2222
TABLE_FILTER = "table_filter"
@@ -30,12 +30,39 @@ class QueryMakerState(TypedDict):
3030
messages: Annotated[list, add_messages]
3131
user_database_env: str
3232
searched_tables: dict[str, dict[str, str]]
33+
document_suitability: dict
3334
best_practice_query: str
3435
question_profile: dict
3536
generated_query: str
3637
retriever_name: str
3738
top_n: int
3839
device: str
40+
question_gate_result: dict
41+
42+
43+
# 노드 함수: QUESTION_GATE 노드
44+
def question_gate_node(state: QueryMakerState):
45+
"""
46+
사용자의 질문이 SQL로 답변 가능한지 판별하고, 구조화된 결과를 반환하는 게이트 노드입니다.
47+
48+
- question_gate_chain 으로 적합성을 판정하여
49+
`question_gate_result`를 설정합니다.
50+
51+
Args:
52+
state (QueryMakerState): 그래프 상태
53+
54+
Returns:
55+
QueryMakerState: 게이트 판정 결과가 반영된 상태
56+
"""
57+
58+
question_text = state["messages"][0].content
59+
suitability = question_gate_chain.invoke({"question": question_text})
60+
state["question_gate_result"] = {
61+
"reason": getattr(suitability, "reason", ""),
62+
"missing_entities": getattr(suitability, "missing_entities", []),
63+
"requires_data_science": getattr(suitability, "requires_data_science", False),
64+
}
65+
return state
3966

4067

4168
# 노드 함수: PROFILE_EXTRACTION 노드
@@ -132,6 +159,70 @@ def get_table_info_node(state: QueryMakerState):
132159
return state
133160

134161

162+
# 노드 함수: DOCUMENT_SUITABILITY 노드
163+
def document_suitability_node(state: QueryMakerState):
164+
"""
165+
GET_TABLE_INFO에서 수집된 테이블 후보들에 대해 문서 적합성 점수를 계산하는 노드입니다.
166+
167+
질문(`messages[0].content`)과 `searched_tables`(테이블→칼럼 설명 맵)를 입력으로
168+
프롬프트 체인(`document_suitability_chain`)을 호출하고, 결과 딕셔너리를
169+
`document_suitability` 상태 키에 저장합니다.
170+
171+
Returns:
172+
QueryMakerState: 문서 적합성 평가 결과가 포함된 상태
173+
"""
174+
175+
# 관련 테이블이 없으면 즉시 반환
176+
if not state.get("searched_tables"):
177+
state["document_suitability"] = {}
178+
return state
179+
180+
res = document_suitability_chain.invoke(
181+
{
182+
"question": state["messages"][0].content,
183+
"tables": state["searched_tables"],
184+
}
185+
)
186+
187+
items = (
188+
res.get("results", [])
189+
if isinstance(res, dict)
190+
else getattr(res, "results", None)
191+
or (res.model_dump().get("results", []) if hasattr(res, "model_dump") else [])
192+
)
193+
194+
normalized = {}
195+
for x in items:
196+
d = (
197+
x.model_dump()
198+
if hasattr(x, "model_dump")
199+
else (
200+
x
201+
if isinstance(x, dict)
202+
else {
203+
"table_name": getattr(x, "table_name", ""),
204+
"score": getattr(x, "score", 0),
205+
"reason": getattr(x, "reason", ""),
206+
"matched_columns": getattr(x, "matched_columns", []),
207+
"missing_entities": getattr(x, "missing_entities", []),
208+
}
209+
)
210+
)
211+
t = d.get("table_name")
212+
if not t:
213+
continue
214+
normalized[t] = {
215+
"score": float(d.get("score", 0)),
216+
"reason": d.get("reason", ""),
217+
"matched_columns": d.get("matched_columns", []),
218+
"missing_entities": d.get("missing_entities", []),
219+
}
220+
221+
state["document_suitability"] = normalized
222+
223+
return state
224+
225+
135226
# 노드 함수: QUERY_MAKER 노드
136227
def query_maker_node(state: QueryMakerState):
137228
# 사용자 원 질문 + (있다면) 컨텍스트 보강 결과를 하나의 문자열로 결합

llm_utils/graph_utils/basic_graph.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
from langgraph.graph import StateGraph, END
44
from llm_utils.graph_utils.base import (
55
QueryMakerState,
6+
QUESTION_GATE,
67
GET_TABLE_INFO,
8+
EVALUATE_DOCUMENT_SUITABILITY,
79
QUERY_MAKER,
10+
question_gate_node,
811
get_table_info_node,
12+
document_suitability_node,
913
query_maker_node,
1014
)
1115

@@ -16,14 +20,31 @@
1620

1721
# StateGraph 생성 및 구성
1822
builder = StateGraph(QueryMakerState)
19-
builder.set_entry_point(GET_TABLE_INFO)
23+
builder.set_entry_point(QUESTION_GATE)
2024

2125
# 노드 추가
26+
builder.add_node(QUESTION_GATE, question_gate_node)
2227
builder.add_node(GET_TABLE_INFO, get_table_info_node)
28+
builder.add_node(EVALUATE_DOCUMENT_SUITABILITY, document_suitability_node)
2329
builder.add_node(QUERY_MAKER, query_maker_node)
2430

31+
32+
def _route_after_gate(state: QueryMakerState):
33+
return GET_TABLE_INFO
34+
35+
36+
builder.add_conditional_edges(
37+
QUESTION_GATE,
38+
_route_after_gate,
39+
{
40+
GET_TABLE_INFO: GET_TABLE_INFO,
41+
END: END,
42+
},
43+
)
44+
2545
# 기본 엣지 설정
26-
builder.add_edge(GET_TABLE_INFO, QUERY_MAKER)
46+
builder.add_edge(GET_TABLE_INFO, EVALUATE_DOCUMENT_SUITABILITY)
47+
builder.add_edge(EVALUATE_DOCUMENT_SUITABILITY, QUERY_MAKER)
2748

2849
# QUERY_MAKER 노드 후 종료
2950
builder.add_edge(QUERY_MAKER, END)

0 commit comments

Comments
 (0)