Skip to content

Commit a8e853d

Browse files
committed
feat: Integrate MLflow for SQL evaluation and enhance CLI functionality
- Add MLflow dependency for tracking evaluation metrics - Implement evaluation command in CLI for SQL generation model - Create QADataset class for loading and managing evaluation datasets - Develop Evaluator class to handle SQL evaluation and logging to MLflow - Add documentation for MLflow setup and evaluation CLI usage - Introduce example dataset for testing evaluation functionality #15
1 parent 2c11d01 commit a8e853d

File tree

12 files changed

+283
-2
lines changed

12 files changed

+283
-2
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ __pycache__/
33
build/
44
lang2sql.egg-info/
55
dist/
6-
.pypirc
6+
.pypirc
7+
mlruns
8+
table_info_db

cli/__init__.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import click
22
import subprocess
33
from llm_utils.tools import set_gms_server
4+
from langchain_core.messages import HumanMessage
5+
from llm_evaluation.evaluator import Evaluator
6+
from llm_utils.graph import builder
7+
import os
48

59

610
@click.group()
@@ -11,12 +15,17 @@
1115
)
1216
@click.option("--run-streamlit", is_flag=True, help="Run the Streamlit app.")
1317
@click.option("-p", "--port", type=int, default=8501, help="Streamlit port")
14-
def cli(ctx, datahub_server, run_streamlit, port):
18+
@click.option("--mlflow-tracking-uri", default=None, help="MLflow 트래킹 서버 URI")
19+
def cli(ctx, datahub_server, run_streamlit, port, mlflow_tracking_uri):
1520
try:
1621
set_gms_server(datahub_server)
1722
except ValueError as e:
1823
click.echo(str(e))
1924
ctx.exit(1)
25+
26+
if mlflow_tracking_uri:
27+
os.environ["MLFLOW_TRACKING_URI"] = mlflow_tracking_uri
28+
2029
if run_streamlit:
2130
run_streamlit_command(port)
2231

@@ -33,3 +42,31 @@ def run_streamlit_command(port):
3342
def run_streamlit(port):
3443
"""Run the Streamlit app."""
3544
run_streamlit_command(port)
45+
46+
47+
@cli.command()
48+
@click.argument("dataset_path", type=click.Path(exists=True))
49+
@click.option(
50+
"--user-database-env", default="clickhouse", help="사용자 데이터베이스 환경"
51+
)
52+
def evaluate(dataset_path, user_database_env):
53+
"""SQL 생성 모델을 평가합니다."""
54+
click.echo(f"데이터셋 {dataset_path}로 평가를 시작합니다...")
55+
56+
evaluator = Evaluator(dataset_path)
57+
58+
def generated_sql_fn(question: str):
59+
graph = builder.compile()
60+
61+
res = graph.invoke(
62+
input={
63+
"messages": [HumanMessage(content=question)],
64+
"user_database_env": user_database_env,
65+
"best_practice_query": "",
66+
}
67+
)
68+
69+
return res["generated_query"].content
70+
71+
results = evaluator.evaluate(generated_sql_fn)
72+
click.echo(f"평가 완료! {len(results)}개 쿼리 평가됨")

docs/evaluation.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Evaluation
2+
3+
## MLFlow
4+
5+
### 실행
6+
7+
- `docker run --name mlflow_postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=mlflow_db -p 5430:5432 -d postgres`
8+
- `mlflow server --backend-store-uri postgresql://postgres:postgres@localhost:5430/mlflow_db --default-artifact-root ./mlruns --host 0.0.0.0 --port 5000`
9+
- `backed-store-uri`: MLflow가 메타데이터를 저장하는 데이터베이스의 URI를 지정합니다. 예를 들어, PostgreSQL 데이터베이스를 사용할 경우 `postgresql://<username>:<password>@<hostname>:<port>/<database_name>` 형식으로 지정합니다.
10+
- `default-artifact-root`: MLflow가 모델 아티팩트를 저장할 기본 경로를 지정합니다. 예를 들어, 로컬 파일 시스템을 사용할 경우 `./mlruns`와 같이 지정할 수 있습니다.
11+
12+
13+
## 평가 실행하기
14+
15+
### CLI 명령어
16+
17+
SQL 생성 모델을 평가하기 위해 다음과 같은 CLI 명령어를 사용할 수 있습니다:
18+
```
19+
lang2sql evaluate /path/to/dataset.json --user-database-env clickhouse
20+
```
21+
22+
### 결과 예시
23+
24+
![MLFlow](./mlflow.png)

docs/mlflow.png

86.9 KB
Loading

llm_evaluation/__init__.py

Whitespace-only changes.

