Skip to content

Commit 3c56994

Browse files
cli commands refactored
1 parent fb4a5f4 commit 3c56994

File tree

7 files changed

+786
-147
lines changed

7 files changed

+786
-147
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ dataset/sqlite_tables.db
33
*__pycache__
44
.env
55
dist/
6+
testenv
67

78
*.egg-info/
89
.pdm-python

llmsql/__main__.py

Lines changed: 5 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,153 +1,11 @@
1-
import argparse
2-
import inspect
3-
import json
1+
from llmsql._cli import LLMSQLCLI
42

53

64
def main() -> None:
7-
parser = argparse.ArgumentParser(prog="llmsql", description="LLMSQL CLI")
8-
subparsers = parser.add_subparsers(dest="command")
9-
10-
# ================================================================
11-
# Inference command
12-
# ================================================================
13-
inference_examples = r"""
14-
Examples:
15-
16-
# 1️⃣ Run inference with Transformers backend
17-
llmsql inference --method transformers \
18-
--model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
19-
--output-file outputs/preds_transformers.jsonl \
20-
--batch-size 8 \
21-
--num-fewshots 5
22-
23-
# 2️⃣ Run inference with vLLM backend
24-
llmsql inference --method vllm \
25-
--model-name Qwen/Qwen2.5-1.5B-Instruct \
26-
--output-file outputs/preds_vllm.jsonl \
27-
--batch-size 8 \
28-
--num-fewshots 5
29-
30-
# 3️⃣ Pass model-specific kwargs (for Transformers)
31-
llmsql inference --method transformers \
32-
--model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \
33-
--output-file outputs/llama_preds.jsonl \
34-
--model-args '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}'
35-
36-
# 4️⃣ Pass LLM init kwargs (for vLLM)
37-
llmsql inference --method vllm \
38-
--model-name mistralai/Mixtral-8x7B-Instruct-v0.1 \
39-
--output-file outputs/mixtral_preds.jsonl \
40-
--llm-kwargs '{"max_model_len": 4096, "gpu_memory_utilization": 0.9}'
41-
42-
# 5️⃣ Override generation parameters dynamically
43-
llmsql inference --method transformers \
44-
--model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
45-
--output-file outputs/temp_0.9.jsonl \
46-
--temperature 0.9 \
47-
--generate-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}'
48-
"""
49-
50-
inf_parser = subparsers.add_parser(
51-
"inference",
52-
help="Run inference using either Transformers or vLLM backend.",
53-
description="Run SQL generation using a chosen inference method "
54-
"(either 'transformers' or 'vllm').",
55-
epilog=inference_examples,
56-
formatter_class=argparse.RawTextHelpFormatter,
57-
)
58-
59-
inf_parser.add_argument(
60-
"--method",
61-
type=str,
62-
required=True,
63-
choices=["transformers", "vllm"],
64-
help="Inference backend to use ('transformers' or 'vllm').",
65-
)
66-
67-
# ================================================================
68-
# Parse CLI
69-
# ================================================================
70-
args, extra = parser.parse_known_args()
71-
72-
# ------------------------------------------------
73-
# Inference
74-
# ------------------------------------------------
75-
if args.command == "inference":
76-
if args.method == "vllm":
77-
from llmsql import inference_vllm as inference_fn
78-
elif args.method == "transformers":
79-
from llmsql import inference_transformers as inference_fn # type: ignore
80-
else:
81-
raise ValueError(f"Unknown inference method: {args.method}")
82-
83-
# Dynamically create parser from the function signature
84-
fn_parser = argparse.ArgumentParser(
85-
prog=f"llmsql inference --method {args.method}",
86-
description=f"Run inference using {args.method} backend",
87-
)
88-
89-
sig = inspect.signature(inference_fn)
90-
for name, param in sig.parameters.items():
91-
if param.kind == inspect.Parameter.VAR_KEYWORD:
92-
fn_parser.add_argument(
93-
"--llm-kwargs",
94-
default="{}",
95-
help="Additional LLM kwargs as a JSON string, e.g. '{\"top_p\": 0.9}'",
96-
)
97-
fn_parser.add_argument(
98-
"--generate-kwargs",
99-
default="{}",
100-
help="",
101-
)
102-
continue
103-
arg_name = f"--{name.replace('_', '-')}"
104-
default = param.default
105-
if default is inspect.Parameter.empty:
106-
fn_parser.add_argument(arg_name, required=True)
107-
else:
108-
if isinstance(default, bool):
109-
fn_parser.add_argument(
110-
arg_name,
111-
action="store_true" if not default else "store_false",
112-
help=f"(default: {default})",
113-
)
114-
elif default is None:
115-
fn_parser.add_argument(arg_name, type=str, default=None)
116-
else:
117-
fn_parser.add_argument(
118-
arg_name, type=type(default), default=default
119-
)
120-
121-
fn_args = fn_parser.parse_args(extra)
122-
fn_kwargs = vars(fn_args)
123-
124-
if "llm_kwargs" in fn_kwargs and isinstance(fn_kwargs["llm_kwargs"], str):
125-
try:
126-
fn_kwargs["llm_kwargs"] = json.loads(fn_kwargs["llm_kwargs"])
127-
except json.JSONDecodeError:
128-
print("⚠️ Could not parse --llm-kwargs JSON, passing as string.")
129-
130-
if fn_kwargs.get("model_args") is not None:
131-
try:
132-
fn_kwargs["model_args"] = json.loads(fn_kwargs["model_args"])
133-
except json.JSONDecodeError:
134-
raise
135-
136-
if fn_kwargs.get("generate_kwargs") is not None:
137-
try:
138-
fn_kwargs["generate_kwargs"] = json.loads(fn_kwargs["generate_kwargs"])
139-
except json.JSONDecodeError:
140-
raise
141-
142-
print(f"🔹 Running {args.method} inference with arguments:")
143-
for k, v in fn_kwargs.items():
144-
print(f" {k}: {v}")
145-
146-
results = inference_fn(**fn_kwargs)
147-
print(f"✅ Inference complete. Generated {len(results)} results.")
148-
149-
else:
150-
parser.print_help()
5+
"""Main CLI entry point."""
6+
parser = LLMSQLCLI()
7+
args = parser.parse_args()
8+
parser.execute(args)
1519

