Skip to content

Commit 8d01bf4

Browse files
new test
1 parent 1545385 commit 8d01bf4

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

tests/evaluation/test_evaluator_stability.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,88 @@ async def test_evaluate_saves_report(monkeypatch, temp_dir, dummy_db_file):
8383
saved_report = json.load(f)
8484
assert saved_report["total"] == 1
8585
assert saved_report["accuracy"] == 1.0
86+
87+
88+
@pytest.mark.asyncio
89+
async def test_evaluate_with_jsonl_file(monkeypatch, temp_dir, dummy_db_file):
90+
# Create fake questions.jsonl
91+
questions_path = temp_dir / "questions.jsonl"
92+
questions_path.write_text(
93+
json.dumps(
94+
{"question_id": 1, "table_id": 1, "question": "Sample", "sql": "SELECT 1"}
95+
)
96+
)
97+
98+
# Create fake outputs.jsonl
99+
outputs_path = temp_dir / "outputs.jsonl"
100+
outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 1"}))
101+
102+
# Mock dependencies
103+
monkeypatch.setattr(
104+
"llmsql.utils.evaluation_utils.evaluate_sample",
105+
lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
106+
)
107+
monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None)
108+
monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None)
109+
monkeypatch.setattr(
110+
"llmsql.utils.evaluation_utils.connect_sqlite", lambda db: dummy_db_file
111+
)
112+
monkeypatch.setattr(
113+
"llmsql.utils.utils.save_json_report", lambda path, report: None
114+
)
115+
116+
report = evaluate(
117+
outputs=str(outputs_path),
118+
questions_path=str(questions_path),
119+
db_path=dummy_db_file,
120+
show_mismatches=False,
121+
)
122+
123+
assert report["total"] == 1
124+
assert report["matches"] == 1
125+
assert report["accuracy"] == 1.0
126+
assert report["input_mode"] == "jsonl_path"
127+
128+
129+
@pytest.mark.asyncio
130+
async def test_evaluate_with_dict_list(monkeypatch, temp_dir, dummy_db_file):
131+
# Prepare fake questions dict
132+
questions_path = temp_dir / "questions.jsonl"
133+
questions_path.write_text(
134+
json.dumps(
135+
{"question_id": 1, "table_id": 1, "question": "Sample", "sql": "SELECT 1"}
136+
)
137+
)
138+
139+
# Output as a list of dicts
140+
outputs_list = [{"question_id": 1, "completion": "SELECT 1"}]
141+
142+
# Mock dependencies
143+
monkeypatch.setattr(
144+
"llmsql.utils.evaluation_utils.evaluate_sample",
145+
lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
146+
)
147+
monkeypatch.setattr("llmsql.utils.rich_utils.log_mismatch", lambda **k: None)
148+
monkeypatch.setattr("llmsql.utils.rich_utils.print_summary", lambda *a, **k: None)
149+
monkeypatch.setattr(
150+
"llmsql.utils.evaluation_utils.connect_sqlite", lambda db: dummy_db_file
151+
)
152+
monkeypatch.setattr(
153+
"llmsql.utils.utils.load_jsonl_dict_by_key",
154+
lambda path, key: {1: {"question_id": 1}},
155+
)
156+
monkeypatch.setattr(
157+
"llmsql.utils.utils.save_json_report", lambda path, report: None
158+
)
159+
160+
report = evaluate(
161+
outputs=outputs_list,
162+
questions_path=str(questions_path),
163+
db_path=dummy_db_file,
164+
show_mismatches=False,
165+
)
166+
167+
assert report["total"] == 1
168+
assert report["matches"] == 1
169+
assert report["accuracy"] == 1.0
170+
assert report["input_mode"] == "dict_list"

0 commit comments

Comments
 (0)