|
6 | 6 | from langgraph.graph.message import add_messages |
7 | 7 | from langchain.chains.sql_database.prompt import SQL_PROMPTS |
8 | 8 | from pydantic import BaseModel, Field |
| 9 | +from .agent import manager_agent, manager_agent_edge |
9 | 10 | from .llm_factory import get_llm |
10 | | - |
11 | 11 | from llm_utils.chains import ( |
12 | 12 | query_refiner_chain, |
13 | 13 | query_maker_chain, |
14 | 14 | ) |
15 | | - |
16 | | -from llm_utils.tools import get_info_from_db |
17 | 15 | from llm_utils.retrieval import search_tables |
| 16 | +from langchain.schema import AIMessage |
| 17 | +from .state import QueryMakerState |
18 | 18 |
|
19 | 19 | # 노드 식별자 정의 |
20 | 20 | QUERY_REFINER = "query_refiner" |
21 | 21 | GET_TABLE_INFO = "get_table_info" |
22 | 22 | TOOL = "tool" |
23 | 23 | TABLE_FILTER = "table_filter" |
24 | 24 | QUERY_MAKER = "query_maker" |
| 25 | +MANAGER_AGENT = "manager_agent" |
| 26 | +EXCEPTION_END_NODE = "exception_end_node" |
25 | 27 |
|
26 | 28 |
|
27 | | -# 상태 타입 정의 (추가 상태 정보와 메시지들을 포함) |
28 | | -class QueryMakerState(TypedDict): |
29 | | - messages: Annotated[list, add_messages] |
30 | | - user_database_env: str |
31 | | - searched_tables: dict[str, dict[str, str]] |
32 | | - best_practice_query: str |
33 | | - refined_input: str |
34 | | - generated_query: str |
35 | | - retriever_name: str |
36 | | - top_n: int |
37 | | - device: str |
| 29 | +def exception_end_node(state: QueryMakerState): |
| 30 | + intent_reason = state.get("intent_reason", "SQL 쿼리 생성을 위한 질문을 해주세요") |
| 31 | + end_message_prompt = f""" |
| 32 | +다음과 같은 이유로 답변을 할 수 없습니다! |
| 33 | +``` |
| 34 | +{intent_reason} |
| 35 | +``` |
| 36 | +""" |
| 37 | + return { |
| 38 | + "messages": state["messages"] + [AIMessage(content=end_message_prompt)], |
| 39 | + } |
| 40 | + |
38 | 41 |
|
39 | 42 |
|
40 | 43 | # 노드 함수: QUERY_REFINER 노드 |
41 | 44 | def query_refiner_node(state: QueryMakerState): |
| 45 | + # refined_node의 결과값으로 바로 AIMessages 반환 |
42 | 46 | res = query_refiner_chain.invoke( |
43 | 47 | input={ |
44 | 48 | "user_input": [state["messages"][0].content], |
@@ -67,6 +71,7 @@ def get_table_info_node(state: QueryMakerState): |
67 | 71 |
|
68 | 72 | # 노드 함수: QUERY_MAKER 노드 |
69 | 73 | def query_maker_node(state: QueryMakerState): |
| 74 | + # sturctured output 사용 |
70 | 75 | res = query_maker_chain.invoke( |
71 | 76 | input={ |
72 | 77 | "user_input": [state["messages"][0].content], |
@@ -105,19 +110,33 @@ def query_maker_node_with_db_guide(state: QueryMakerState): |
105 | 110 |
|
106 | 111 | # StateGraph 생성 및 구성 |
107 | 112 | builder = StateGraph(QueryMakerState) |
108 | | -builder.set_entry_point(GET_TABLE_INFO) |
109 | | - |
110 | 113 | # 노드 추가 |
| 114 | +builder.add_node(MANAGER_AGENT, manager_agent) |
111 | 115 | builder.add_node(GET_TABLE_INFO, get_table_info_node) |
112 | 116 | builder.add_node(QUERY_REFINER, query_refiner_node) |
113 | 117 | builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide |
| 118 | +builder.add_node(EXCEPTION_END_NODE, exception_end_node) |
114 | 119 | # builder.add_node( |
115 | 120 | # QUERY_MAKER, query_maker_node_with_db_guide |
116 | 121 | # ) # query_maker_node_with_db_guide |
117 | 122 |
|
118 | 123 | # 기본 엣지 설정 |
| 124 | +builder.set_entry_point(MANAGER_AGENT) |
119 | 125 | builder.add_edge(GET_TABLE_INFO, QUERY_REFINER) |
120 | 126 | builder.add_edge(QUERY_REFINER, QUERY_MAKER) |
121 | 127 |
|
| 128 | +# 조건부 엣지 |
| 129 | +builder.add_conditional_edges( |
| 130 | + MANAGER_AGENT, |
| 131 | + manager_agent_edge, |
| 132 | + { |
| 133 | + "end": EXCEPTION_END_NODE, |
| 134 | + "make_query": GET_TABLE_INFO, |
| 135 | + }, |
| 136 | +) |
| 137 | + |
122 | 138 | # QUERY_MAKER 노드 후 종료 |
123 | 139 | builder.add_edge(QUERY_MAKER, END) |
| 140 | + |
| 141 | +# EXCEPTION_END_NODE 노드 후 종료 |
| 142 | +builder.add_edge(EXCEPTION_END_NODE, END) |
0 commit comments