Skip to content

Commit 762088f

Browse files
Merge pull request #23 from K-dash/add-utils-tests-coverage
Add comprehensive tests for utils modules to improve code coverage
2 parents 782b6db + c5189ed commit 762088f

File tree

7 files changed

+1115
-0
lines changed

7 files changed

+1115
-0
lines changed

tests/evaluation/test_evaluator_stability.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,108 @@ async def test_evaluate_with_mock(monkeypatch, temp_dir, dummy_db_file):
6565
assert report["total"] == 1
6666
assert report["matches"] == 1
6767
assert report["accuracy"] == 1.0
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_connect_with_nonexistent_db():
72+
"""Test that connecting to non-existent database raises FileNotFoundError."""
73+
evaluator = LLMSQLEvaluator()
74+
with pytest.raises(FileNotFoundError, match="Database not found"):
75+
evaluator.connect("/nonexistent/path/to/database.db")
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_evaluate_saves_report(monkeypatch, temp_dir, dummy_db_file):
80+
"""Test that save_report parameter creates a JSON report file."""
81+
evaluator = LLMSQLEvaluator(workdir_path=temp_dir)
82+
83+
# Setup test files
84+
questions_path = temp_dir / "questions.jsonl"
85+
questions_path.write_text(
86+
json.dumps({"question_id": 1, "question": "Test", "sql": "SELECT 1"})
87+
)
88+
89+
outputs_path = temp_dir / "outputs.jsonl"
90+
outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 1"}))
91+
92+
report_path = temp_dir / "report.json"
93+
94+
# Mock dependencies
95+
monkeypatch.setattr(
96+
"llmsql.evaluation.evaluator.evaluate_sample",
97+
lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
98+
)
99+
monkeypatch.setattr("llmsql.evaluation.evaluator.log_mismatch", lambda **k: None)
100+
monkeypatch.setattr(
101+
"llmsql.evaluation.evaluator.print_summary", lambda *a, **k: None
102+
)
103+
104+
evaluator.evaluate(
105+
outputs_path=str(outputs_path),
106+
questions_path=str(questions_path),
107+
db_path=dummy_db_file,
108+
save_report=str(report_path),
109+
show_mismatches=False,
110+
)
111+
112+
# Verify report file was created
113+
assert report_path.exists()
114+
with open(report_path, encoding="utf-8") as f:
115+
saved_report = json.load(f)
116+
assert saved_report["total"] == 1
117+
assert saved_report["accuracy"] == 1.0
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_evaluate_with_mismatches(monkeypatch, temp_dir, dummy_db_file):
122+
"""Test that mismatches are logged when show_mismatches=True."""
123+
evaluator = LLMSQLEvaluator(workdir_path=temp_dir)
124+
125+
# Setup test files
126+
questions_path = temp_dir / "questions.jsonl"
127+
questions_path.write_text(
128+
json.dumps({"question_id": 1, "question": "Test", "sql": "SELECT 1"})
129+
)
130+
131+
outputs_path = temp_dir / "outputs.jsonl"
132+
outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 2"}))
133+
134+
mismatch_logged = []
135+
136+
def mock_log_mismatch(**kwargs):
137+
mismatch_logged.append(kwargs)
138+
139+
# Mock dependencies - return mismatch
140+
monkeypatch.setattr(
141+
"llmsql.evaluation.evaluator.evaluate_sample",
142+
lambda *a, **k: (
143+
0,
144+
{
145+
"question_id": 1,
146+
"question": "Test",
147+
"gold_sql": "SELECT 1",
148+
"model_output": "SELECT 2",
149+
"gold_results": [(1,)],
150+
"prediction_results": [(2,)],
151+
},
152+
{"pred_none": 0, "gold_none": 0, "sql_error": 0},
153+
),
154+
)
155+
monkeypatch.setattr("llmsql.evaluation.evaluator.log_mismatch", mock_log_mismatch)
156+
monkeypatch.setattr(
157+
"llmsql.evaluation.evaluator.print_summary", lambda *a, **k: None
158+
)
159+
160+
report = evaluator.evaluate(
161+
outputs_path=str(outputs_path),
162+
questions_path=str(questions_path),
163+
db_path=dummy_db_file,
164+
show_mismatches=True,
165+
max_mismatches=5,
166+
)
167+
168+
# Verify mismatch was logged
169+
assert len(mismatch_logged) == 1
170+
assert mismatch_logged[0]["question_id"] == 1
171+
assert report["matches"] == 0
172+
assert len(report["mismatches"]) == 1

tests/finetune/test_finetune_stability.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,125 @@ def save_model(self, path):
136136
)
137137

