|
1 | | -# import argparse |
2 | | -# import json |
3 | | -# import textwrap |
4 | | -# from typing import Any |
5 | | - |
6 | | -# from llmsql import inference_transformers, inference_vllm |
7 | | -# from llmsql.loggers.logging_config import log |
8 | | - |
9 | | - |
10 | | -# def _json_arg(value: str) -> dict[str, Any]: |
11 | | -# """Parse a JSON string into dict.""" |
12 | | -# if value is None: |
13 | | -# return None |
14 | | -# try: |
15 | | -# return json.loads(value) |
16 | | -# except json.JSONDecodeError: |
17 | | -# return json.JSONDecodeError(f"Invalid JSON: {value}") |
18 | | - |
19 | | - |
20 | | -# class InferenceCommand: |
21 | | -# """CLI registration + dispatch for `llmsql inference`.""" |
22 | | - |
23 | | -# @staticmethod |
24 | | -# def register(subparsers): |
25 | | -# inference_examples = textwrap.dedent(""" |
26 | | -# Examples: |
27 | | - |
28 | | -# # 1️⃣ Run inference with Transformers backend |
29 | | -# llmsql inference --method transformers \\ |
30 | | -# --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \\ |
31 | | -# --output-file outputs/preds_transformers.jsonl \\ |
32 | | -# --batch-size 8 \\ |
33 | | -# --num-fewshots 5 |
34 | | - |
35 | | -# # 2️⃣ Run inference with vLLM backend |
36 | | -# llmsql inference --method vllm \\ |
37 | | -# --model-name Qwen/Qwen2.5-1.5B-Instruct \\ |
38 | | -# --output-file outputs/preds_vllm.jsonl \\ |
39 | | -# --batch-size 8 \\ |
40 | | -# --num-fewshots 5 |
41 | | - |
42 | | -# # 3️⃣ Pass model init kwargs (Transformers) |
43 | | -# llmsql inference --method transformers \\ |
44 | | -# --model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \\ |
45 | | -# --output-file outputs/llama_preds.jsonl \\ |
46 | | -# --model-kwargs '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}' |
47 | | - |
48 | | -# # 4️⃣ Pass LLM init kwargs (vLLM) |
49 | | -# llmsql inference --method vllm \\ |
50 | | -# --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 \\ |
51 | | -# --output-file outputs/mixtral_preds.jsonl \\ |
52 | | -# --llm-kwargs '{"max_model_len": 4096, "gpu_memory_utilization": 0.9}' |
53 | | - |
54 | | -# # 5️⃣ Override generation parameters dynamically (Transformers) |
55 | | -# llmsql inference --method transformers \\ |
56 | | -# --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \\ |
57 | | -# --output-file outputs/temp_0.9.jsonl \\ |
58 | | -# --temperature 0.9 \\ |
59 | | -# --generation-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}' |
60 | | - |
61 | | -# See the `llmsql.inference_transformers()` and `llmsql.inference_vllm()` for more information about the arguments. |
62 | | -# """) |
63 | | - |
64 | | -# parser = subparsers.add_parser( |
65 | | -# "inference", |
66 | | -# help="Run inference using Transformers or vLLM backend.", |
67 | | -# description="Run SQL generation using a chosen inference backend.", |
68 | | -# epilog=inference_examples, |
69 | | -# formatter_class=argparse.RawTextHelpFormatter, |
70 | | -# ) |
71 | | - |
72 | | -# parser.add_argument( |
73 | | -# "--method", |
74 | | -# required=True, |
75 | | -# choices=["vllm", "transformers"], |
76 | | -# help="Backend: 'vllm' or 'transformers'", |
77 | | -# ) |
78 | | - |
79 | | -# # This parser only parses --method. The backend-specific parsers come later. |
80 | | -# parser.set_defaults(func=InferenceCommand.dispatch) |
81 | | - |
82 | | -# # ------------------------------------------------------------------ |
83 | | -# @staticmethod |
84 | | -# def dispatch(args): |
85 | | -# """Dispatch to vLLM or Transformers backend.""" |
86 | | -# if args.method == "vllm": |
87 | | -# InferenceCommand._run_vllm() |
88 | | -# else: |
89 | | -# InferenceCommand._run_transformers() |
90 | | - |
91 | | -# # ------------------------------------------------------------------ |
92 | | -# @staticmethod |
93 | | -# def _run_vllm(): |
94 | | -# parser = argparse.ArgumentParser( |
95 | | -# prog="llmsql inference --method vllm", |
96 | | -# description="Inference using vLLM backend", |
97 | | -# help="Something for vllm", |
98 | | -# ) |
99 | | - |
100 | | -# parser.add_argument("--model-name", required=True) |
101 | | - |
102 | | -# parser.add_argument("--trust-remote-code", action="store_true", default=True) |
103 | | -# parser.add_argument("--tensor-parallel-size", type=int, default=1) |
104 | | -# parser.add_argument("--hf-token", type=str, default=None) |
105 | | -# parser.add_argument("--llm-kwargs", type=_json_arg, default=None) |
106 | | -# parser.add_argument("--use-chat-template", action="store_true", default=True) |
107 | | - |
108 | | -# parser.add_argument("--max-new-tokens", type=int, default=256) |
109 | | -# parser.add_argument("--temperature", type=float, default=1.0) |
110 | | -# parser.add_argument("--do-sample", action="store_true", default=True) |
111 | | -# parser.add_argument("--sampling-kwargs", type=_json_arg, default=None) |
112 | | - |
113 | | -# parser.add_argument("--output-file", default="llm_sql_predictions.jsonl") |
114 | | -# parser.add_argument("--questions-path", type=str, default=None) |
115 | | -# parser.add_argument("--tables-path", type=str, default=None) |
116 | | -# parser.add_argument("--workdir-path", default=None) |
117 | | -# parser.add_argument("--num-fewshots", type=int, default=5) |
118 | | -# parser.add_argument("--batch-size", type=int, default=8) |
119 | | -# parser.add_argument("--seed", type=int, default=42) |
120 | | - |
121 | | -# args = parser.parse_args() |
122 | | -# results = inference_vllm(**vars(args)) |
123 | | -# log.info(f"Generated {len(results)} results.") |
124 | | - |
125 | | -# # ------------------------------------------------------------------ |
126 | | -# @staticmethod |
127 | | -# def _run_transformers(): |
128 | | -# parser = argparse.ArgumentParser( |
129 | | -# prog="llmsql inference --method transformers", |
130 | | -# description="Inference using Transformers backend", |
131 | | -# help="Something for transformers", |
132 | | -# ) |
133 | | - |
134 | | -# parser.add_argument("--model-or-model-name-or-path", required=True) |
135 | | -# parser.add_argument("--tokenizer-or-name", default=None) |
136 | | - |
137 | | -# parser.add_argument("--trust-remote-code", action="store_true", default=True) |
138 | | -# parser.add_argument("--dtype", default="float16") |
139 | | -# parser.add_argument("--device-map", default="auto") |
140 | | -# parser.add_argument("--hf-token", type=str, default=None) |
141 | | -# parser.add_argument("--model-kwargs", type=_json_arg, default=None) |
142 | | - |
143 | | -# parser.add_argument("--tokenizer-kwargs", type=_json_arg, default=None) |
144 | | - |
145 | | -# parser.add_argument("--chat-template", type=str, default=None) |
146 | | - |
147 | | -# parser.add_argument("--max-new-tokens", type=int, default=256) |
148 | | -# parser.add_argument("--temperature", type=float, default=0.0) |
149 | | -# parser.add_argument("--do-sample", action="store_true", default=False) |
150 | | -# parser.add_argument("--top-p", type=float, default=1.0) |
151 | | -# parser.add_argument("--top-k", type=int, default=50) |
152 | | -# parser.add_argument("--generation-kwargs", type=_json_arg, default=None) |
153 | | - |
154 | | -# parser.add_argument("--output-file", default="llm_sql_predictions.jsonl") |
155 | | -# parser.add_argument("--questions-path", type=str, default=None) |
156 | | -# parser.add_argument("--tables-path", type=str, default=None) |
157 | | -# parser.add_argument("--workdir-path", default=None) |
158 | | -# parser.add_argument("--num-fewshots", type=int, default=5) |
159 | | -# parser.add_argument("--batch-size", type=int, default=8) |
160 | | -# parser.add_argument("--seed", type=int, default=42) |
161 | | - |
162 | | -# args = parser.parse_args() |
163 | | -# results = inference_transformers(**vars(args)) |
164 | | -# print(f"Generated {len(results)} results.") |
165 | | - |
166 | 1 | import argparse |
167 | 2 | import json |
168 | 3 | import textwrap |
|
0 commit comments