Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions tests/evaluation/test_evaluator_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,108 @@ async def test_evaluate_with_mock(monkeypatch, temp_dir, dummy_db_file):
assert report["total"] == 1
assert report["matches"] == 1
assert report["accuracy"] == 1.0


@pytest.mark.asyncio
async def test_connect_with_nonexistent_db():
"""Test that connecting to non-existent database raises FileNotFoundError."""
evaluator = LLMSQLEvaluator()
with pytest.raises(FileNotFoundError, match="Database not found"):
evaluator.connect("/nonexistent/path/to/database.db")


@pytest.mark.asyncio
async def test_evaluate_saves_report(monkeypatch, temp_dir, dummy_db_file):
"""Test that save_report parameter creates a JSON report file."""
evaluator = LLMSQLEvaluator(workdir_path=temp_dir)

# Setup test files
questions_path = temp_dir / "questions.jsonl"
questions_path.write_text(
json.dumps({"question_id": 1, "question": "Test", "sql": "SELECT 1"})
)

outputs_path = temp_dir / "outputs.jsonl"
outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 1"}))

report_path = temp_dir / "report.json"

# Mock dependencies
monkeypatch.setattr(
"llmsql.evaluation.evaluator.evaluate_sample",
lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}),
)
monkeypatch.setattr("llmsql.evaluation.evaluator.log_mismatch", lambda **k: None)
monkeypatch.setattr(
"llmsql.evaluation.evaluator.print_summary", lambda *a, **k: None
)

evaluator.evaluate(
outputs_path=str(outputs_path),
questions_path=str(questions_path),
db_path=dummy_db_file,
save_report=str(report_path),
show_mismatches=False,
)

# Verify report file was created
assert report_path.exists()
with open(report_path, encoding="utf-8") as f:
saved_report = json.load(f)
assert saved_report["total"] == 1
assert saved_report["accuracy"] == 1.0


@pytest.mark.asyncio
async def test_evaluate_with_mismatches(monkeypatch, temp_dir, dummy_db_file):
"""Test that mismatches are logged when show_mismatches=True."""
evaluator = LLMSQLEvaluator(workdir_path=temp_dir)

# Setup test files
questions_path = temp_dir / "questions.jsonl"
questions_path.write_text(
json.dumps({"question_id": 1, "question": "Test", "sql": "SELECT 1"})
)

outputs_path = temp_dir / "outputs.jsonl"
outputs_path.write_text(json.dumps({"question_id": 1, "completion": "SELECT 2"}))

mismatch_logged = []

def mock_log_mismatch(**kwargs):
mismatch_logged.append(kwargs)

# Mock dependencies - return mismatch
monkeypatch.setattr(
"llmsql.evaluation.evaluator.evaluate_sample",
lambda *a, **k: (
0,
{
"question_id": 1,
"question": "Test",
"gold_sql": "SELECT 1",
"model_output": "SELECT 2",
"gold_results": [(1,)],
"prediction_results": [(2,)],
},
{"pred_none": 0, "gold_none": 0, "sql_error": 0},
),
)
monkeypatch.setattr("llmsql.evaluation.evaluator.log_mismatch", mock_log_mismatch)
monkeypatch.setattr(
"llmsql.evaluation.evaluator.print_summary", lambda *a, **k: None
)

report = evaluator.evaluate(
outputs_path=str(outputs_path),
questions_path=str(questions_path),
db_path=dummy_db_file,
show_mismatches=True,
max_mismatches=5,
)

# Verify mismatch was logged
assert len(mismatch_logged) == 1
assert mismatch_logged[0]["question_id"] == 1
assert report["matches"] == 0
assert len(report["mismatches"]) == 1
122 changes: 122 additions & 0 deletions tests/finetune/test_finetune_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,125 @@ def save_model(self, path):
)

assert (tmp_path / "out" / "final_model").exists()


@pytest.mark.asyncio
async def test_flatten_nested_config():
"""Test that nested YAML config is flattened correctly."""
from llmsql.finetune.finetune import parse_args_and_config

# Create a config with nested structure
config = {
"training": {"num_epochs": 10, "learning_rate": 0.001},
"model": {"name": "gpt2", "size": "small"},
"output_dir": "test_out",
}

import sys
import tempfile

import yaml

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name

old_argv = sys.argv
try:
sys.argv = ["prog", "--config_file", config_path]
args = parse_args_and_config()

# Flattened keys should use underscore separator
assert args.training_num_epochs == 10
assert args.training_learning_rate == 0.001
assert args.model_name == "gpt2"
assert args.model_size == "small"
assert args.output_dir == "test_out"
finally:
sys.argv = old_argv
import os