138138
assert (tmp_path / "out" / "final_model").exists()
139+
140+
141+
@pytest.mark.asyncio
142+
async def test_flatten_nested_config():
143+
"""Test that nested YAML config is flattened correctly."""
144+
from llmsql.finetune.finetune import parse_args_and_config
145+
146+
# Create a config with nested structure
147+
config = {
148+
"training": {"num_epochs": 10, "learning_rate": 0.001},
149+
"model": {"name": "gpt2", "size": "small"},
150+
"output_dir": "test_out",
151+
}
152+
153+
import sys
154+
import tempfile
155+
156+
import yaml
157+
158+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
159+
yaml.dump(config, f)
160+
config_path = f.name
161+
162+
old_argv = sys.argv
163+
try:
164+
sys.argv = ["prog", "--config_file", config_path]
165+
args = parse_args_and_config()
166+
167+
# Flattened keys should use underscore separator
168+
assert args.training_num_epochs == 10
169+
assert args.training_learning_rate == 0.001
170+
assert args.model_name == "gpt2"
171+
assert args.model_size == "small"
172+
assert args.output_dir == "test_out"
173+
finally:
174+
sys.argv = old_argv
175+
import os
176+
177+
os.unlink(config_path)
178+
179+
180+
@pytest.mark.asyncio
181+
async def test_main_with_wandb_env_vars(tmp_path, monkeypatch):
182+
"""Test that WandB environment variables are set correctly."""
183+
import os
184+
185+
# Clear any existing WandB env vars
186+
for key in [
187+
"WANDB_PROJECT",
188+
"WANDB_RUN_NAME",
189+
"WANDB_API_KEY",
190+
"WANDB_MODE",
191+
]:
192+
if key in os.environ:
193+
del os.environ[key]
194+
195+
train_file = tmp_path / "train.jsonl"
196+
val_file = tmp_path / "val.jsonl"
197+
tables_file = tmp_path / "tables.jsonl"
198+
199+
train_file.write_text(
200+
json.dumps({"question": "Q?", "sql": "SELECT 1", "table_id": "t1"}) + "\n"
201+
)
202+
val_file.write_text(
203+
json.dumps({"question": "Q?", "sql": "SELECT 2", "table_id": "t1"}) + "\n"
204+
)
205+
tables_file.write_text(
206+
json.dumps(
207+
{"table_id": "t1", "header": ["c"], "types": ["text"], "rows": [["r"]]}
208+
)
209+
+ "\n"
210+
)
211+
212+
# Mock all the heavy dependencies
213+
monkeypatch.setattr(
214+
"llmsql.finetune.finetune.load_jsonl",
215+
lambda f: [json.loads(line) for line in open(f)],
216+
)
217+
monkeypatch.setattr(
218+
"llmsql.finetune.finetune.choose_prompt_builder",
219+
lambda shots: lambda q, h, t, r: "PROMPT",
220+
)
221+
monkeypatch.setattr(
222+
"llmsql.finetune.finetune.AutoModelForCausalLM",
223+
type("FakeModel", (), {"from_pretrained": lambda *a, **k: "MODEL"}),
224+
)
225+
monkeypatch.setattr(
226+
"llmsql.finetune.finetune.SFTConfig", lambda **kwargs: {"args": kwargs}
227+
)
228+
229+
class FakeTrainer:
230+
def __init__(self, model, train_dataset, eval_dataset, args):
231+
pass
232+
233+
def train(self):
234+
return "trained"
235+
236+
def save_model(self, path):
237+
Path(path).mkdir(parents=True, exist_ok=True)
238+
239+
monkeypatch.setattr("llmsql.finetune.finetune.SFTTrainer", FakeTrainer)
240+
241+
# Call main with WandB parameters
242+
finetune.main(
243+
model_name_or_path="gpt2",
244+
output_dir=str(tmp_path / "out"),
245+
train_file=str(train_file),
246+
val_file=str(val_file),
247+
tables_file=str(tables_file),
248+
shots=1,
249+
num_train_epochs=1,
250+
wandb_project="test_project",
251+
wandb_run_name="test_run",
252+
wandb_key="test_key",
253+
wandb_offline=True,
254+
)
255+
256+
# Verify environment variables were set
257+
assert os.environ.get("WANDB_PROJECT") == "test_project"
258+
assert os.environ.get("WANDB_RUN_NAME") == "test_run"
259+
assert os.environ.get("WANDB_API_KEY") == "test_key"
260+
assert os.environ.get("WANDB_MODE") == "offline"

tests/test_init.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Tests for llmsql package initialization and lazy imports."""
2+
3+
import pytest
4+
5+
6+
class TestLazyImport:
7+
"""Test lazy import mechanism in __init__.py."""
8+
9+
def test_lazy_import_evaluator(self) -> None:
10+
"""Test that LLMSQLEvaluator can be imported via lazy loading."""
11+
from llmsql import LLMSQLEvaluator
12+
13+
assert LLMSQLEvaluator is not None
14+
# Verify it's the correct class
15+
assert LLMSQLEvaluator.__name__ == "LLMSQLEvaluator"
16+
17+
def test_lazy_import_inference_skipped_if_vllm_missing(self) -> None:
18+
"""Test that LLMSQLVLLMInference import fails gracefully without vllm."""
19+
try:
20+
from llmsql import LLMSQLVLLMInference
21+
22+
# If vllm is installed, should succeed
23+
assert LLMSQLVLLMInference is not None
24+
except ModuleNotFoundError as e:
25+
# If vllm not installed, should raise ModuleNotFoundError
26+
assert "vllm" in str(e).lower()
27+
28+
def test_invalid_attribute_raises_error(self) -> None:
29+
"""Test that accessing invalid attribute raises AttributeError."""
30+
import llmsql
31+
32+
with pytest.raises(
33+
AttributeError, match="module .* has no attribute 'NonExistentClass'"
34+
):
35+
_ = llmsql.NonExistentClass # type: ignore
36+
37+
def test_version_attribute(self) -> None:
38+
"""Test that __version__ is accessible."""
39+
import llmsql
40+
41+
assert hasattr(llmsql, "__version__")
42+
assert isinstance(llmsql.__version__, str)
43+
# Should match semantic versioning pattern
44+
assert len(llmsql.__version__.split(".")) >= 2
45+
46+
def test_all_exports(self) -> None:
47+
"""Test that __all__ contains expected exports."""
48+
import llmsql
49+
50+
assert hasattr(llmsql, "__all__")
51+
assert "LLMSQLEvaluator" in llmsql.__all__
52+
assert "LLMSQLVLLMInference" in llmsql.__all__

0 commit comments

Comments
 (0)