Skip to content

Commit f309512

Browse files
second inference type added; vllm moved to optional deps; README updated; test written for new changes
1 parent 93c0de0 commit f309512

17 files changed

+1281
-736
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ dist/
77
*.egg-info/
88
.pdm-python
99
.vscode
10+
11+
.coverage

README.md

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,19 @@ Modern LLMs are already strong at **producing SQL queries without finetuning**.
4040
We therefore recommend that most users:
4141

4242
1. **Run inference** directly on the full benchmark:
43-
- Use [`llmsql.LLMSQLVLLMInference`](./llmsql/inference/inference.py) (the main inference class) for generation of SQL predictions with your LLM from HF.
43+
model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
44+
output_file="path_to_your_outputs.jsonl",
45+
- Use [`llmsql.inference_transformers`](./llmsql/inference/inference_transformers.py) (the function for transformers inference) for generation of SQL predictions with your model. If you want to do vllm based inference, use [`llmsql.inference_vllm`](./llmsql/inference/inference_vllm.py). Works both with HF model id, e.g. `Qwen/Qwen2.5-1.5B-Instruct` and model instance passed directly, e.g. `inference_transformers(model_or_model_name_or_path=model, ...)`
4446
- Evaluate results against the benchmark with the [`llmsql.LLMSQLEvaluator`](./llmsql/evaluation/evaluator.py) evaluator class.
4547

4648
2. **Optional finetuning**:
4749
- For research or domain adaptation, we provide finetuning script for HF models. Use `llmsql finetune --help` or read [Finetune Readme](./llmsql/finetune/README.md) to find more about finetuning.
4850

4951
> [!Tip]
5052
> You can find additional manuals in the README files of each folder([Inferece Readme](./llmsql/inference/README.md), [Evaluation Readme](./llmsql/evaluation/README.md), [Finetune Readme](./llmsql/finetune/README.md))
53+
54+
> [!Tip]
55+
> vllm based inference require vllm optional dependency group installed: `pip install llmsql[vllm]`
5156
---
5257

5358
## Repository Structure
@@ -77,24 +82,21 @@ pip3 install llmsql
7782
### 1. Run Inference
7883

7984
```python
80-
from llmsql import LLMSQLVLLMInference
85+
from llmsql import inference_transformers
8186

82-
# Initialize inference engine
83-
inference = LLMSQLVLLMInference(
84-
model_name="Qwen/Qwen2.5-1.5B-Instruct", # or any Hugging Face causal LM
85-
tensor_parallel_size=1,
86-
)
87-
88-
# Run generation
89-
results = inference.generate(
87+
# Run generation directly with transformers
88+
results = inference_transformers(
89+
model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
9090
output_file="path_to_your_outputs.jsonl",
91-
questions_path="data/questions.jsonl",
92-
tables_path="data/tables.jsonl",
9391
shots=5,
9492
batch_size=8,
9593
max_new_tokens=256,
96-
temperature=0.7,
94+
do_sample=False,
95+
model_args={
96+
"torch_dtype": "bfloat16",
97+
}
9798
)
99+
98100
```
99101

100102
### 2. Evaluate Results

llmsql/__init__.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
1+
"""
2+
LLMSQL — A Text2SQL benchmark for evaluation of Large Language Models
3+
"""
4+
15
__version__ = "0.1.4"
26

37

48
def __getattr__(name: str): # type: ignore
5-
if name == "LLMSQLVLLMInference":
6-
from .inference.inference import LLMSQLVLLMInference
7-
8-
return LLMSQLVLLMInference
9-
elif name == "LLMSQLEvaluator":
9+
if name == "LLMSQLEvaluator":
1010
from .evaluation.evaluator import LLMSQLEvaluator
1111

1212
return LLMSQLEvaluator
13+
elif name == "inference_vllm":
14+
try:
15+
from .inference.inference_vllm import inference_vllm
16+
17+
return inference_vllm
18+
except ModuleNotFoundError as e:
19+
if "vllm" in str(e):
20+
raise ImportError(
21+
"The vLLM backend is not installed. "
22+
"Install it with: pip install llmsql[vllm]"
23+
) from e
24+
raise
25+
elif name == "inference_transformers":
26+
from .inference.inference_transformers import inference_transformers
27+
28+
return inference_transformers
1329
raise AttributeError(f"module {__name__} has no attribute {name!r}")
1430

1531

16-
__all__ = ["LLMSQLEvaluator", "LLMSQLVLLMInference"]
32+
__all__ = ["LLMSQLEvaluator", "inference_vllm", "inference_transformers"]

llmsql/__main__.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import argparse
2+
import inspect
23
import sys
34

45

56
def main() -> None:
67
parser = argparse.ArgumentParser(prog="llmsql", description="LLMSQL CLI")
78
subparsers = parser.add_subparsers(dest="command")
89

10+
# ================================================================
11+
# Fine-tuning command
12+
# ================================================================
913
ft_parser = subparsers.add_parser(
1014
"finetune",
1115
help="Fine-tune a causal LM on the LLMSQL benchmark.",
@@ -21,13 +25,132 @@ def main() -> None:
2125
help="Path to a YAML config file containing training parameters.",
2226
)
2327

