Skip to content

Commit d77d449

Browse files
Merge pull request #27 from LLMSQL/19-feature-add-transformers-inference
second inference type added; vllm moved to optional deps; README upda…
2 parents 93c0de0 + ca50b95 commit d77d449

20 files changed

+1325
-5472
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
${{ runner.os }}-pdm-
3131
3232
- name: Install dependencies (with dev)
33-
run: pdm install --with dev
33+
run: pdm install --with dev,vllm
3434

3535
- name: Run tests with coverage
3636
run: PYTHONPATH=. pdm run pytest --cov=llmsql --cov-report=xml --maxfail=1 --disable-warnings -v
@@ -41,4 +41,4 @@ jobs:
4141
token: ${{ secrets.CODECOV_TOKEN }}
4242
files: ./coverage.xml
4343
flags: unittests
44-
fail_ci_if_error: true
44+
fail_ci_if_error: false

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ dist/
77
*.egg-info/
88
.pdm-python
99
.vscode
10+
11+
.coverage
12+
llmsql_workdir

README.md

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,19 @@ Modern LLMs are already strong at **producing SQL queries without finetuning**.
4040
We therefore recommend that most users:
4141

4242
1. **Run inference** directly on the full benchmark:
43-
- Use [`llmsql.LLMSQLVLLMInference`](./llmsql/inference/inference.py) (the main inference class) for generation of SQL predictions with your LLM from HF.
43+
model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
44+
output_file="path_to_your_outputs.jsonl",
45+
- Use [`llmsql.inference_transformers`](./llmsql/inference/inference_transformers.py) (the function for transformers inference) for generation of SQL predictions with your model. If you want to do vllm based inference, use [`llmsql.inference_vllm`](./llmsql/inference/inference_vllm.py). Works both with HF model id, e.g. `Qwen/Qwen2.5-1.5B-Instruct` and model instance passed directly, e.g. `inference_transformers(model_or_model_name_or_path=model, ...)`
4446
- Evaluate results against the benchmark with the [`llmsql.LLMSQLEvaluator`](./llmsql/evaluation/evaluator.py) evaluator class.
4547

4648
2. **Optional finetuning**:
4749
- For research or domain adaptation, we provide finetuning script for HF models. Use `llmsql finetune --help` or read [Finetune Readme](./llmsql/finetune/README.md) to find more about finetuning.
4850

4951
> [!Tip]
5052
> You can find additional manuals in the README files of each folder([Inferece Readme](./llmsql/inference/README.md), [Evaluation Readme](./llmsql/evaluation/README.md), [Finetune Readme](./llmsql/finetune/README.md))
53+
54+
> [!Tip]
55+
> vllm based inference require vllm optional dependency group installed: `pip install llmsql[vllm]`
5156
---
5257

5358
## Repository Structure
@@ -77,24 +82,21 @@ pip3 install llmsql
7782
### 1. Run Inference
7883

7984
```python
80-
from llmsql import LLMSQLVLLMInference
85+
from llmsql import inference_transformers
8186

82-
# Initialize inference engine
83-
inference = LLMSQLVLLMInference(
84-
model_name="Qwen/Qwen2.5-1.5B-Instruct", # or any Hugging Face causal LM
85-
tensor_parallel_size=1,
86-
)
87-
88-
# Run generation
89-
results = inference.generate(
87+
# Run generation directly with transformers
88+
results = inference_transformers(
89+
model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
9090
output_file="path_to_your_outputs.jsonl",
91-
questions_path="data/questions.jsonl",
92-
tables_path="data/tables.jsonl",
93-
shots=5,
91+
num_fewshots=5,
9492
batch_size=8,
9593
max_new_tokens=256,
96-
temperature=0.7,
94+
do_sample=False,
95+
model_args={
96+
"torch_dtype": "bfloat16",
97+
}
9798
)
99+
98100
```
99101

100102
### 2. Evaluate Results

examples/inference.ipynb

Lines changed: 0 additions & 4728 deletions
This file was deleted.

