Skip to content

Commit 58ce2a3

Browse files
committed
feat: Add SQL query generation with database guidance
- Implemented query_maker_node_with_db_guide to generate SQL queries using user input and database context. - Introduced SQLResult model for structured output of SQL queries and explanations. - Updated graph structure to replace the previous query_maker_node with the new functionality.
1 parent ff9b754 commit 58ce2a3

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

llm_utils/graph.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from typing_extensions import TypedDict, Annotated
55
from langgraph.graph import END, StateGraph
66
from langgraph.graph.message import add_messages
7+
from langchain.chains.sql_database.prompt import SQL_PROMPTS
8+
from pydantic import BaseModel, Field
9+
from .llm_factory import get_llm
710

811
from llm_utils.chains import (
912
query_refiner_chain,
@@ -102,14 +105,44 @@ def query_maker_node(state: QueryMakerState):
102105
return state
103106

104107

108+
class SQLResult(BaseModel):
109+
sql: str = Field(description="SQL 쿼리 문자열")
110+
explanation: str = Field(description="SQL 쿼리 설명")
111+
112+
113+
def query_maker_node_with_db_guide(state: QueryMakerState):
114+
sql_prompt = SQL_PROMPTS[state["user_database_env"]]
115+
llm = get_llm(
116+
model_type="openai",
117+
model_name="gpt-4o-mini",
118+
openai_api_key=os.getenv("OPENAI_API_KEY"),
119+
)
120+
chain = sql_prompt | llm.with_structured_output(SQLResult)
121+
res = chain.invoke(
122+
input={
123+
"input": "\n\n---\n\n".join(
124+
[state["messages"][0].content] + [state["refined_input"].content]
125+
),
126+
"table_info": [json.dumps(state["searched_tables"])],
127+
"top_k": 10,
128+
}
129+
)
130+
state["generated_query"] = res.sql
131+
state["messages"].append(res.explanation)
132+
return state
133+
134+
105135
# StateGraph 생성 및 구성
106136
builder = StateGraph(QueryMakerState)
107137
builder.set_entry_point(QUERY_REFINER)
108138

109139
# 노드 추가
110140
builder.add_node(QUERY_REFINER, query_refiner_node)
111141
builder.add_node(GET_TABLE_INFO, get_table_info_node)
112-
builder.add_node(QUERY_MAKER, query_maker_node)
142+
# builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
143+
builder.add_node(
144+
QUERY_MAKER, query_maker_node_with_db_guide
145+
) # query_maker_node_with_db_guide
113146

114147
# 기본 엣지 설정
115148
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)

llm_utils/llm_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_llm(
1717
if model_type == "openai":
1818
return ChatOpenAI(
1919
model=model_name,
20-
openai_api_key=openai_api_key,
20+
api_key=openai_api_key,
2121
**kwargs,
2222
)
2323

0 commit comments

Comments
 (0)