15210

15311
if __name__ == "__main__":

llmsql/_cli/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
CLI subcommands to run from the terminal.
3+
"""
4+
5+
from .llmsql_cli import LLMSQLCLI
6+
7+
__all__ = ["LLMSQLCLI"]

llmsql/_cli/evaluation.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import argparse
2+
import json
3+
4+
from llmsql.config.config import DEFAULT_WORKDIR_PATH
5+
from llmsql.evaluation.evaluate import evaluate
6+
7+
8+
class EvaluationCommand:
9+
"""CLI wrapper for the `evaluate()` function."""
10+
11+
@staticmethod
12+
def register(subparsers: argparse._SubParsersAction) -> None:
13+
eval_parser = subparsers.add_parser(
14+
"evaluate",
15+
help="Evaluate SQL predictions against the LLMSQL benchmark.",
16+
description="Evaluate predicted SQL against the LLMSQL benchmark.",
17+
formatter_class=argparse.RawTextHelpFormatter,
18+
)
19+
20+
eval_parser.add_argument(
21+
"--outputs",
22+
required=True,
23+
help=(
24+
"Path to predictions JSONL file OR inline JSON list.\n"
25+
"Examples:\n"
26+
" --outputs outputs/preds.jsonl\n"
27+
' --outputs \'[{"id":1,"sql":"SELECT ..."}]\''
28+
),
29+
)
30+
31+
eval_parser.add_argument(
32+
"--workdir-path",
33+
type=str,
34+
default=None,
35+
help=(
36+
f"Optional. Where to store help .db files. Default: {DEFAULT_WORKDIR_PATH}"
37+
),
38+
)
39+
40+
eval_parser.add_argument(
41+
"--questions-path",
42+
type=str,
43+
default=None,
44+
help=(
45+
"Optional. Where is the questions of the benchmark stored. If not provided, questions will be downloaded."
46+
),
47+
)
48+
49+
eval_parser.add_argument(
50+
"--db-path",
51+
type=str,
52+
default=None,
53+
help=(
54+
"Optional. Where is the db file of the benchmark stored. If not provided, db file will be downloaded."
55+
),
56+
)
57+
58+
eval_parser.add_argument(
59+
"--save-report",
60+
type=str,
61+
default=None,
62+
help=(
63+
"Optional. Manual save path. If None → auto-generated with name 'evaluation_results_{uuid}.json'."
64+
),
65+
)
66+
67+
# Boolean toggle
68+
eval_parser.add_argument(
69+
"--show-mismatches",
70+
type=bool,
71+
default=True,
72+
help="Optional. Show mismatches during evaluation. Default: True.",
73+
)
74+
75+
eval_parser.add_argument(
76+
"--max-mismatches",
77+
type=int,
78+
default=5,
79+
help="Optional. Number of mismatches to print. Default: 5",
80+
)
81+
82+
# Dispatcher
83+
eval_parser.set_defaults(func=EvaluationCommand.execute)
84+
85+
# ----------------------------------------------------------
86+
@staticmethod
87+
def execute(args: argparse.Namespace) -> None:
88+
"""Run evaluation function."""
89+
# Try inline JSON first; fallback to path
90+
try:
91+
parsed = json.loads(args.outputs)
92+
outputs = parsed if isinstance(parsed, list) else args.outputs
93+
except Exception:
94+
outputs = args.outputs
95+
96+
result = evaluate(
97+
outputs=outputs,
98+
workdir_path=args.workdir_path,
99+
questions_path=args.questions_path,
100+
db_path=args.db_path,
101+
save_report=args.save_report,
102+
show_mismatches=args.show_mismatches,
103+
max_mismatches=args.max_mismatches,
104+
)
105+
106+
print(json.dumps(result, indent=2))

0 commit comments

Comments
 (0)