llmsql/__init__.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
1+
"""
2+
LLMSQL — A Text2SQL benchmark for evaluation of Large Language Models
3+
"""
4+
15
__version__ = "0.1.4"
26

37

48
def __getattr__(name: str): # type: ignore
5-
if name == "LLMSQLVLLMInference":
6-
from .inference.inference import LLMSQLVLLMInference
7-
8-
return LLMSQLVLLMInference
9-
elif name == "LLMSQLEvaluator":
9+
if name == "LLMSQLEvaluator":
1010
from .evaluation.evaluator import LLMSQLEvaluator
1111

1212
return LLMSQLEvaluator
13+
elif name == "inference_vllm":
14+
try:
15+
from .inference.inference_vllm import inference_vllm
16+
17+
return inference_vllm
18+
except ModuleNotFoundError as e:
19+
if "vllm" in str(e):
20+
raise ImportError(
21+
"The vLLM backend is not installed. "
22+
"Install it with: pip install llmsql[vllm]"
23+
) from e
24+
raise
25+
elif name == "inference_transformers":
26+
from .inference.inference_transformers import inference_transformers
27+
28+
return inference_transformers
1329
raise AttributeError(f"module {__name__} has no attribute {name!r}")
1430

1531

16-
__all__ = ["LLMSQLEvaluator", "LLMSQLVLLMInference"]
32+
__all__ = ["LLMSQLEvaluator", "inference_vllm", "inference_transformers"]

llmsql/__main__.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import argparse
2+
import inspect
3+
import json
24
import sys
35

46

57
def main() -> None:
68
parser = argparse.ArgumentParser(prog="llmsql", description="LLMSQL CLI")
79
subparsers = parser.add_subparsers(dest="command")
810

11+
# ================================================================
12+
# Fine-tuning command
13+
# ================================================================
914
ft_parser = subparsers.add_parser(
1015
"finetune",
1116
help="Fine-tune a causal LM on the LLMSQL benchmark.",
@@ -21,13 +26,154 @@ def main() -> None:
2126
help="Path to a YAML config file containing training parameters.",
2227
)
2328

29+
# ================================================================
30+
# Inference command
31+
# ================================================================
32+
inference_examples = r"""
33+
Examples:
34+
35+
# 1️⃣ Run inference with Transformers backend
36+
llmsql inference --method transformers \
37+
--model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
38+
--output-file outputs/preds_transformers.jsonl \
39+
--batch-size 8 \
40+
--num-fewshots 5
41+
42+
# 2️⃣ Run inference with vLLM backend
43+
llmsql inference --method vllm \
44+
--model-name Qwen/Qwen2.5-1.5B-Instruct \
45+
--output-file outputs/preds_vllm.jsonl \
46+
--batch-size 8 \
47+
--num-fewshots 5
48+
49+
# 3️⃣ Pass model-specific kwargs (for Transformers)
50+
llmsql inference --method transformers \
51+
--model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \
52+
--output-file outputs/llama_preds.jsonl \
53+
--model-args '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}'
54+
55+
# 4️⃣ Pass LLM init kwargs (for vLLM)
56+
llmsql inference --method vllm \
57+
--model-name mistralai/Mixtral-8x7B-Instruct-v0.1 \
58+
--output-file outputs/mixtral_preds.jsonl \
59+
--llm-kwargs '{"max_model_len": 4096, "gpu_memory_utilization": 0.9}'
60+
61+
# 5️⃣ Override generation parameters dynamically
62+
llmsql inference --method transformers \
63+
--model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
64+
--output-file outputs/temp_0.9.jsonl \
65+
--temperature 0.9 \
66+
--generate-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}'
67+
"""
68+
69+
inf_parser = subparsers.add_parser(
70+
"inference",
71+
help="Run inference using either Transformers or vLLM backend.",
72+
description="Run SQL generation using a chosen inference method "
73+
"(either 'transformers' or 'vllm').",
74+
epilog=inference_examples,
75+
formatter_class=argparse.RawTextHelpFormatter,
76+
)
77+
78+
inf_parser.add_argument(
79+
"--method",
80+
type=str,
81+
required=True,
82+
choices=["transformers", "vllm"],
83+
help="Inference backend to use ('transformers' or 'vllm').",
84+
)
85+
86+
# ================================================================
87+
# Parse CLI
88+
# ================================================================
2489
args, extra = parser.parse_known_args()
2590

