Skip to content

Commit d8442db

Browse files
Merge pull request #47 from LLMSQL/46-add-custom-chat-template-handaling
46 add custom chat template handaling
2 parents 008de41 + e53f76b commit d8442db

File tree

5 files changed

+326
-161
lines changed

5 files changed

+326
-161
lines changed

examples/inference_transformers.ipynb

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,26 @@
6262
],
6363
"source": [
6464
"from llmsql import inference_transformers\n",
65-
"results = inference_transformers(model_or_model_name_or_path=\"EleutherAI/pythia-14m\", output_file=\"test_output.jsonl\", batch_size=5000, do_sample=False)"
65+
"\n",
66+
"# Example 1: Basic usage (same as before)\n",
67+
"results = inference_transformers(\n",
68+
" model_or_model_name_or_path=\"EleutherAI/pythia-14m\",\n",
69+
" output_file=\"test_output.jsonl\",\n",
70+
" batch_size=5000,\n",
71+
" do_sample=False,\n",
72+
")\n",
73+
"\n",
74+
"# # Example 2: Using the new kwargs for advanced options\n",
75+
"# results = inference_transformers(\n",
76+
"# model_or_model_name_or_path=\"EleutherAI/pythia-14m\",\n",
77+
"# output_file=\"test_output.jsonl\",\n",
78+
"# batch_size=5000,\n",
79+
"# do_sample=False,\n",
80+
"# # Advanced model loading options\n",
81+
"# model_kwargs={\"low_cpu_mem_usage\": True, \"attn_implementation\": \"flash_attention_2\"},\n",
82+
"# # Advanced generation options\n",
83+
"# generation_kwargs={\"repetition_penalty\": 1.1, \"length_penalty\": 1.0},\n",
84+
"# )"
6685
]
6786
}
6887
],

