diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..6fac987 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,44 @@ +name: Tests + +on: + push: + branches: [ main ] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install PDM + run: pip install pdm + + - name: Cache PDM packages + uses: actions/cache@v4 + with: + path: ~/.cache/pdm + key: ${{ runner.os }}-pdm-${{ hashFiles('pdm.lock') }} + restore-keys: | + ${{ runner.os }}-pdm- + + - name: Install dependencies (with dev) + run: pdm install --with dev + + - name: Run tests with coverage + run: PYTHONPATH=. pdm run pytest --cov=llmsql --cov-report=xml --maxfail=1 --disable-warnings -v + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.xml + flags: unittests + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index fa55266..1148116 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ dataset/sqlite_tables.db dist/ *.egg-info/ -.pdm-python \ No newline at end of file +.pdm-python +.vscode \ No newline at end of file diff --git a/README.md b/README.md index a0e88f5..a76b0cc 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +[![codecov](https://codecov.io/gh/LLMSQL/llmsql-benchmark/branch/main/graph/badge.svg)](https://codecov.io/gh/LLMSQL/llmsql-benchmark) + + # LLMSQL Patched and improved version of the original large crowd-sourced dataset for developing natural language interfaces for relational databases, [WikiSQL](https://github.com/salesforce/WikiSQL). diff --git a/llmsql/inference/inference.py b/llmsql/inference/inference.py index 5fdc282..2fbbe09 100644 --- a/llmsql/inference/inference.py +++ b/llmsql/inference/inference.py @@ -96,6 +96,10 @@ def __init__( self.workdir_path.mkdir(parents=True, exist_ok=True) self.repo_id = "llmsql-bench/llmsql-benchmark" + if "device" not in llm_kwargs: + llm_kwargs["device"] = "cuda" if torch.cuda.is_available() else "cpu" + self.device = llm_kwargs["device"] + log.info( f"Loading vLLM model {model_name} with tensor_parallel_size={tensor_parallel_size}..." ) diff --git a/pdm.lock b/pdm.lock index 819aa26..8f56957 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:aef9ccf50c6ee815aa2e06755ce4de49d901c961bff4ff0309c6db6a33a0b9d1" +content_hash = "sha256:bfca308fdb3fdba9c187004179050b30d0547ccc05d0e6abbf1790d94c8f06ff" [[metadata.targets]] requires_python = ">=3.10" @@ -220,6 +220,18 @@ files = [ {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +requires_python = "<3.11,>=3.8" +summary = "Backport of asyncio.Runner, a context manager that controls event loop life cycle." +groups = ["dev"] +marker = "python_version < \"3.11\"" +files = [ + {file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"}, + {file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"}, +] + [[package]] name = "blake3" version = "1.0.6" @@ -3034,6 +3046,22 @@ files = [ {file = "pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01"}, ] +[[package]] +name = "pytest-asyncio" +version = "1.2.0" +requires_python = ">=3.9" +summary = "Pytest support for asyncio" +groups = ["dev"] +dependencies = [ + "backports-asyncio-runner<2,>=1.1; python_version < \"3.11\"", + "pytest<9,>=8.2", + "typing-extensions>=4.12; python_version < \"3.13\"", +] +files = [ + {file = "pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99"}, + {file = "pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57"}, +] + [[package]] name = "pytest-cov" version = "7.0.0" diff --git a/pyproject.toml b/pyproject.toml index a98d1c6..9e51e03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dev = [ "pre-commit>=4.3.0", "types-tqdm>=4.67.0.20250809", "types-PyYAML>=6.0.12.20250915", + "pytest-asyncio>=1.2.0", ] [tool.pdm] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 27b43e1..0000000 --- a/requirements.txt +++ /dev/null @@ -1,176 +0,0 @@ -accelerate==1.10.1 -aiohappyeyeballs==2.6.1 -aiohttp==3.12.15 -aiosignal==1.4.0 -annotated-types==0.7.0 -anyio==4.10.0 -astor==0.8.1 -asttokens==3.0.0 -attrs==25.3.0 -blake3==1.0.6 -cachetools==6.2.0 -cbor2==5.7.0 -certifi==2025.8.3 -cffi==2.0.0 -charset-normalizer==3.4.3 -click==8.3.0 -cloudpickle==3.1.1 -comm==0.2.3 -compressed-tensors==0.11.0 -cupy-cuda12x==13.6.0 -datasets==4.1.1 -debugpy==1.8.17 -decorator==5.2.1 -depyf==0.19.0 -dill==0.4.0 -diskcache==5.6.3 -distro==1.9.0 -dnspython==2.8.0 -einops==0.8.1 -email-validator==2.3.0 -executing==2.2.1 -fastapi==0.117.1 -fastapi-cli==0.0.13 -fastapi-cloud-cli==0.2.0 -fastrlock==0.8.3 -filelock==3.19.1 -frozendict==2.4.6 -frozenlist==1.7.0 -fsspec==2025.9.0 -gguf==0.17.1 -h11==0.16.0 -hf-xet==1.1.10 -httpcore==1.0.9 -httptools==0.6.4 -httpx==0.28.1 -huggingface-hub==0.35.0 -idna==3.10 -interegular==0.3.3 -ipykernel==6.30.1 -ipython==9.5.0 -ipython_pygments_lexers==1.1.1 -ipywidgets==8.1.7 -jedi==0.19.2 -Jinja2==3.1.6 -jiter==0.11.0 -jsonschema==4.25.1 -jsonschema-specifications==2025.9.1 -jupyter_client==8.6.3 -jupyter_core==5.8.1 -jupyterlab_widgets==3.0.15 -lark==1.2.2 -llguidance==0.7.30 -llvmlite==0.44.0 -lm-format-enforcer==0.11.3 -markdown-it-py==4.0.0 -MarkupSafe==3.0.2 -matplotlib-inline==0.1.7 -mdurl==0.1.2 -mistral_common==1.8.5 -mpmath==1.3.0 -msgpack==1.1.1 -msgspec==0.19.0 -multidict==6.6.4 -multiprocess==0.70.16 -nest-asyncio==1.6.0 -networkx==3.5 -ninja==1.13.0 -numba==0.61.2 -numpy==2.2.6 -nvidia-cublas-cu12==12.8.4.1 -nvidia-cuda-cupti-cu12==12.8.90 -nvidia-cuda-nvrtc-cu12==12.8.93 -nvidia-cuda-runtime-cu12==12.8.90 -nvidia-cudnn-cu12==9.10.2.21 -nvidia-cufft-cu12==11.3.3.83 -nvidia-cufile-cu12==1.13.1.3 -nvidia-curand-cu12==10.3.9.90 -nvidia-cusolver-cu12==11.7.3.90 -nvidia-cusparse-cu12==12.5.8.93 -nvidia-cusparselt-cu12==0.7.1 -nvidia-nccl-cu12==2.27.3 -nvidia-nvjitlink-cu12==12.8.93 -nvidia-nvtx-cu12==12.8.90 -openai==1.108.2 -openai-harmony==0.0.4 -opencv-python-headless==4.12.0.88 -outlines_core==0.2.11 -packaging==25.0 -pandas==2.3.2 -parso==0.8.5 -partial-json-parser==0.2.1.1.post6 -pexpect==4.9.0 -pillow==11.3.0 -platformdirs==4.4.0 -prometheus-fastapi-instrumentator==7.1.0 -prometheus_client==0.23.1 -prompt_toolkit==3.0.52 -propcache==0.3.2 -protobuf==6.32.1 -psutil==7.1.0 -ptyprocess==0.7.0 -pure_eval==0.2.3 -py-cpuinfo==9.0.0 -pyarrow==21.0.0 -pybase64==1.4.2 -pycountry==24.6.1 -pycparser==2.23 -pydantic==2.11.9 -pydantic-extra-types==2.10.5 -pydantic_core==2.33.2 -Pygments==2.19.2 -python-dateutil==2.9.0.post0 -python-dotenv==1.1.1 -python-json-logger==3.3.0 -python-multipart==0.0.20 -pytz==2025.2 -PyYAML==6.0.2 -pyzmq==27.1.0 -ray==2.49.2 -referencing==0.36.2 -regex==2025.9.18 -requests==2.32.5 -rich==14.1.0 -rich-toolkit==0.15.1 -rignore==0.6.4 -rpds-py==0.27.1 -safetensors==0.6.2 -scipy==1.16.2 -sentencepiece==0.2.1 -sentry-sdk==2.38.0 -setproctitle==1.3.7 -shellingham==1.5.4 -six==1.17.0 -sniffio==1.3.1 -soundfile==0.13.1 -soxr==1.0.0 -stack-data==0.6.3 -starlette==0.48.0 -sympy==1.14.0 -tiktoken==0.11.0 -tokenizers==0.22.1 -torch==2.8.0 -torchaudio==2.8.0 -torchvision==0.23.0 -tornado==6.5.2 -tqdm==4.67.1 -traitlets==5.14.3 -transformers==4.56.2 -triton==3.4.0 -trl==0.23.0 -typer==0.19.1 -typing-inspection==0.4.1 -typing_extensions==4.15.0 -tzdata==2025.2 -urllib3==2.5.0 -uvicorn==0.36.0 -uvloop==0.21.0 -vllm==0.10.2 -watchfiles==1.1.0 -wcwidth==0.2.14 -websockets==15.0.1 -widgetsnbextension==4.0.14 -xformers==0.0.32.post1 -xgrammar==0.1.23 -xxhash==3.5.0 -yarl==1.20.1 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..92f1523 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,70 @@ +import json +import os +import sqlite3 + +import pytest + +import llmsql.inference.inference as inference + + +@pytest.fixture +def temp_dir(tmp_path): + return tmp_path + + +@pytest.fixture +def dummy_db_file(tmp_path): + """Create a temporary SQLite DB file for testing, cleanup afterwards.""" + db_path = tmp_path / "test.db" + conn = sqlite3.connect(db_path) + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + conn.execute("INSERT INTO test (name) VALUES ('Alice'), ('Bob')") + conn.commit() + conn.close() + + yield str(db_path) + + # cleanup + if os.path.exists(db_path): + os.remove(db_path) + + +@pytest.fixture +def mock_llm(monkeypatch): + """Mock vLLM LLM to avoid GPU/model loading.""" + + class DummyOutput: + def __init__(self, text="SELECT 1"): + self.outputs = [type("Obj", (), {"text": text})()] + + class DummyLLM: + def generate(self, prompts, sampling_params): + return [DummyOutput(f"-- SQL for: {p}") for p in prompts] + + monkeypatch.setattr(inference, "LLM", lambda **_: DummyLLM()) + return DummyLLM() + + +@pytest.fixture +def fake_jsonl_files(tmp_path): + """Create fake questions.jsonl and tables.jsonl.""" + qpath = tmp_path / "questions.jsonl" + tpath = tmp_path / "tables.jsonl" + + questions = [ + {"question_id": "q1", "question": "How many users?", "table_id": "t1"}, + {"question_id": "q2", "question": "List names", "table_id": "t1"}, + ] + tables = [ + { + "table_id": "t1", + "header": ["id", "name"], + "types": ["int", "text"], + "rows": [[1, "Alice"], [2, "Bob"]], + } + ] + + qpath.write_text("\n".join(json.dumps(q) for q in questions)) + tpath.write_text("\n".join(json.dumps(t) for t in tables)) + + return str(qpath), str(tpath) diff --git a/tests/evaluation/test_evaluator_stability.py b/tests/evaluation/test_evaluator_stability.py new file mode 100644 index 0000000..3cfdcf5 --- /dev/null +++ b/tests/evaluation/test_evaluator_stability.py @@ -0,0 +1,67 @@ +import json +from pathlib import Path +import sqlite3 + +import pytest + +from llmsql import LLMSQLEvaluator + + +@pytest.mark.asyncio +async def test_connect_and_close(dummy_db_file): + evaluator = LLMSQLEvaluator() + evaluator.connect(dummy_db_file) + assert isinstance(evaluator.conn, sqlite3.Connection) + evaluator.close() + assert evaluator.conn is None + + +@pytest.mark.asyncio +async def test_download_file_is_called(monkeypatch, temp_dir): + evaluator = LLMSQLEvaluator(workdir_path=temp_dir) + + def fake_download(*args, **kwargs): + file_path = temp_dir / "fake_file.txt" + file_path.write_text("content") + return str(file_path) + + monkeypatch.setattr("llmsql.evaluation.evaluator.hf_hub_download", fake_download) + + path = evaluator._download_file("fake_file.txt") + assert Path(path).exists() + + +@pytest.mark.asyncio +async def test_evaluate_with_mock(monkeypatch, temp_dir, dummy_db_file): + evaluator = LLMSQLEvaluator(workdir_path=temp_dir) + + # Fake questions.jsonl + questions_path = temp_dir / "questions.jsonl" + questions_path.write_text( + json.dumps({"question_id": 1, "question": "Sample quesiton", "sql": "SELECT 1"}) + ) + + # Fake outputs.jsonl + outputs_path = temp_dir / "outputs.jsonl" + outputs_path.write_text(json.dumps({"question_id": 1, "predicted": "SELECT 1"})) + + # Monkeypatch dependencies + monkeypatch.setattr( + "llmsql.evaluation.evaluator.evaluate_sample", + lambda *a, **k: (1, None, {"pred_none": 0, "gold_none": 0, "sql_error": 0}), + ) + monkeypatch.setattr("llmsql.evaluation.evaluator.log_mismatch", lambda **k: None) + monkeypatch.setattr( + "llmsql.evaluation.evaluator.print_summary", lambda *a, **k: None + ) + + report = evaluator.evaluate( + outputs_path=str(outputs_path), + questions_path=str(questions_path), + db_path=dummy_db_file, + show_mismatches=False, + ) + + assert report["total"] == 1 + assert report["matches"] == 1 + assert report["accuracy"] == 1.0 diff --git a/tests/finetune/test_finetune_stability.py b/tests/finetune/test_finetune_stability.py new file mode 100644 index 0000000..5fd1029 --- /dev/null +++ b/tests/finetune/test_finetune_stability.py @@ -0,0 +1,138 @@ +import json +from pathlib import Path + +import pytest + +import llmsql.finetune.finetune as finetune + + +@pytest.mark.asyncio +async def test_parse_args_and_config_with_yaml(tmp_path, monkeypatch): + # create a fake yaml config file + yaml_file = tmp_path / "config.yaml" + yaml_file.write_text( + """ + model_name_or_path: "gpt2" + output_dir: "out" + num_train_epochs: 2 + train_file: "train.jsonl" + val_file: "val.jsonl" + """ + ) + + testargs = ["prog", "--config_file", str(yaml_file)] + monkeypatch.setattr("sys.argv", testargs) + + args = finetune.parse_args_and_config() + assert args.model_name_or_path == "gpt2" + assert args.output_dir == "out" + assert args.num_train_epochs == 2 + assert args.train_file == "train.jsonl" + assert args.val_file == "val.jsonl" + + +@pytest.mark.asyncio +async def test_build_dataset(monkeypatch, tmp_path): + # fake tables + tables = { + "t1": { + "table_id": "t1", + "header": ["col1"], + "types": ["text"], + "rows": [["foo"]], + } + } + # fake questions + q_file = tmp_path / "train.jsonl" + q_file.write_text( + json.dumps({"question": "What?", "sql": "SELECT 1", "table_id": "t1"}) + "\n" + ) + + # fake prompt builder + def fake_builder(question, header, types, example_row): + return f"{question} | {header} | {types} | {example_row}" + + monkeypatch.setattr( + "llmsql.finetune.finetune.load_jsonl", + lambda f: [json.loads(line) for line in open(f)], + ) + + dataset = finetune.build_dataset(str(q_file), tables, fake_builder) + assert len(dataset) == 1 + sample = dataset[0] + assert "prompt" in sample and "completion" in sample + assert sample["completion"] == "SELECT 1" + + +@pytest.mark.asyncio +async def test_download_file(monkeypatch, tmp_path): + def fake_download(**kwargs): + file_path = tmp_path / kwargs["filename"] + file_path.write_text("content") + return str(file_path) + + monkeypatch.setattr("llmsql.finetune.finetune.hf_hub_download", fake_download) + monkeypatch.setattr("shutil.copy", lambda src, dst: Path(dst).write_text("copied")) + + out = finetune._download_file("file.txt", "repo", str(tmp_path)) + assert Path(out).exists() + + +@pytest.mark.asyncio +async def test_main_runs_with_mocks(tmp_path, monkeypatch): + train_file = tmp_path / "train.jsonl" + val_file = tmp_path / "val.jsonl" + tables_file = tmp_path / "tables.jsonl" + + train_file.write_text( + json.dumps({"question": "Q?", "sql": "SELECT 1", "table_id": "t1"}) + "\n" + ) + val_file.write_text( + json.dumps({"question": "Q?", "sql": "SELECT 2", "table_id": "t1"}) + "\n" + ) + tables_file.write_text( + json.dumps( + {"table_id": "t1", "header": ["c"], "types": ["text"], "rows": [["r"]]} + ) + + "\n" + ) + + monkeypatch.setattr( + "llmsql.finetune.finetune.load_jsonl", + lambda f: [json.loads(line) for line in open(f)], + ) + monkeypatch.setattr( + "llmsql.finetune.finetune.choose_prompt_builder", + lambda shots: lambda q, h, t, r: "PROMPT", + ) + monkeypatch.setattr( + "llmsql.finetune.finetune.AutoModelForCausalLM", + type("FakeModel", (), {"from_pretrained": lambda *a, **k: "MODEL"}), + ) + monkeypatch.setattr( + "llmsql.finetune.finetune.SFTConfig", lambda **kwargs: {"args": kwargs} + ) + + class FakeTrainer: + def __init__(self, model, train_dataset, eval_dataset, args): + pass + + def train(self): + return "trained" + + def save_model(self, path): + Path(path).mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr("llmsql.finetune.finetune.SFTTrainer", FakeTrainer) + + finetune.main( + model_name_or_path="gpt2", + output_dir=str(tmp_path / "out"), + train_file=str(train_file), + val_file=str(val_file), + tables_file=str(tables_file), + shots=1, + num_train_epochs=1, + ) + + assert (tmp_path / "out" / "final_model").exists() diff --git a/tests/inference/test_inference_stability.py b/tests/inference/test_inference_stability.py new file mode 100644 index 0000000..ebae1c1 --- /dev/null +++ b/tests/inference/test_inference_stability.py @@ -0,0 +1,139 @@ +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from llmsql.inference import inference + + +@pytest.mark.asyncio +async def test_download_file(monkeypatch, tmp_path): + """Ensure _download_file calls hf_hub_download correctly.""" + # Patch LLM so __init__ does not try to load a real model + with patch("llmsql.inference.inference.LLM") as mock_llm: + mock_llm.return_value = object() # dummy instance + inf = inference.LLMSQLVLLMInference("dummy-model") + + called = {} + + def fake_download(**kwargs): + called.update(kwargs) + return str(tmp_path / "questions.jsonl") + + monkeypatch.setattr(inference, "hf_hub_download", fake_download) + + path = inf._download_file("questions.jsonl") + + assert "repo_id" in called + assert called["filename"] == "questions.jsonl" + assert path.endswith("questions.jsonl") + + +@pytest.mark.asyncio +async def test_generate_with_local_files(monkeypatch, tmp_path): + """Generate should read JSONL, call LLM, and write outputs.""" + # Create fake questions/tables + qpath = tmp_path / "questions.jsonl" + tpath = tmp_path / "tables.jsonl" + + questions = [ + {"question_id": "q1", "question": "What is 1+1?", "table_id": "t1"}, + {"question_id": "q2", "question": "What is 2+2?", "table_id": "t1"}, + ] + tables = [ + {"table_id": "t1", "header": ["col"], "types": ["text"], "rows": [["foo"]]} + ] + + qpath.write_text("\n".join(json.dumps(q) for q in questions)) + tpath.write_text("\n".join(json.dumps(t) for t in tables)) + + out_file = tmp_path / "out.jsonl" + + # Patch utils + monkeypatch.setattr( + inference, + "load_jsonl", + lambda path: [json.loads(line) for line in Path(path).read_text().splitlines()], + ) + monkeypatch.setattr( + inference, "overwrite_jsonl", lambda path: Path(path).write_text("") + ) + monkeypatch.setattr( + inference, + "save_jsonl_lines", + lambda path, lines: Path(path).write_text( + Path(path).read_text() + + "\n".join(json.dumps(line) for line in lines) + + "\n" + ), + ) + monkeypatch.setattr( + inference, + "choose_prompt_builder", + lambda shots: lambda q, h, t, r: f"PROMPT: {q}", + ) + + # Patch LLM with a fake generate + fake_llm = MagicMock() + fake_llm.generate.return_value = [ + MagicMock(outputs=[MagicMock(text="SELECT 2")]), + MagicMock(outputs=[MagicMock(text="SELECT 4")]), + ] + + with patch("llmsql.inference.inference.LLM", return_value=fake_llm): + inf = inference.LLMSQLVLLMInference("dummy-model") + + results = inf.generate( + output_file=str(out_file), + questions_path=str(qpath), + tables_path=str(tpath), + shots=1, + batch_size=1, + max_new_tokens=5, + temperature=0.7, + ) + + assert len(results) == 2 + assert all("question_id" in r and "completion" in r for r in results) + assert out_file.exists() + written = out_file.read_text().strip().splitlines() + assert len(written) == 2 + + +@pytest.mark.asyncio +async def test_generate_downloads_if_missing( + monkeypatch, mock_llm, fake_jsonl_files, tmp_path +): + """If paths not provided, should use _download_file.""" + qpath, tpath = fake_jsonl_files + (tmp_path / "questions.jsonl").unlink() # remove to force download + (tmp_path / "tables.jsonl").unlink() + + inf = inference.LLMSQLVLLMInference("dummy-model", workdir_path=str(tmp_path)) + + monkeypatch.setattr( + inference, + "load_jsonl", + lambda path: [{"question_id": "q1", "question": "x?", "table_id": "t1"}] + if "questions" in path + else [{"table_id": "t1", "header": ["id"], "types": ["int"], "rows": [[1]]}], + ) + monkeypatch.setattr(inference, "overwrite_jsonl", lambda path: None) + monkeypatch.setattr(inference, "save_jsonl_lines", lambda path, lines: None) + monkeypatch.setattr( + inference, "choose_prompt_builder", lambda shots: lambda *a: "PROMPT" + ) + + called = {"q": 0, "t": 0} + + def fake_download(filename, **_): + called["q" if "questions" in filename else "t"] += 1 + return str(tmp_path / filename) + + monkeypatch.setattr(inference, "hf_hub_download", fake_download) + + results = inf.generate(output_file=str(tmp_path / "out.jsonl")) + assert results + assert called["q"] == 1 + assert called["t"] == 1