|
1 | | -import argparse |
2 | | -import inspect |
3 | | -import json |
| 1 | +from llmsql._cli import LLMSQLCLI |
4 | 2 |
|
5 | 3 |
|
6 | 4 | def main() -> None: |
7 | | - parser = argparse.ArgumentParser(prog="llmsql", description="LLMSQL CLI") |
8 | | - subparsers = parser.add_subparsers(dest="command") |
9 | | - |
10 | | - # ================================================================ |
11 | | - # Inference command |
12 | | - # ================================================================ |
13 | | - inference_examples = r""" |
14 | | -Examples: |
15 | | -
|
16 | | - # 1️⃣ Run inference with Transformers backend |
17 | | - llmsql inference --method transformers \ |
18 | | - --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ |
19 | | - --output-file outputs/preds_transformers.jsonl \ |
20 | | - --batch-size 8 \ |
21 | | - --num-fewshots 5 |
22 | | -
|
23 | | - # 2️⃣ Run inference with vLLM backend |
24 | | - llmsql inference --method vllm \ |
25 | | - --model-name Qwen/Qwen2.5-1.5B-Instruct \ |
26 | | - --output-file outputs/preds_vllm.jsonl \ |
27 | | - --batch-size 8 \ |
28 | | - --num-fewshots 5 |
29 | | -
|
30 | | - # 3️⃣ Pass model-specific kwargs (for Transformers) |
31 | | - llmsql inference --method transformers \ |
32 | | - --model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \ |
33 | | - --output-file outputs/llama_preds.jsonl \ |
34 | | - --model-args '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}' |
35 | | -
|
36 | | - # 4️⃣ Pass LLM init kwargs (for vLLM) |
37 | | - llmsql inference --method vllm \ |
38 | | - --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 \ |
39 | | - --output-file outputs/mixtral_preds.jsonl \ |
40 | | - --llm-kwargs '{"max_model_len": 4096, "gpu_memory_utilization": 0.9}' |
41 | | -
|
42 | | - # 5️⃣ Override generation parameters dynamically |
43 | | - llmsql inference --method transformers \ |
44 | | - --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ |
45 | | - --output-file outputs/temp_0.9.jsonl \ |
46 | | - --temperature 0.9 \ |
47 | | - --generate-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}' |
48 | | -""" |
49 | | - |
50 | | - inf_parser = subparsers.add_parser( |
51 | | - "inference", |
52 | | - help="Run inference using either Transformers or vLLM backend.", |
53 | | - description="Run SQL generation using a chosen inference method " |
54 | | - "(either 'transformers' or 'vllm').", |
55 | | - epilog=inference_examples, |
56 | | - formatter_class=argparse.RawTextHelpFormatter, |
57 | | - ) |
58 | | - |
59 | | - inf_parser.add_argument( |
60 | | - "--method", |
61 | | - type=str, |
62 | | - required=True, |
63 | | - choices=["transformers", "vllm"], |
64 | | - help="Inference backend to use ('transformers' or 'vllm').", |
65 | | - ) |
66 | | - |
67 | | - # ================================================================ |
68 | | - # Parse CLI |
69 | | - # ================================================================ |
70 | | - args, extra = parser.parse_known_args() |
71 | | - |
72 | | - # ------------------------------------------------ |
73 | | - # Inference |
74 | | - # ------------------------------------------------ |
75 | | - if args.command == "inference": |
76 | | - if args.method == "vllm": |
77 | | - from llmsql import inference_vllm as inference_fn |
78 | | - elif args.method == "transformers": |
79 | | - from llmsql import inference_transformers as inference_fn # type: ignore |
80 | | - else: |
81 | | - raise ValueError(f"Unknown inference method: {args.method}") |
82 | | - |
83 | | - # Dynamically create parser from the function signature |
84 | | - fn_parser = argparse.ArgumentParser( |
85 | | - prog=f"llmsql inference --method {args.method}", |
86 | | - description=f"Run inference using {args.method} backend", |
87 | | - ) |
88 | | - |
89 | | - sig = inspect.signature(inference_fn) |
90 | | - for name, param in sig.parameters.items(): |
91 | | - if param.kind == inspect.Parameter.VAR_KEYWORD: |
92 | | - fn_parser.add_argument( |
93 | | - "--llm-kwargs", |
94 | | - default="{}", |
95 | | - help="Additional LLM kwargs as a JSON string, e.g. '{\"top_p\": 0.9}'", |
96 | | - ) |
97 | | - fn_parser.add_argument( |
98 | | - "--generate-kwargs", |
99 | | - default="{}", |
100 | | - help="", |
101 | | - ) |
102 | | - continue |
103 | | - arg_name = f"--{name.replace('_', '-')}" |
104 | | - default = param.default |
105 | | - if default is inspect.Parameter.empty: |
106 | | - fn_parser.add_argument(arg_name, required=True) |
107 | | - else: |
108 | | - if isinstance(default, bool): |
109 | | - fn_parser.add_argument( |
110 | | - arg_name, |
111 | | - action="store_true" if not default else "store_false", |
112 | | - help=f"(default: {default})", |
113 | | - ) |
114 | | - elif default is None: |
115 | | - fn_parser.add_argument(arg_name, type=str, default=None) |
116 | | - else: |
117 | | - fn_parser.add_argument( |
118 | | - arg_name, type=type(default), default=default |
119 | | - ) |
120 | | - |
121 | | - fn_args = fn_parser.parse_args(extra) |
122 | | - fn_kwargs = vars(fn_args) |
123 | | - |
124 | | - if "llm_kwargs" in fn_kwargs and isinstance(fn_kwargs["llm_kwargs"], str): |
125 | | - try: |
126 | | - fn_kwargs["llm_kwargs"] = json.loads(fn_kwargs["llm_kwargs"]) |
127 | | - except json.JSONDecodeError: |
128 | | - print("⚠️ Could not parse --llm-kwargs JSON, passing as string.") |
129 | | - |
130 | | - if fn_kwargs.get("model_args") is not None: |
131 | | - try: |
132 | | - fn_kwargs["model_args"] = json.loads(fn_kwargs["model_args"]) |
133 | | - except json.JSONDecodeError: |
134 | | - raise |
135 | | - |
136 | | - if fn_kwargs.get("generate_kwargs") is not None: |
137 | | - try: |
138 | | - fn_kwargs["generate_kwargs"] = json.loads(fn_kwargs["generate_kwargs"]) |
139 | | - except json.JSONDecodeError: |
140 | | - raise |
141 | | - |
142 | | - print(f"🔹 Running {args.method} inference with arguments:") |
143 | | - for k, v in fn_kwargs.items(): |
144 | | - print(f" {k}: {v}") |
145 | | - |
146 | | - results = inference_fn(**fn_kwargs) |
147 | | - print(f"✅ Inference complete. Generated {len(results)} results.") |
148 | | - |
149 | | - else: |
150 | | - parser.print_help() |
| 5 | + """Main CLI entry point.""" |
| 6 | + parser = LLMSQLCLI() |
| 7 | + args = parser.parse_args() |
| 8 | + parser.execute(args) |
151 | 9 |
|
152 | 10 |
|
153 | 11 | if __name__ == "__main__": |
|
0 commit comments