Skip to content

Commit 98dc147

Browse files
more tests added
1 parent 8d01bf4 commit 98dc147

File tree

4 files changed

+154
-1
lines changed

4 files changed

+154
-1
lines changed

pdm.lock

Lines changed: 15 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ dev = [
6161
"sphinx-autoapi>=3.6.0",
6262
"sphinx-autobuild>=2024.10.3",
6363
"sphinx-copybutton>=0.5.2",
64+
"pytest-mock>=3.15.1",
6465
]
6566
vllm = [
6667
"vllm>=0.4.2",

tests/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
import os
3+
from pathlib import Path
34
import sqlite3
5+
from unittest.mock import MagicMock
46

57
import pytest
68

@@ -68,3 +70,44 @@ def fake_jsonl_files(tmp_path):
6870
tpath.write_text("\n".join(json.dumps(t) for t in tables))
6971

7072
return str(qpath), str(tpath)
73+
74+
75+
@pytest.fixture
76+
def mock_utils(mocker, tmp_path):
77+
"""Mock all underlying I/O + DB functions."""
78+
# load questions
79+
mocker.patch(
80+
"llmsql.evaluation.evaluate.load_jsonl_dict_by_key",
81+
return_value={1: {"question_id": 1, "gold": "SELECT 1"}},
82+
)
83+
84+
# predictions loader
85+
mocker.patch(
86+
"llmsql.evaluation.evaluate.load_jsonl",
87+
return_value=[{"question_id": 1, "completion": "SELECT 1"}],
88+
)
89+
90+
# DB connection
91+
fake_conn = MagicMock()
92+
mocker.patch("llmsql.evaluation.evaluate.connect_sqlite", return_value=fake_conn)
93+
94+
# evaluate_sample → always correct prediction
95+
mocker.patch(
96+
"llmsql.evaluation.evaluate.evaluate_sample",
97+
return_value=(1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
98+
)
99+
100+
# rich logging
101+
mocker.patch("llmsql.evaluation.evaluate.log_mismatch")
102+
mocker.patch("llmsql.evaluation.evaluate.print_summary")
103+
104+
# download files
105+
mocker.patch(
106+
"llmsql.evaluation.evaluate.download_benchmark_file",
107+
side_effect=lambda filename, wd: str(Path(wd) / filename),
108+
)
109+
110+
# report writer
111+
mocker.patch("llmsql.evaluation.evaluate.save_json_report")
112+
113+
return tmp_path

tests/evaluation/test_evaluator_stability.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,98 @@ async def test_evaluate_with_dict_list(monkeypatch, temp_dir, dummy_db_file):
168168
assert report["matches"] == 1
169169
assert report["accuracy"] == 1.0
170170
assert report["input_mode"] == "dict_list"
171+
172+
173+
def test_evaluate_with_list_outputs(mock_utils, mocker):
174+
outputs = [{"question_id": 1, "completion": "SELECT 1"}]
175+
176+
report = evaluate(outputs, workdir_path=str(mock_utils))
177+
178+
assert report["total"] == 1
179+
assert report["matches"] == 1
180+
assert report["accuracy"] == 1.0
181+
assert report["input_mode"] == "dict_list"
182+
183+
184+
def test_evaluate_with_jsonl_path(mock_utils, mocker):
185+
# fake jsonl file
186+
jsonl_path = mock_utils / "preds.jsonl"
187+
jsonl_path.write_text("dummy", encoding="utf-8")
188+
189+
report = evaluate(str(jsonl_path), workdir_path=str(mock_utils))
190+
191+
assert report["total"] == 1
192+
assert report["input_mode"] == "jsonl_path"
193+
194+
195+
def test_missing_workdir_and_no_questions_path_raises():
196+
with pytest.raises(ValueError):
197+
evaluate(
198+
outputs=[{"question_id": 1, "completion": "x"}],
199+
workdir_path=None,
200+
questions_path=None,
201+
)
202+
203+
204+
def test_missing_workdir_and_no_db_path_raises():
205+
with pytest.raises(ValueError):
206+
evaluate(
207+
outputs=[{"question_id": 1, "completion": "x"}],
208+
workdir_path=None,
209+
db_path=None,
210+
)
211+
212+
213+
def test_download_occurs_if_files_missing(mock_utils, mocker):
214+
dl = mocker.patch("llmsql.evaluation.evaluate.download_benchmark_file")
215+
216+
evaluate(
217+
[{"question_id": 1, "completion": "SELECT 1"}],
218+
workdir_path=str(mock_utils),
219+
questions_path=None,
220+
db_path=None,
221+
)
222+
223+
assert dl.call_count == 2 # questions + sqlite
224+
225+
226+
def test_saves_report_with_auto_filename(mock_utils, mocker):
227+
save = mocker.patch("llmsql.evaluation.evaluate.save_json_report")
228+
229+
report = evaluate(
230+
[{"question_id": 1, "completion": "SELECT 1"}],
231+
workdir_path=str(mock_utils),
232+
save_report=None,
233+
)
234+
235+
# automatic UUID-based filename
236+
args, kwargs = save.call_args
237+
auto_filename = args[0]
238+
assert auto_filename.startswith("evaluation_results_")
239+
assert auto_filename.endswith(".json")
240+
241+
assert report["total"] == 1
242+
243+
244+
def test_mismatch_handling(mock_utils, mocker):
245+
"""Test branch where a mismatch is returned."""
246+
mocker.patch(
247+
"llmsql.evaluation.evaluate.evaluate_sample",
248+
return_value=(
249+
0,
250+
{"info": "bad"},
251+
{"pred_none": 0, "gold_none": 0, "sql_error": 0},
252+
),
253+
)
254+
255+
log_mis = mocker.patch("llmsql.evaluation.evaluate.log_mismatch")
256+
257+
report = evaluate(
258+
[{"question_id": 1, "completion": "SELECT X"}],
259+
workdir_path=str(mock_utils),
260+
max_mismatches=3,
261+
)
262+
263+
assert report["matches"] == 0
264+
assert len(report["mismatches"]) == 1
265+
log_mis.assert_called_once()

0 commit comments

Comments
 (0)