Skip to content

Commit ae1589e

Browse files
committed
feat: 평가 질문을 lang2sql로 처리하고 저장하는 기능 추가
1 parent 2eb4472 commit ae1589e

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

evaluation/gen_answer.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from argparse import ArgumentParser
2+
from langchain_core.messages import HumanMessage
3+
4+
from utils import load_question_json, save_answer_json
5+
6+
from tqdm import tqdm
7+
import uuid
8+
9+
from llm_utils.graph import builder
10+
11+
12+
def get_eval_result(
13+
graph,
14+
name=None,
15+
version=None,
16+
desc="",
17+
debug=False,
18+
input_dir="data/questions",
19+
output_dir="data/q_sql",
20+
):
21+
22+
if name is None:
23+
# random name
24+
name = str(uuid.uuid4())
25+
26+
if version is None:
27+
version = "0.0.1"
28+
29+
results = load_question_json(input_dir)
30+
31+
for i, result in tqdm(enumerate(results), desc="Processing results"):
32+
inputs = []
33+
for question in result["questions"]:
34+
inputs.append(
35+
{
36+
"messages": [HumanMessage(content=question)],
37+
"user_database_env": "duckdb",
38+
"best_practice_query": "",
39+
}
40+
)
41+
response = graph.batch(inputs)
42+
answers = []
43+
for res in response:
44+
refined_input_content = (
45+
res["refined_input"].content
46+
if hasattr(res["refined_input"], "content")
47+
else res["refined_input"]
48+
)
49+
answers.append(
50+
{
51+
"user_database_env": res["user_database_env"],
52+
"answer_SQL": res["generated_query"],
53+
"answer_explanation": res["messages"][-1].content,
54+
"question_refined": refined_input_content,
55+
"searched_tables": res["searched_tables"],
56+
}
57+
)
58+
59+
# debug 모드일 때 결과를 print로 확인
60+
if debug:
61+
print(f"질문: {result['questions']}")
62+
print(f"답변: {answers}")
63+
64+
result["answers"] = answers
65+
result["name"] = name
66+
result["version"] = version
67+
result["desc"] = desc
68+
69+
save_answer_json(result, f"{output_dir}/{name}_{version}", i)
70+
71+
72+
if __name__ == "__main__":
73+
parser = ArgumentParser()
74+
parser.add_argument("--input_dir", type=str, default="data/questions")
75+
parser.add_argument("--output_dir", type=str, default="data/q_sql")
76+
parser.add_argument("--name", type=str, default=None)
77+
parser.add_argument("--version", type=str, default=None)
78+
parser.add_argument("--desc", type=str, default="")
79+
parser.add_argument("--debug", type=bool, default=False)
80+
args = parser.parse_args()
81+
82+
graph = builder.compile() # langgraph 모델 load하여 사용하세요
83+
84+
get_eval_result(
85+
graph,
86+
name=args.name,
87+
version=args.version,
88+
desc=args.desc,
89+
input_dir=args.input_dir,
90+
output_dir=args.output_dir,
91+
debug=args.debug,
92+
)

0 commit comments

Comments
 (0)