28+
# ================================================================
29+
# Inference command
30+
# ================================================================
31+
inference_examples = r"""
32+
Examples:
33+
34+
# 1️⃣ Run inference with Transformers backend
35+
llmsql inference --method transformers \
36+
--model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
37+
--output-file outputs/preds_transformers.jsonl \
38+
--batch-size 8 \
39+
--shots 5
40+
41+
# 2️⃣ Run inference with vLLM backend
42+
llmsql inference --method vllm \
43+
--model-name Qwen/Qwen2.5-1.5B-Instruct \
44+
--output-file outputs/preds_vllm.jsonl \
45+
--batch-size 8 \
46+
--shots 5
47+
48+
# 3️⃣ Pass model-specific kwargs (for Transformers)
49+
llmsql inference --method transformers \
50+
--model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \
51+
--output-file outputs/llama_preds.jsonl \
52+
--model-args '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}'
53+
54+
# 4️⃣ Pass LLM init kwargs (for vLLM)
55+
llmsql inference --method vllm \
56+
--model-name mistralai/Mixtral-8x7B-Instruct-v0.1 \
57+
--output-file outputs/mixtral_preds.jsonl \
58+
--llm-kwargs '{"max_model_len": 4096, "gpu_memory_utilization": 0.9}'
59+
60+
# 5️⃣ Override generation parameters dynamically
61+
llmsql inference --method transformers \
62+
--model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
63+
--output-file outputs/temp_0.9.jsonl \
64+
--temperature 0.9 \
65+
--generate-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}'
66+
"""
67+
68+
inf_parser = subparsers.add_parser(
69+
"inference",
70+
help="Run inference using either Transformers or vLLM backend.",
71+
description="Run SQL generation using a chosen inference method "
72+
"(either 'transformers' or 'vllm').",
73+
epilog=inference_examples,
74+
formatter_class=argparse.RawTextHelpFormatter,
75+
)
76+
77+
inf_parser.add_argument(
78+
"--method",
79+
type=str,
80+
required=True,
81+
choices=["transformers", "vllm"],
82+
help="Inference backend to use ('transformers' or 'vllm').",
83+
)
84+
85+
# Catch-all for remaining inference-specific args
86+
inf_parser.add_argument(
87+
"--args",
88+
nargs=argparse.REMAINDER,
89+
help="Additional arguments passed to the chosen inference function "
90+
"(e.g., --model-name MODEL --output-file FILE --batch-size 8 ...).",
91+
)
92+
93+
# ================================================================
94+
# Parse CLI
95+
# ================================================================
2496
args, extra = parser.parse_known_args()
2597

98+
# ------------------------------------------------
99+
# Fine-tune
100+
# ------------------------------------------------
26101
if args.command == "finetune":
27102
from llmsql.finetune import finetune
28103

29104
sys.argv = ["llmsql-finetune"] + extra + ["--config_file", args.config_file]
30105
finetune.run_cli()
106+
107+
# ------------------------------------------------
108+
# Inference
109+
# ------------------------------------------------
110+
elif args.command == "inference":
111+
if args.method == "vllm":
112+
from llmsql import inference_vllm as inference_fn
113+
elif args.method == "transformers":
114+
from llmsql import inference_transformers as inference_fn # type: ignore
115+
else:
116+
raise ValueError(f"Unknown inference method: {args.method}")
117+
118+
# Dynamically create parser from the function signature
119+
fn_parser = argparse.ArgumentParser(
120+
prog=f"llmsql inference --method {args.method}",
121+
description=f"Run inference using {args.method} backend",
122+
)
123+
124+
sig = inspect.signature(inference_fn)
125+
for name, param in sig.parameters.items():
126+
if name.startswith("**"):
127+
continue # skip **kwargs
128+
arg_name = f"--{name.replace('_', '-')}"
129+
default = param.default
130+
if default is inspect.Parameter.empty:
131+
fn_parser.add_argument(arg_name, required=True)
132+
else:
133+
if isinstance(default, bool):
134+
fn_parser.add_argument(
135+
arg_name,
136+
action="store_true" if not default else "store_false",
137+
help=f"(default: {default})",
138+
)
139+
else:
140+
fn_parser.add_argument(
141+
arg_name, type=type(default), default=default
142+
)
143+
144+
fn_args = fn_parser.parse_args(args.args or [])
145+
fn_kwargs = vars(fn_args)
146+
147+
print(f"🔹 Running {args.method} inference with arguments:")
148+
for k, v in fn_kwargs.items():
149+
print(f" {k}: {v}")
150+
151+
results = inference_fn(**fn_kwargs)
152+
print(f"✅ Inference complete. Generated {len(results)} results.")
153+
31154
else:
32155
parser.print_help()
33156

llmsql/config/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
REPO_ID = "llmsql-bench/llmsql-benchmark"
2+
DEFAULT_WORKDIR_PATH = "llmsql_workdir"

0 commit comments

Comments
 (0)