91+
# ------------------------------------------------
92+
# Fine-tune
93+
# ------------------------------------------------
2694
if args.command == "finetune":
2795
from llmsql.finetune import finetune
2896

2997
sys.argv = ["llmsql-finetune"] + extra + ["--config_file", args.config_file]
3098
finetune.run_cli()
99+
100+
# ------------------------------------------------
101+
# Inference
102+
# ------------------------------------------------
103+
elif args.command == "inference":
104+
if args.method == "vllm":
105+
from llmsql import inference_vllm as inference_fn
106+
elif args.method == "transformers":
107+
from llmsql import inference_transformers as inference_fn # type: ignore
108+
else:
109+
raise ValueError(f"Unknown inference method: {args.method}")
110+
111+
# Dynamically create parser from the function signature
112+
fn_parser = argparse.ArgumentParser(
113+
prog=f"llmsql inference --method {args.method}",
114+
description=f"Run inference using {args.method} backend",
115+
)
116+
117+
sig = inspect.signature(inference_fn)
118+
for name, param in sig.parameters.items():
119+
if param.kind == inspect.Parameter.VAR_KEYWORD:
120+
fn_parser.add_argument(
121+
"--llm-kwargs",
122+
default="{}",
123+
help="Additional LLM kwargs as a JSON string, e.g. '{\"top_p\": 0.9}'",
124+
)
125+
fn_parser.add_argument(
126+
"--generate-kwargs",
127+
default="{}",
128+
help="",
129+
)
130+
continue
131+
arg_name = f"--{name.replace('_', '-')}"
132+
default = param.default
133+
if default is inspect.Parameter.empty:
134+
fn_parser.add_argument(arg_name, required=True)
135+
else:
136+
if isinstance(default, bool):
137+
fn_parser.add_argument(
138+
arg_name,
139+
action="store_true" if not default else "store_false",
140+
help=f"(default: {default})",
141+
)
142+
elif default is None:
143+
fn_parser.add_argument(arg_name, type=str, default=None)
144+
else:
145+
fn_parser.add_argument(
146+
arg_name, type=type(default), default=default
147+
)
148+
149+
fn_args = fn_parser.parse_args(extra)
150+
fn_kwargs = vars(fn_args)
151+
152+
if "llm_kwargs" in fn_kwargs and isinstance(fn_kwargs["llm_kwargs"], str):
153+
try:
154+
fn_kwargs["llm_kwargs"] = json.loads(fn_kwargs["llm_kwargs"])
155+
except json.JSONDecodeError:
156+
print("⚠️ Could not parse --llm-kwargs JSON, passing as string.")
157+
158+
if fn_kwargs.get("model_args") is not None:
159+
try:
160+
fn_kwargs["model_args"] = json.loads(fn_kwargs["model_args"])
161+
except json.JSONDecodeError:
162+
raise
163+
164+
if fn_kwargs.get("generate_kwargs") is not None:
165+
try:
166+
fn_kwargs["generate_kwargs"] = json.loads(fn_kwargs["generate_kwargs"])
167+
except json.JSONDecodeError:
168+
raise
169+
170+
print(f"🔹 Running {args.method} inference with arguments:")
171+
for k, v in fn_kwargs.items():
172+
print(f" {k}: {v}")
173+
174+
results = inference_fn(**fn_kwargs)
175+
print(f"✅ Inference complete. Generated {len(results)} results.")
176+
31177
else:
32178
parser.print_help()
33179

llmsql/config/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
REPO_ID = "llmsql-bench/llmsql-benchmark"
2+
DEFAULT_WORKDIR_PATH = "llmsql_workdir"

0 commit comments

Comments
 (0)