examples/inference_vllm.ipynb

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@
173173
},
174174
{
175175
"cell_type": "code",
176-
"execution_count": 1,
176+
"execution_count": null,
177177
"id": "edc910ac",
178178
"metadata": {},
179179
"outputs": [
@@ -333,12 +333,34 @@
333333
],
334334
"source": [
335335
"from llmsql import inference_vllm\n",
336+
"\n",
337+
"# Basic usage (backward compatible)\n",
336338
"results = inference_vllm(\n",
337-
" \"Qwen/Qwen2.5-1.5B-Instruct\",\n",
338-
" \"test_results.jsonl\",\n",
339+
" model_name=\"EleutherAI/pythia-14m\",\n",
340+
" output_file=\"test_output.jsonl\",\n",
341+
" batch_size=5000,\n",
339342
" do_sample=False,\n",
340-
" batch_size=20000\n",
341-
")"
343+
")\n",
344+
"\n",
345+
"# # Advanced usage with new kwargs\n",
346+
"# results = inference_vllm(\n",
347+
"# model_name=\"EleutherAI/pythia-14m\",\n",
348+
"# output_file=\"test_output.jsonl\",\n",
349+
"# batch_size=5000,\n",
350+
"# do_sample=False,\n",
351+
"# # vLLM-specific options\n",
352+
"# llm_kwargs={\n",
353+
"# \"gpu_memory_utilization\": 0.9,\n",
354+
"# \"max_model_len\": 4096,\n",
355+
"# \"quantization\": \"awq\",\n",
356+
"# },\n",
357+
"# # Advanced sampling options\n",
358+
"# sampling_kwargs={\n",
359+
"# \"top_p\": 0.95,\n",
360+
"# \"frequency_penalty\": 0.1,\n",
361+
"# \"presence_penalty\": 0.1,\n",
362+
"# },\n",
363+
"# )"
342364
]
343365
},
344366
{

llmsql/inference/README.md

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,10 @@ pip install llmsql[vllm]
3333
from llmsql import inference_transformers
3434

3535
results = inference_transformers(
36-
model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
37-
output_file="outputs/preds_transformers.jsonl",
38-
questions_path="data/questions.jsonl",
39-
tables_path="data/tables.jsonl",
40-
num_fewshots=5,
41-
batch_size=8,
42-
max_new_tokens=256,
43-
temperature=0.7,
44-
model_args={
45-
"attn_implementation": "flash_attention_2",
46-
"torch_dtype": "bfloat16",
47-
},
48-
generate_kwargs={
49-
"do_sample": False,
50-
},
36+
model_or_model_name_or_path="EleutherAI/pythia-14m",
37+
output_file="test_output.jsonl",
38+
batch_size=5000,
39+
do_sample=False,
5140
)
5241
```
5342

@@ -59,19 +48,10 @@ results = inference_transformers(
5948
from llmsql import inference_vllm
6049

6150
results = inference_vllm(
62-
model_name="Qwen/Qwen2.5-1.5B-Instruct",
63-
output_file="outputs/preds_vllm.jsonl",
64-
questions_path="data/questions.jsonl",
65-
tables_path="data/tables.jsonl",
66-
num_fewshots=5,
67-
batch_size=8,
68-
max_new_tokens=256,
51+
model_name="EleutherAI/pythia-14m",
52+
output_file="test_output.jsonl",
53+
batch_size=5000,
6954
do_sample=False,
70-
llm_kwargs={
71-
"tensor_parallel_size": 1,
72-
"gpu_memory_utilization": 0.9,
73-
"max_model_len": 4096,
74-
},
7555
)
7656
```
7757

@@ -97,8 +77,7 @@ llmsql inference --method transformers \
9777
--model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
9878
--output-file outputs/preds.jsonl \
9979
--batch-size 8 \
100-
--temperature 0.9 \
101-
--generate-kwargs '{"do_sample": false, "top_p": 0.95}'
80+
--temperature 0.0 \
10281
```
10382

10483
👉 Run `llmsql inference --help` for more detailed examples and parameter options.
@@ -113,18 +92,49 @@ Runs inference using the Hugging Face `transformers` backend.
11392

11493
**Parameters:**
11594

116-
| Argument | Type | Description |
117-
| ------------------------------- | ------- | -------------------------------------------------------------- |
118-
| `model_or_model_name_or_path` | `str` | Model name or local path (any causal LM). |
119-
| `output_file` | `str` | Path to write predictions as JSONL. |
120-
| `questions_path`, `tables_path` | `str` | Benchmark files (auto-downloaded if missing). |
121-
| `num_fewshots` | `int` | Number of few-shot examples (0, 1, 5). |
122-
| `batch_size` | `int` | Batch size for inference. |
123-
| `max_new_tokens` | `int` | Maximum length of generated SQL queries. |
124-
| `temperature` | `float` | Sampling temperature. |
125-
| `do_sample` | `bool` | Whether to use sampling. |
126-
| `model_args` | `dict` | Extra kwargs passed to `AutoModelForCausalLM.from_pretrained`. |
127-
| `generate_kwargs` | `dict` | Extra kwargs passed to `model.generate()`. |
95+
#### Model Loading
96+
97+
| Argument | Type | Default | Description |
98+
| ------------------------------- | --------------------- | ------------- | -------------------------------------------------------------- |
99+
| `model_or_model_name_or_path` | `str \| AutoModelForCausalLM` | *required* | Model object, HuggingFace model name, or local path. |
100+
| `tokenizer_or_name` | `str \| Any \| None` | `None` | Tokenizer object, name, or None (infers from model). |
101+
| `trust_remote_code` | `bool` | `True` | Whether to trust remote code when loading models. |
102+
| `dtype` | `torch.dtype` | `torch.float16` | Model precision (e.g., `torch.float16`, `torch.bfloat16`). |
103+
| `device_map` | `str \| dict \| None` | `"auto"` | Device placement strategy for multi-GPU. |
104+
| `hf_token` | `str \| None` | `None` | Hugging Face authentication token. |
105+
| `model_kwargs` | `dict \| None` | `None` | Additional kwargs for `AutoModelForCausalLM.from_pretrained()`. |
106+
| `tokenizer_kwargs` | `dict \| None` | `None` | Additional kwargs for `AutoTokenizer.from_pretrained()`. |
107+
108+
#### Prompt & Chat
109+
110+
| Argument | Type | Default | Description |
111+
| ------------------------------- | -------------- | ------- | ------------------------------------------------ |
112+
| `chat_template` | `str \| None` | `None` | Optional chat template string to apply. |
113+
114+
#### Generation
115+
116+
| Argument | Type | Default | Description |
117+
| ------------------------------- | -------------- | ------- | ------------------------------------------------ |
118+
| `max_new_tokens` | `int` | `256` | Maximum tokens to generate per sequence. |
119+
| `temperature` | `float` | `0.0` | Sampling temperature (0.0 = greedy). |
120+
| `do_sample` | `bool` | `False` | Whether to use sampling vs greedy decoding. |
121+
| `top_p` | `float` | `1.0` | Nucleus sampling parameter. |
122+
| `top_k` | `int` | `50` | Top-k sampling parameter. |
123+
| `generation_kwargs` | `dict \| None` | `None` | Additional kwargs for `model.generate()`. |
124+
125+
#### Benchmark
126+
127+
| Argument | Type | Default | Description |
128+
| ------------------------------- | ------- | ------------------------- | ------------------------------------------------ |
129+
| `output_file` | `str` | `"outputs/predictions.jsonl"` | Path to write predictions as JSONL. |
130+
| `questions_path` | `str \| None` | `None` | Path to questions.jsonl (auto-downloads if missing). |
131+
| `tables_path` | `str \| None` | `None` | Path to tables.jsonl (auto-downloads if missing). |
132+
| `workdir_path` | `str` | `"llmsql_workdir"` | Working directory for downloaded files. |
133+
| `num_fewshots` | `int` | `5` | Number of few-shot examples (0, 1, or 5). |
134+
| `batch_size` | `int` | `8` | Batch size for inference. |
135+
| `seed` | `int` | `42` | Random seed for reproducibility. |
136+
137+
**Note:** Explicit parameters (e.g., `dtype`, `trust_remote_code`) override any values specified in `model_kwargs` or `tokenizer_kwargs`.
128138

129139
---
130140

@@ -134,18 +144,39 @@ Runs inference using the [vLLM](https://github.com/vllm-project/vllm) backend fo
134144

135145
**Parameters:**
136146

137-
| Argument | Type | Description |
138-
| ------------------------------- | ------- | ------------------------------------------------ |
139-
| `model_name` | `str` | Hugging Face model name or path. |
140-
| `output_file` | `str` | Path to write predictions as JSONL. |
141-
| `questions_path`, `tables_path` | `str` | Benchmark files (auto-downloaded if missing). |
142-
| `num_fewshots` | `int` | Number of few-shot examples (0, 1, 5). |
143-
| `batch_size` | `int` | Number of prompts per batch. |
144-
| `max_new_tokens` | `int` | Maximum tokens per generation. |
145-
| `temperature` | `float` | Sampling temperature. |
146-
| `do_sample` | `bool` | Whether to sample or use greedy decoding. |
147-
| `llm_kwargs` | `dict` | Extra kwargs forwarded to `vllm.LLM`. |
148-
| `sampling_kwargs` | `dict` | Extra kwargs forwarded to `vllm.SamplingParams`. |
147+
#### Model Loading
148+
149+
| Argument | Type | Default | Description |
150+
| ------------------------------- | -------------- | ------- | ------------------------------------------------ |
151+
| `model_name` | `str` | *required* | Hugging Face model name or local path. |
152+
| `trust_remote_code` | `bool` | `True` | Whether to trust remote code when loading. |
153+
| `tensor_parallel_size` | `int` | `1` | Number of GPUs for tensor parallelism. |
154+
| `hf_token` | `str \| None` | `None` | Hugging Face authentication token. |
155+
| `llm_kwargs` | `dict \| None` | `None` | Additional kwargs for `vllm.LLM()`. |
156+
| `llm_kwargs` | `bool` | `True` | Whether to use chat template of the tokenizer |
157+
158+
#### Generation
159+
160+
| Argument | Type | Default | Description |
161+
| ------------------------------- | -------------- | ------- | ------------------------------------------------ |
162+
| `max_new_tokens` | `int` | `256` | Maximum tokens to generate per sequence. |
163+
| `temperature` | `float` | `1.0` | Sampling temperature (0.0 = greedy). |
164+
| `do_sample` | `bool` | `True` | Whether to use sampling vs greedy decoding. |
165+
| `sampling_kwargs` | `dict \| None` | `None` | Additional kwargs for `vllm.SamplingParams()`. |
166+
167+
#### Benchmark
168+
169+
| Argument | Type | Default | Description |
170+
| ------------------------------- | -------------- | ----------------------------- | ------------------------------------------------ |
171+
| `output_file` | `str` | `"outputs/predictions.jsonl"` | Path to write predictions as JSONL. |
172+
| `questions_path` | `str \| None` | `None` | Path to questions.jsonl (auto-downloads if missing). |
173+
| `tables_path` | `str \| None` | `None` | Path to tables.jsonl (auto-downloads if missing). |
174+
| `workdir_path` | `str` | `"llmsql_workdir"` | Working directory for downloaded files. |
175+
| `num_fewshots` | `int` | `5` | Number of few-shot examples (0, 1, or 5). |
176+
| `batch_size` | `int` | `8` | Number of prompts per batch. |
177+
| `seed` | `int` | `42` | Random seed for reproducibility. |
178+
179+
**Note:** Explicit parameters (e.g., `tensor_parallel_size`, `trust_remote_code`) override any values specified in `llm_kwargs` or `sampling_kwargs`.
149180

150181
---
151182

0 commit comments

Comments
 (0)