llm_evaluation/dataset.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import json
2+
import pandas as pd
3+
import os
4+
5+
6+
class QADataset:
7+
def __init__(self, dataset_path):
8+
self.dataset_path = dataset_path
9+
self.data = self._load_dataset()
10+
self.results_path = dataset_path.replace(".json", "_results.json")
11+
12+
def _load_dataset(self):
13+
"""JSON 파일에서 Question-Answer 데이터셋을 로드"""
14+
with open(self.dataset_path, "r", encoding="utf-8") as f:
15+
data = json.load(f)
16+
if isinstance(data, list):
17+
return pd.DataFrame(data)
18+
elif isinstance(data, dict):
19+
return pd.DataFrame.from_dict(data)
20+
else:
21+
raise ValueError(
22+
"지원되지 않는 JSON 형식입니다. 리스트 또는 딕셔너리 형식이어야 합니다."
23+
)
24+
25+
def get_samples(self):
26+
"""원본 데이터셋의 질문, 정답 SQL, 평가 타입(evaluation_type) 정보를 반환"""
27+
if "evaluation_type" in self.data.columns:
28+
eval_types = self.data["evaluation_type"].tolist()
29+
else:
30+
eval_types = [None] * len(self.data["inputs"])
31+
32+
for i in range(len(self.data["inputs"])):
33+
yield self.data["inputs"][i], self.data["ground_truths"][i], eval_types[i]
34+
35+
def save_feedback(self, feedback_data):
36+
"""평가 결과를 별도 파일에 저장"""
37+
results = []
38+
if os.path.exists(self.results_path):
39+
with open(self.results_path, "r", encoding="utf-8") as f:
40+
results = json.load(f)
41+
42+
results.append(feedback_data)
43+
44+
with open(self.results_path, "w", encoding="utf-8") as f:
45+
json.dump(results, f, ensure_ascii=False, indent=4)

llm_evaluation/evaluator.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import time
2+
import mlflow
3+
import os
4+
import importlib.metadata
5+
from langchain_core.messages import HumanMessage
6+
7+
import numpy as np
8+
from llm_evaluation.dataset import QADataset
9+
from llm_evaluation.llm_evaluator import compare_sql_with_llm
10+
from llm_evaluation.mlflow_logger import log_to_mlflow
11+
12+
13+
class Evaluator:
14+
def __init__(self, dataset_path):
15+
self.dataset = QADataset(dataset_path)
16+
17+
def evaluate(self, generated_sql_fn):
18+
"""Lang2SQL 평가 함수 (사용자가 SQL 생성 함수를 제공)
19+
20+
각 평가 샘플은 nested run으로 기록됩니다.
21+
"""
22+
results = []
23+
metrics_by_type = {} # evaluation_type별 점수를 저장할 dict
24+
25+
# MLflow 설정: tracking URI와 experiment 이름 설정
26+
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
27+
try:
28+
lang2sql_version = importlib.metadata.version("lang2sql")
29+
except importlib.metadata.PackageNotFoundError:
30+
lang2sql_version = "unknown"
31+
experiment_name = f"lang2sql-evaluation-v{lang2sql_version}"
32+
mlflow.set_experiment(experiment_name)
33+
34+
# 전체 평가를 하나의 부모 run으로 감싸기
35+
with mlflow.start_run(run_name="evaluation_run") as parent_run:
36+
for (
37+
question,
38+
ground_truth_sql,
39+
evaluation_type,
40+
) in self.dataset.get_samples():
41+
start_time = time.time()
42+
generated_sql = generated_sql_fn(question)
43+
exec_time = time.time() - start_time
44+
45+
# LLM 평가 결과 (현재 단일 점수)를 dict로 기록
46+
llm_score = compare_sql_with_llm(
47+
generated_sql, ground_truth_sql, question
48+
)
49+
# evaluation_type별로 점수를 집계
50+
if evaluation_type not in metrics_by_type:
51+
metrics_by_type[evaluation_type] = []
52+
metrics_by_type[evaluation_type].append(llm_score)
53+
54+
feedback_data = {
55+
"question": question,
56+
"generated_sql": generated_sql,
57+
"ground_truth_sql": ground_truth_sql,
58+
"llm_evaluation_metric": llm_score, # 각 쿼리별 metric은 dict 형태로 기록
59+
"execution_time": exec_time,
60+
"evaluation_type": evaluation_type,
61+
}
62+
63+
# 각 샘플 평가를 nested run으로 기록 (run 이름으로 evaluation_type 사용)
64+
with mlflow.start_run(nested=True, run_name=str(evaluation_type)):
65+
log_to_mlflow(
66+
question,
67+
generated_sql,
68+
ground_truth_sql,
69+
llm_score,
70+
evaluation_type,
71+
)
72+
mlflow.log_metric("execution_time", exec_time)
73+
74+
self.dataset.save_feedback(feedback_data)
75+
results.append(feedback_data)
76+
77+
# 각 evaluation_type별로 집계한 metric 계산 (평균, 최고, 최저, 중앙값)
78+
aggregated_metrics = {}
79+
for eval_type, scores in metrics_by_type.items():
80+
81+
for idx, score in enumerate(scores):
82+
mlflow.log_metric(f"{eval_type}", score, step=idx)
83+
84+
# aggregated_metrics를 태그로도 기록 (문자열로 변환)
85+
mlflow.set_tag("aggregated_metrics", str(aggregated_metrics))
86+
87+
return results

