diff --git a/examples/inference_transformers.ipynb b/examples/inference_transformers.ipynb index 85fc868..a59bfa7 100644 --- a/examples/inference_transformers.ipynb +++ b/examples/inference_transformers.ipynb @@ -62,7 +62,26 @@ ], "source": [ "from llmsql import inference_transformers\n", - "results = inference_transformers(model_or_model_name_or_path=\"EleutherAI/pythia-14m\", output_file=\"test_output.jsonl\", batch_size=5000, do_sample=False)" + "\n", + "# Example 1: Basic usage (same as before)\n", + "results = inference_transformers(\n", + " model_or_model_name_or_path=\"EleutherAI/pythia-14m\",\n", + " output_file=\"test_output.jsonl\",\n", + " batch_size=5000,\n", + " do_sample=False,\n", + ")\n", + "\n", + "# # Example 2: Using the new kwargs for advanced options\n", + "# results = inference_transformers(\n", + "# model_or_model_name_or_path=\"EleutherAI/pythia-14m\",\n", + "# output_file=\"test_output.jsonl\",\n", + "# batch_size=5000,\n", + "# do_sample=False,\n", + "# # Advanced model loading options\n", + "# model_kwargs={\"low_cpu_mem_usage\": True, \"attn_implementation\": \"flash_attention_2\"},\n", + "# # Advanced generation options\n", + "# generation_kwargs={\"repetition_penalty\": 1.1, \"length_penalty\": 1.0},\n", + "# )" ] } ], diff --git a/examples/inference_vllm.ipynb b/examples/inference_vllm.ipynb index 1622a9b..6e26683 100644 --- a/examples/inference_vllm.ipynb +++ b/examples/inference_vllm.ipynb @@ -173,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "edc910ac", "metadata": {}, "outputs": [ @@ -333,12 +333,34 @@ ], "source": [ "from llmsql import inference_vllm\n", + "\n", + "# Basic usage (backward compatible)\n", "results = inference_vllm(\n", - " \"Qwen/Qwen2.5-1.5B-Instruct\",\n", - " \"test_results.jsonl\",\n", + " model_name=\"EleutherAI/pythia-14m\",\n", + " output_file=\"test_output.jsonl\",\n", + " batch_size=5000,\n", " do_sample=False,\n", - " batch_size=20000\n", - ")" + ")\n", + "\n", + "# # Advanced usage with new kwargs\n", + "# results = inference_vllm(\n", + "# model_name=\"EleutherAI/pythia-14m\",\n", + "# output_file=\"test_output.jsonl\",\n", + "# batch_size=5000,\n", + "# do_sample=False,\n", + "# # vLLM-specific options\n", + "# llm_kwargs={\n", + "# \"gpu_memory_utilization\": 0.9,\n", + "# \"max_model_len\": 4096,\n", + "# \"quantization\": \"awq\",\n", + "# },\n", + "# # Advanced sampling options\n", + "# sampling_kwargs={\n", + "# \"top_p\": 0.95,\n", + "# \"frequency_penalty\": 0.1,\n", + "# \"presence_penalty\": 0.1,\n", + "# },\n", + "# )" ] }, { diff --git a/llmsql/inference/README.md b/llmsql/inference/README.md index 13c4e48..352d13d 100644 --- a/llmsql/inference/README.md +++ b/llmsql/inference/README.md @@ -33,21 +33,10 @@ pip install llmsql[vllm] from llmsql import inference_transformers results = inference_transformers( - model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct", - output_file="outputs/preds_transformers.jsonl", - questions_path="data/questions.jsonl", - tables_path="data/tables.jsonl", - num_fewshots=5, - batch_size=8, - max_new_tokens=256, - temperature=0.7, - model_args={ - "attn_implementation": "flash_attention_2", - "torch_dtype": "bfloat16", - }, - generate_kwargs={ - "do_sample": False, - }, + model_or_model_name_or_path="EleutherAI/pythia-14m", + output_file="test_output.jsonl", + batch_size=5000, + do_sample=False, ) ``` @@ -59,19 +48,10 @@ results = inference_transformers( from llmsql import inference_vllm results = inference_vllm( - model_name="Qwen/Qwen2.5-1.5B-Instruct", - output_file="outputs/preds_vllm.jsonl", - questions_path="data/questions.jsonl", - tables_path="data/tables.jsonl", - num_fewshots=5, - batch_size=8, - max_new_tokens=256, + model_name="EleutherAI/pythia-14m", + output_file="test_output.jsonl", + batch_size=5000, do_sample=False, - llm_kwargs={ - "tensor_parallel_size": 1, - "gpu_memory_utilization": 0.9, - "max_model_len": 4096, - }, ) ``` @@ -97,8 +77,7 @@ llmsql inference --method transformers \ --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \ --output-file outputs/preds.jsonl \ --batch-size 8 \ - --temperature 0.9 \ - --generate-kwargs '{"do_sample": false, "top_p": 0.95}' + --temperature 0.0 \ ``` 👉 Run `llmsql inference --help` for more detailed examples and parameter options. @@ -113,18 +92,49 @@ Runs inference using the Hugging Face `transformers` backend. **Parameters:** -| Argument | Type | Description | -| ------------------------------- | ------- | -------------------------------------------------------------- | -| `model_or_model_name_or_path` | `str` | Model name or local path (any causal LM). | -| `output_file` | `str` | Path to write predictions as JSONL. | -| `questions_path`, `tables_path` | `str` | Benchmark files (auto-downloaded if missing). | -| `num_fewshots` | `int` | Number of few-shot examples (0, 1, 5). | -| `batch_size` | `int` | Batch size for inference. | -| `max_new_tokens` | `int` | Maximum length of generated SQL queries. | -| `temperature` | `float` | Sampling temperature. | -| `do_sample` | `bool` | Whether to use sampling. | -| `model_args` | `dict` | Extra kwargs passed to `AutoModelForCausalLM.from_pretrained`. | -| `generate_kwargs` | `dict` | Extra kwargs passed to `model.generate()`. | +#### Model Loading + +| Argument | Type | Default | Description | +| ------------------------------- | --------------------- | ------------- | -------------------------------------------------------------- | +| `model_or_model_name_or_path` | `str \| AutoModelForCausalLM` | *required* | Model object, HuggingFace model name, or local path. | +| `tokenizer_or_name` | `str \| Any \| None` | `None` | Tokenizer object, name, or None (infers from model). | +| `trust_remote_code` | `bool` | `True` | Whether to trust remote code when loading models. | +| `dtype` | `torch.dtype` | `torch.float16` | Model precision (e.g., `torch.float16`, `torch.bfloat16`). | +| `device_map` | `str \| dict \| None` | `"auto"` | Device placement strategy for multi-GPU. | +| `hf_token` | `str \| None` | `None` | Hugging Face authentication token. | +| `model_kwargs` | `dict \| None` | `None` | Additional kwargs for `AutoModelForCausalLM.from_pretrained()`. | +| `tokenizer_kwargs` | `dict \| None` | `None` | Additional kwargs for `AutoTokenizer.from_pretrained()`. | + +#### Prompt & Chat + +| Argument | Type | Default | Description | +| ------------------------------- | -------------- | ------- | ------------------------------------------------ | +| `chat_template` | `str \| None` | `None` | Optional chat template string to apply. | + +#### Generation + +| Argument | Type | Default | Description | +| ------------------------------- | -------------- | ------- | ------------------------------------------------ | +| `max_new_tokens` | `int` | `256` | Maximum tokens to generate per sequence. | +| `temperature` | `float` | `0.0` | Sampling temperature (0.0 = greedy). | +| `do_sample` | `bool` | `False` | Whether to use sampling vs greedy decoding. | +| `top_p` | `float` | `1.0` | Nucleus sampling parameter. | +| `top_k` | `int` | `50` | Top-k sampling parameter. | +| `generation_kwargs` | `dict \| None` | `None` | Additional kwargs for `model.generate()`. | + +#### Benchmark + +| Argument | Type | Default | Description | +| ------------------------------- | ------- | ------------------------- | ------------------------------------------------ | +| `output_file` | `str` | `"outputs/predictions.jsonl"` | Path to write predictions as JSONL. | +| `questions_path` | `str \| None` | `None` | Path to questions.jsonl (auto-downloads if missing). | +| `tables_path` | `str \| None` | `None` | Path to tables.jsonl (auto-downloads if missing). | +| `workdir_path` | `str` | `"llmsql_workdir"` | Working directory for downloaded files. | +| `num_fewshots` | `int` | `5` | Number of few-shot examples (0, 1, or 5). | +| `batch_size` | `int` | `8` | Batch size for inference. | +| `seed` | `int` | `42` | Random seed for reproducibility. | + +**Note:** Explicit parameters (e.g., `dtype`, `trust_remote_code`) override any values specified in `model_kwargs` or `tokenizer_kwargs`. --- @@ -134,18 +144,39 @@ Runs inference using the [vLLM](https://github.com/vllm-project/vllm) backend fo **Parameters:** -| Argument | Type | Description | -| ------------------------------- | ------- | ------------------------------------------------ | -| `model_name` | `str` | Hugging Face model name or path. | -| `output_file` | `str` | Path to write predictions as JSONL. | -| `questions_path`, `tables_path` | `str` | Benchmark files (auto-downloaded if missing). | -| `num_fewshots` | `int` | Number of few-shot examples (0, 1, 5). | -| `batch_size` | `int` | Number of prompts per batch. | -| `max_new_tokens` | `int` | Maximum tokens per generation. | -| `temperature` | `float` | Sampling temperature. | -| `do_sample` | `bool` | Whether to sample or use greedy decoding. | -| `llm_kwargs` | `dict` | Extra kwargs forwarded to `vllm.LLM`. | -| `sampling_kwargs` | `dict` | Extra kwargs forwarded to `vllm.SamplingParams`. | +#### Model Loading + +| Argument | Type | Default | Description | +| ------------------------------- | -------------- | ------- | ------------------------------------------------ | +| `model_name` | `str` | *required* | Hugging Face model name or local path. | +| `trust_remote_code` | `bool` | `True` | Whether to trust remote code when loading. | +| `tensor_parallel_size` | `int` | `1` | Number of GPUs for tensor parallelism. | +| `hf_token` | `str \| None` | `None` | Hugging Face authentication token. | +| `llm_kwargs` | `dict \| None` | `None` | Additional kwargs for `vllm.LLM()`. | +| `llm_kwargs` | `bool` | `True` | Whether to use chat template of the tokenizer | + +#### Generation + +| Argument | Type | Default | Description | +| ------------------------------- | -------------- | ------- | ------------------------------------------------ | +| `max_new_tokens` | `int` | `256` | Maximum tokens to generate per sequence. | +| `temperature` | `float` | `1.0` | Sampling temperature (0.0 = greedy). | +| `do_sample` | `bool` | `True` | Whether to use sampling vs greedy decoding. | +| `sampling_kwargs` | `dict \| None` | `None` | Additional kwargs for `vllm.SamplingParams()`. | + +#### Benchmark + +| Argument | Type | Default | Description | +| ------------------------------- | -------------- | ----------------------------- | ------------------------------------------------ | +| `output_file` | `str` | `"outputs/predictions.jsonl"` | Path to write predictions as JSONL. | +| `questions_path` | `str \| None` | `None` | Path to questions.jsonl (auto-downloads if missing). | +| `tables_path` | `str \| None` | `None` | Path to tables.jsonl (auto-downloads if missing). | +| `workdir_path` | `str` | `"llmsql_workdir"` | Working directory for downloaded files. | +| `num_fewshots` | `int` | `5` | Number of few-shot examples (0, 1, or 5). | +| `batch_size` | `int` | `8` | Number of prompts per batch. | +| `seed` | `int` | `42` | Random seed for reproducibility. | + +**Note:** Explicit parameters (e.g., `tensor_parallel_size`, `trust_remote_code`) override any values specified in `llm_kwargs` or `sampling_kwargs`. --- diff --git a/llmsql/inference/inference_transformers.py b/llmsql/inference/inference_transformers.py index 04d408a..904e55f 100644 --- a/llmsql/inference/inference_transformers.py +++ b/llmsql/inference/inference_transformers.py @@ -1,3 +1,25 @@ +""" +LLMSQL Transformers Inference Function + +This module provides a single function `inference_transformers()` that performs +text-to-SQL generation using large language models via the transformers backend. + +Example: + + from llmsql import inference_transformers + + results = inference_transformers( + model_or_model_name_or_path="EleutherAI/pythia-14m", + output_file="test_output.jsonl", + batch_size=5000, + do_sample=False, + # Advanced model loading options + model_kwargs={"low_cpu_mem_usage": True, "attn_implementation": "flash_attention_2"}, + # Advanced generation options + generation_kwargs={"repetition_penalty": 1.1, "length_penalty": 1.0}, + ) +""" + from pathlib import Path from typing import Any @@ -27,25 +49,31 @@ def inference_transformers( model_or_model_name_or_path: str | AutoModelForCausalLM, tokenizer_or_name: str | Any | None = None, *, - chat_template: str | None = None, - model_args: dict[str, Any] | None = None, - hf_token: str | None = None, - output_file: str = "outputs/predictions.jsonl", - questions_path: str | None = None, - tables_path: str | None = None, - workdir_path: str = DEFAULT_WORKDIR_PATH, - num_fewshots: int = 5, + # --- Model Loading Parameters --- trust_remote_code: bool = True, - batch_size: int = 8, + dtype: torch.dtype = torch.float16, + device_map: str | dict[str, int] | None = "auto", + hf_token: str | None = None, + model_kwargs: dict[str, Any] | None = None, + # --- Tokenizer Loading Parameters --- + tokenizer_kwargs: dict[str, Any] | None = None, + # --- Prompt & Chat Parameters --- + chat_template: str | None = None, + # --- Generation Parameters --- max_new_tokens: int = 256, temperature: float = 0.0, do_sample: bool = False, top_p: float = 1.0, top_k: int = 50, + generation_kwargs: dict[str, Any] | None = None, + # --- Benchmark Parameters --- + output_file: str = "llm_sql_predictions.jsonl", + questions_path: str | None = None, + tables_path: str | None = None, + workdir_path: str = DEFAULT_WORKDIR_PATH, + num_fewshots: int = 5, + batch_size: int = 8, seed: int = 42, - dtype: torch.dtype = torch.float16, - device_map: str | dict[str, int] | None = "auto", - generate_kwargs: dict[str, Any] | None = None, ) -> list[dict[str, str]]: """ Inference a causal model (Transformers) on the LLMSQL benchmark. @@ -53,27 +81,45 @@ def inference_transformers( Args: model_or_model_name_or_path: Model object or HF model name/path. tokenizer_or_name: Tokenizer object or HF tokenizer name/path. + + # Model Loading: + trust_remote_code: Whether to trust remote code (default: True). + dtype: Torch dtype for model (default: float16). + device_map: Device placement strategy (default: "auto"). + hf_token: Hugging Face authentication token. + model_kwargs: Additional arguments for AutoModelForCausalLM.from_pretrained(). + Note: 'dtype', 'device_map', 'trust_remote_code', 'token' + are handled separately and will override values here. + + # Tokenizer Loading: + tokenizer_kwargs: Additional arguments for AutoTokenizer.from_pretrained(). 'padding_side' defaults to "left". + Note: 'trust_remote_code', 'token' are handled separately and will override values here. + + + # Prompt & Chat: chat_template: Optional chat template to apply before tokenization. - model_args: Optional kwargs passed to `from_pretrained` if needed. - hf_token: Hugging Face token (optional). - output_file: Output JSONL file for completions. + + # Generation: + max_new_tokens: Maximum tokens to generate per sequence. + temperature: Sampling temperature (0.0 = greedy). + do_sample: Whether to use sampling vs greedy decoding. + top_p: Nucleus sampling parameter. + top_k: Top-k sampling parameter. + generation_kwargs: Additional arguments for model.generate(). + Note: 'max_new_tokens', 'temperature', 'do_sample', + 'top_p', 'top_k' are handled separately. + + # Benchmark: + output_file: Output JSONL file path for completions. questions_path: Path to benchmark questions JSONL. tables_path: Path to benchmark tables JSONL. - workdir_path: Work directory (default: "llmsql_workdir"). - num_fewshots: 0, 1, or 5 — prompt builder choice. + workdir_path: Working directory path. + num_fewshots: Number of few-shot examples (0, 1, or 5). batch_size: Batch size for inference. - max_new_tokens: Max tokens to generate. - temperature: Sampling temperature. - do_sample: Whether to sample or use greedy decoding. - top_p: Nucleus sampling parameter. - top_k: Top-k sampling parameter. - seed: Random seed. - dtype: Torch dtype (default: float16). - device_map: Device map ("auto" for multi-GPU). - **generate_kwargs: Extra arguments for `model.generate`. + seed: Random seed for reproducibility. Returns: - List[dict[str, str]]: Generated SQL results. + List of generated SQL results with metadata. """ # --- Setup --- _setup_seed(seed=seed) @@ -81,56 +127,66 @@ def inference_transformers( workdir = Path(workdir_path) workdir.mkdir(parents=True, exist_ok=True) - if generate_kwargs is None: - generate_kwargs = {} + model_kwargs = model_kwargs or {} + tokenizer_kwargs = tokenizer_kwargs or {} + generation_kwargs = generation_kwargs or {} - model_args = model_args or {} - if "torch_dtype" in model_args: - dtype = model_args.pop("torch_dtype") - if "trust_remote_code" in model_args: - trust_remote_code = model_args.pop("trust_remote_code") - - # --- Load model --- + # --- Load Model --- if isinstance(model_or_model_name_or_path, str): - model_args = model_args or {} - log.info(f"Loading model from: {model_or_model_name_or_path}") + load_args = { + "torch_dtype": dtype, + "device_map": device_map, + "trust_remote_code": trust_remote_code, + "token": hf_token, + **model_kwargs, + } + + print(f"Loading model from: {model_or_model_name_or_path}") model = AutoModelForCausalLM.from_pretrained( model_or_model_name_or_path, - torch_dtype=dtype, - device_map=device_map, - token=hf_token, - trust_remote_code=trust_remote_code, - **model_args, + **load_args, ) else: model = model_or_model_name_or_path - log.info(f"Using provided model object: {type(model)}") + print(f"Using provided model object: {type(model)}") - # --- Load tokenizer --- + # --- Load Tokenizer --- if tokenizer_or_name is None: if isinstance(model_or_model_name_or_path, str): - tokenizer = AutoTokenizer.from_pretrained( - model_or_model_name_or_path, - token=hf_token, - trust_remote_code=True, - padding_side="left" - ) + tok_name = model_or_model_name_or_path else: - raise ValueError("Tokenizer must be provided if model is passed directly.") + raise ValueError( + "tokenizer_or_name must be provided when passing a model object directly." + ) elif isinstance(tokenizer_or_name, str): - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_or_name, - token=hf_token, - trust_remote_code=True, - padding_side="left" - ) + tok_name = tokenizer_or_name else: + # Already a tokenizer object tokenizer = tokenizer_or_name + tok_name = None + + if tok_name: + load_tok_args = { + "trust_remote_code": True, + "token": hf_token, + "padding_side": tokenizer_kwargs.get("padding_side", "left"), + **tokenizer_kwargs, + } + tokenizer = AutoTokenizer.from_pretrained(tok_name, **load_tok_args) - # ensure pad token exists if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + gen_params = { + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "do_sample": do_sample, + "top_p": top_p, + "top_k": top_k, + "pad_token_id": tokenizer.pad_token_id, + **generation_kwargs, + } + model.eval() # --- Load necessary files --- @@ -185,13 +241,7 @@ def inference_transformers( outputs = model.generate( **inputs, - max_new_tokens=max_new_tokens, - temperature=temperature if do_sample else 0.0, - do_sample=do_sample, - top_p=top_p, - top_k=top_k, - pad_token_id=tokenizer.pad_token_id, - **generate_kwargs, + **gen_params, ) input_lengths = [len(ids) for ids in inputs["input_ids"]] diff --git a/llmsql/inference/inference_vllm.py b/llmsql/inference/inference_vllm.py index caef064..36b1c84 100644 --- a/llmsql/inference/inference_vllm.py +++ b/llmsql/inference/inference_vllm.py @@ -21,6 +21,8 @@ ) """ +from __future__ import annotations + import os os.environ["VLLM_USE_V1"] = "0" @@ -49,43 +51,65 @@ def inference_vllm( model_name: str, - output_file: str, + *, + # === Model Loading Parameters === + trust_remote_code: bool = True, + tensor_parallel_size: int = 1, + hf_token: str | None = None, + llm_kwargs: dict[str, Any] | None = None, + use_chat_template: bool = True, + # === Generation Parameters === + max_new_tokens: int = 256, + temperature: float = 1.0, + do_sample: bool = True, + sampling_kwargs: dict[str, Any] | None = None, + # === Benchmark Parameters === + output_file: str = "llm_sql_predictions.jsonl", questions_path: str | None = None, tables_path: str | None = None, - hf_token: str | None = None, - tensor_parallel_size: int = 1, - seed: int = 42, workdir_path: str = DEFAULT_WORKDIR_PATH, num_fewshots: int = 5, batch_size: int = 8, - max_new_tokens: int = 256, - temperature: float = 1.0, - do_sample: bool = True, - llm_kwargs: dict[str, Any] | None = None, + seed: int = 42, ) -> list[dict[str, str]]: """ Run SQL generation using vLLM. Args: model_name: Hugging Face model name or path. + + # Model Loading: + trust_remote_code: Whether to trust remote code (default: True). + tensor_parallel_size: Number of GPUs for tensor parallelism (default: 1). + hf_token: Hugging Face authentication token. + llm_kwargs: Additional arguments for vllm.LLM(). + Note: 'model', 'tokenizer', 'tensor_parallel_size', + 'trust_remote_code' are handled separately and will + override values here. + + # Generation: + max_new_tokens: Maximum tokens to generate per sequence. + temperature: Sampling temperature (0.0 = greedy). + do_sample: Whether to use sampling vs greedy decoding. + sampling_kwargs: Additional arguments for vllm.SamplingParams(). + Note: 'temperature', 'max_tokens' are handled + separately and will override values here. + + # Benchmark: output_file: Path to write outputs (will be overwritten). - questions_path: Path to questions.jsonl (optional, auto-download if missing). - tables_path: Path to tables.jsonl (optional, auto-download if missing). - hf_token: Hugging Face auth token. - tensor_parallel_size: Degree of tensor parallelism (for multi-GPU). - seed: Random seed. - workdir_path: Directory to store any downloaded data. - num_fewshots: Number of examples per prompt (0, 1, or 5). + questions_path: Path to questions.jsonl (auto-downloads if missing). + tables_path: Path to tables.jsonl (auto-downloads if missing). + workdir_path: Directory to store downloaded data. + num_fewshots: Number of few-shot examples (0, 1, or 5). batch_size: Number of questions per generation batch. - max_new_tokens: Max tokens to generate. - temperature: Sampling temperature. - do_sample: Whether to sample or use greedy decoding. - **llm_kwargs: Extra kwargs forwarded to vllm.LLM(). + seed: Random seed for reproducibility. Returns: List of dicts containing `question_id` and generated `completion`. """ # --- setup --- + llm_kwargs = llm_kwargs or {} + sampling_kwargs = sampling_kwargs or {} _setup_seed(seed=seed) hf_token = hf_token or os.environ.get("HF_TOKEN") @@ -101,19 +125,21 @@ def inference_vllm( tables = {t["table_id"]: t for t in tables_list} # --- init model --- - llm_kwargs = llm_kwargs or {} - if "tensor_parallel_size" in llm_kwargs: - tensor_parallel_size = llm_kwargs.pop("tensor_parallel_size") + llm_init_args = { + "model": model_name, + "tokenizer": model_name, + "tensor_parallel_size": tensor_parallel_size, + "trust_remote_code": trust_remote_code, + **llm_kwargs, # User kwargs come first, but explicit params above will override + } log.info(f"Loading vLLM model '{model_name}' (tp={tensor_parallel_size})...") - llm = LLM( - model=model_name, - tokenizer=model_name, - tensor_parallel_size=tensor_parallel_size, - trust_remote_code=True, - **llm_kwargs, - ) + llm = LLM(**llm_init_args) + + tokenizer = llm.get_tokenizer() + if use_chat_template: + use_chat_template = getattr(tokenizer, "chat_template", None) # type: ignore # --- prepare output file --- overwrite_jsonl(output_file) @@ -121,11 +147,16 @@ def inference_vllm( # --- prompt builder and sampling params --- prompt_builder = choose_prompt_builder(num_fewshots) - temperature = 0.0 if not do_sample else temperature - sampling_params = SamplingParams( - temperature=temperature, - max_tokens=max_new_tokens, - ) + + effective_temperature = 0.0 if not do_sample else temperature + + sampling_params_args = { + "temperature": effective_temperature, + "max_tokens": max_new_tokens, + **sampling_kwargs, + } + + sampling_params = SamplingParams(**sampling_params_args) # --- main inference loop --- all_results: list[dict[str, str]] = [] @@ -138,10 +169,22 @@ def inference_vllm( for q in batch: tbl = tables[q["table_id"]] example_row = tbl["rows"][0] if tbl["rows"] else [] - prompts.append( - prompt_builder(q["question"], tbl["header"], tbl["types"], example_row) + + raw_text = prompt_builder( + q["question"], tbl["header"], tbl["types"], example_row ) + if use_chat_template: + messages = [{"role": "user", "content": raw_text}] + + final_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + final_prompt = raw_text + + prompts.append(final_prompt) + outputs = llm.generate(prompts, sampling_params) batch_results: list[dict[str, str]] = []