|
4 | 4 | MessagesPlaceholder, |
5 | 5 | SystemMessagePromptTemplate, |
6 | 6 | ) |
| 7 | +from pydantic import BaseModel, Field |
7 | 8 |
|
8 | 9 | from .llm_factory import get_llm |
9 | 10 |
|
10 | 11 | from dotenv import load_dotenv |
11 | 12 | from prompt.template_loader import get_prompt_template |
12 | 13 |
|
| 14 | + |
13 | 15 | env_path = os.path.join(os.getcwd(), ".env") |
14 | 16 |
|
15 | 17 | if os.path.exists(env_path): |
|
20 | 22 | llm = get_llm() |
21 | 23 |
|
22 | 24 |
|
| 25 | +class QuestionProfile(BaseModel): |
| 26 | + is_timeseries: bool = Field(description="시계열 분석 필요 여부") |
| 27 | + is_aggregation: bool = Field(description="집계 함수 필요 여부") |
| 28 | + has_filter: bool = Field(description="조건 필터 필요 여부") |
| 29 | + is_grouped: bool = Field(description="그룹화 필요 여부") |
| 30 | + has_ranking: bool = Field(description="정렬/순위 필요 여부") |
| 31 | + has_temporal_comparison: bool = Field(description="기간 비교 포함 여부") |
| 32 | + intent_type: str = Field(description="질문의 주요 의도 유형") |
| 33 | + |
| 34 | + |
23 | 35 | def create_query_refiner_chain(llm): |
24 | 36 | prompt = get_prompt_template("query_refiner_prompt") |
25 | 37 | tool_choice_prompt = ChatPromptTemplate.from_messages( |
@@ -101,6 +113,33 @@ def create_query_refiner_with_profile_chain(llm): |
101 | 113 | return tool_choice_prompt | llm |
102 | 114 |
|
103 | 115 |
|
| 116 | +from langchain.prompts import PromptTemplate |
| 117 | + |
| 118 | +profile_prompt = PromptTemplate( |
| 119 | + input_variables=["question"], |
| 120 | + template=""" |
| 121 | +You are an assistant that analyzes a user question and extracts the following profiles as JSON: |
| 122 | +- is_timeseries (boolean) |
| 123 | +- is_aggregation (boolean) |
| 124 | +- has_filter (boolean) |
| 125 | +- is_grouped (boolean) |
| 126 | +- has_ranking (boolean) |
| 127 | +- has_temporal_comparison (boolean) |
| 128 | +- intent_type (one of: trend, lookup, comparison, distribution) |
| 129 | +
|
| 130 | +Return only valid JSON matching the QuestionProfile schema. |
| 131 | +
|
| 132 | +Question: |
| 133 | +{question} |
| 134 | +""".strip(), |
| 135 | +) |
| 136 | + |
| 137 | + |
| 138 | +def create_profile_extraction_chain(llm): |
| 139 | + chain = profile_prompt | llm.with_structured_output(QuestionProfile) |
| 140 | + return chain |
| 141 | + |
| 142 | + |
104 | 143 | query_refiner_chain = create_query_refiner_chain(llm) |
105 | 144 | query_maker_chain = create_query_maker_chain(llm) |
106 | 145 | query_refiner_with_profile_chain = create_query_refiner_with_profile_chain(llm) |
|
0 commit comments