llm_evaluation/example.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"evaluation_type": ["subscription", "hired_sdr", "subscription"],
3+
"inputs": [
4+
"구독서비스를 이용하는 유니크한 유저수",
5+
"회사의 고용된 영업 SDR 수가 궁금해",
6+
"subscription을 시작한 무료 사용자(free_users) 중 각 플랜별 사용자 수를 알고 싶습니다."
7+
],
8+
"ground_truths": [
9+
"SELECT toStartOfMonth(activity_ts) AS month, COUNT(DISTINCT entity_id) AS unique_users FROM client_stream_active_on_subscription WHERE activity_ts >= date_sub(current_date(), interval 3 month) GROUP BY month ORDER BY month",
10+
"SELECT COUNT(DISTINCT entity_id) AS total_hired_sdrs FROM company_stream_hired_sdr WHERE activity_ts >= now() - INTERVAL 3 MONTH",
11+
"SELECT feature_json, COUNT(entity_id) AS user_count FROM client_stream_started_subscription WHERE feature_json LIKE '%\"free_users\"%' GROUP BY feature_json"
12+
]
13+
}

llm_evaluation/llm_evaluator.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
from langchain_core.prompts import ChatPromptTemplate
3+
from llm_utils.llm_factory import get_llm
4+
5+
6+
def compare_sql_with_llm(generated_sql, ground_truth_sql, user_query):
7+
"""LLM을 사용하여 SQL 평가 (0 ~ 1 점수)"""
8+
9+
# LLM 초기화
10+
llm = get_llm(
11+
model_type="openai",
12+
model_name="gpt-4o-mini",
13+
openai_api_key=os.getenv("OPENAI_API_KEY"),
14+
)
15+
16+
# 프롬프트 템플릿 생성
17+
prompt = ChatPromptTemplate.from_messages(
18+
[
19+
(
20+
"system",
21+
f"""
22+
당신은 SQL 전문가입니다. 다음 두 SQL 쿼리의 정확성과 유사성을 비교해주세요.
23+
유사성을 0(완전히 다름)에서 1(동일함) 사이의 척도로 평가해주세요.
24+
25+
입력 설명 (SQL이 수행해야 할 작업):
26+
{user_query}
27+
28+
정답 SQL:
29+
{ground_truth_sql}
30+
31+
생성된 SQL:
32+
{generated_sql}
33+
34+
정답과의 유사성과 생성된 SQL이 입력 설명을 올바르게 처리하는지 모두 고려하세요.
35+
0과 1 사이의 유사성 점수만 반환하고, 소수점 둘째 자리까지 반올림하세요(예: 0.75, 0.42, 1.00).
36+
설명이나 추가 텍스트를 포함하지 마세요.
37+
""",
38+
)
39+
]
40+
)
41+
42+
# LLM 체인 실행
43+
chain = prompt | llm
44+
response = chain.invoke({})
45+
46+
try:
47+
score = float(response.content.strip())
48+
return max(0, min(score, 1)) # 0~1 사이 값으로 정규화
49+
except:
50+
return 0.0 # 오류 발생 시 0점 처리

llm_evaluation/mlflow_logger.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
import importlib.metadata
3+
import mlflow
4+
from dotenv import load_dotenv
5+
6+
load_dotenv()
7+
8+
9+
def log_to_mlflow(
10+
question, generated_sql, ground_truth_sql, llm_metric, evaluation_type=None
11+
):
12+
"""활성화된 run 내에서 평가 결과 기록
13+
14+
llm_metric은 dict 형태로 기록되며, mlflow.log_param은 내부적으로 문자열로 저장됩니다.
15+
"""
16+
mlflow.log_param("question", question)
17+
mlflow.log_param("generated_sql", generated_sql)
18+
mlflow.log_param("ground_truth_sql", ground_truth_sql)
19+
mlflow.log_param("llm_evaluation_metric", llm_metric)
20+
if evaluation_type is not None:
21+
mlflow.log_param("evaluation_type", evaluation_type)

0 commit comments

Comments
 (0)