os.unlink(config_path)


@pytest.mark.asyncio
async def test_main_with_wandb_env_vars(tmp_path, monkeypatch):
"""Test that WandB environment variables are set correctly."""
import os

# Clear any existing WandB env vars
for key in [
"WANDB_PROJECT",
"WANDB_RUN_NAME",
"WANDB_API_KEY",
"WANDB_MODE",
]:
if key in os.environ:
del os.environ[key]

train_file = tmp_path / "train.jsonl"
val_file = tmp_path / "val.jsonl"
tables_file = tmp_path / "tables.jsonl"

train_file.write_text(
json.dumps({"question": "Q?", "sql": "SELECT 1", "table_id": "t1"}) + "\n"
)
val_file.write_text(
json.dumps({"question": "Q?", "sql": "SELECT 2", "table_id": "t1"}) + "\n"
)
tables_file.write_text(
json.dumps(
{"table_id": "t1", "header": ["c"], "types": ["text"], "rows": [["r"]]}
)
+ "\n"
)

# Mock all the heavy dependencies
monkeypatch.setattr(
"llmsql.finetune.finetune.load_jsonl",
lambda f: [json.loads(line) for line in open(f)],
)
monkeypatch.setattr(
"llmsql.finetune.finetune.choose_prompt_builder",
lambda shots: lambda q, h, t, r: "PROMPT",
)
monkeypatch.setattr(
"llmsql.finetune.finetune.AutoModelForCausalLM",
type("FakeModel", (), {"from_pretrained": lambda *a, **k: "MODEL"}),
)
monkeypatch.setattr(
"llmsql.finetune.finetune.SFTConfig", lambda **kwargs: {"args": kwargs}
)

class FakeTrainer:
def __init__(self, model, train_dataset, eval_dataset, args):
pass

def train(self):
return "trained"

def save_model(self, path):
Path(path).mkdir(parents=True, exist_ok=True)

monkeypatch.setattr("llmsql.finetune.finetune.SFTTrainer", FakeTrainer)

# Call main with WandB parameters
finetune.main(
model_name_or_path="gpt2",
output_dir=str(tmp_path / "out"),
train_file=str(train_file),
val_file=str(val_file),
tables_file=str(tables_file),
shots=1,
num_train_epochs=1,
wandb_project="test_project",
wandb_run_name="test_run",
wandb_key="test_key",
wandb_offline=True,
)

# Verify environment variables were set
assert os.environ.get("WANDB_PROJECT") == "test_project"
assert os.environ.get("WANDB_RUN_NAME") == "test_run"
assert os.environ.get("WANDB_API_KEY") == "test_key"
assert os.environ.get("WANDB_MODE") == "offline"
52 changes: 52 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Tests for llmsql package initialization and lazy imports."""

import pytest


class TestLazyImport:
"""Test lazy import mechanism in __init__.py."""

def test_lazy_import_evaluator(self) -> None:
"""Test that LLMSQLEvaluator can be imported via lazy loading."""
from llmsql import LLMSQLEvaluator

assert LLMSQLEvaluator is not None
# Verify it's the correct class
assert LLMSQLEvaluator.__name__ == "LLMSQLEvaluator"

def test_lazy_import_inference_skipped_if_vllm_missing(self) -> None:
"""Test that LLMSQLVLLMInference import fails gracefully without vllm."""
try:
from llmsql import LLMSQLVLLMInference

# If vllm is installed, should succeed
assert LLMSQLVLLMInference is not None
except ModuleNotFoundError as e:
# If vllm not installed, should raise ModuleNotFoundError
assert "vllm" in str(e).lower()

def test_invalid_attribute_raises_error(self) -> None:
"""Test that accessing invalid attribute raises AttributeError."""
import llmsql

with pytest.raises(
AttributeError, match="module .* has no attribute 'NonExistentClass'"
):
_ = llmsql.NonExistentClass # type: ignore

def test_version_attribute(self) -> None:
"""Test that __version__ is accessible."""
import llmsql

assert hasattr(llmsql, "__version__")
assert isinstance(llmsql.__version__, str)
# Should match semantic versioning pattern
assert len(llmsql.__version__.split(".")) >= 2

def test_all_exports(self) -> None:
"""Test that __all__ contains expected exports."""
import llmsql

assert hasattr(llmsql, "__all__")
assert "LLMSQLEvaluator" in llmsql.__all__
assert "LLMSQLVLLMInference" in llmsql.__all__
Loading