Skip to content

Commit 121a14b

Browse files
inference_transformers corrected;
1 parent 008de41 commit 121a14b

File tree

2 files changed

+114
-68
lines changed

2 files changed

+114
-68
lines changed

examples/inference_transformers.ipynb

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,26 @@
6262
],
6363
"source": [
6464
"from llmsql import inference_transformers\n",
65-
"results = inference_transformers(model_or_model_name_or_path=\"EleutherAI/pythia-14m\", output_file=\"test_output.jsonl\", batch_size=5000, do_sample=False)"
65+
"\n",
66+
"# Example 1: Basic usage (same as before)\n",
67+
"results = inference_transformers(\n",
68+
" model_or_model_name_or_path=\"EleutherAI/pythia-14m\",\n",
69+
" output_file=\"test_output.jsonl\",\n",
70+
" batch_size=5000,\n",
71+
" do_sample=False,\n",
72+
")\n",
73+
"\n",
74+
"# # Example 2: Using the new kwargs for advanced options\n",
75+
"# results = inference_transformers(\n",
76+
"# model_or_model_name_or_path=\"EleutherAI/pythia-14m\",\n",
77+
"# output_file=\"test_output.jsonl\",\n",
78+
"# batch_size=5000,\n",
79+
"# do_sample=False,\n",
80+
"# # Advanced model loading options\n",
81+
"# model_kwargs={\"low_cpu_mem_usage\": True, \"attn_implementation\": \"flash_attention_2\"},\n",
82+
"# # Advanced generation options\n",
83+
"# generation_kwargs={\"repetition_penalty\": 1.1, \"length_penalty\": 1.0},\n",
84+
"# )"
6685
]
6786
}
6887
],

llmsql/inference/inference_transformers.py

Lines changed: 94 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from tqdm import tqdm
77
from transformers import AutoModelForCausalLM, AutoTokenizer
88

9-
from llmsql.config.config import DEFAULT_WORKDIR_PATH
109
from llmsql.loggers.logging_config import log
1110
from llmsql.utils.inference_utils import _maybe_download, _setup_seed
1211
from llmsql.utils.utils import (
@@ -27,110 +26,144 @@ def inference_transformers(
2726
model_or_model_name_or_path: str | AutoModelForCausalLM,
2827
tokenizer_or_name: str | Any | None = None,
2928
*,
30-
chat_template: str | None = None,
31-
model_args: dict[str, Any] | None = None,
32-
hf_token: str | None = None,
33-
output_file: str = "outputs/predictions.jsonl",
34-
questions_path: str | None = None,
35-
tables_path: str | None = None,
36-
workdir_path: str = DEFAULT_WORKDIR_PATH,
37-
num_fewshots: int = 5,
29+
# --- Model Loading Parameters ---
3830
trust_remote_code: bool = True,
39-
batch_size: int = 8,
31+
dtype: torch.dtype = torch.float16,
32+
device_map: str | dict[str, int] | None = "auto",
33+
hf_token: str | None = None,
34+
model_kwargs: dict[str, Any] | None = None,
35+
# --- Tokenizer Loading Parameters ---
36+
tokenizer_kwargs: dict[str, Any] | None = None,
37+
# --- Prompt & Chat Parameters ---
38+
chat_template: str | None = None,
39+
# --- Generation Parameters ---
4040
max_new_tokens: int = 256,
4141
temperature: float = 0.0,
4242
do_sample: bool = False,
4343
top_p: float = 1.0,
4444
top_k: int = 50,
45+
generation_kwargs: dict[str, Any] | None = None,
46+
# --- Benchmark Parameters ---
47+
output_file: str = "outputs/predictions.jsonl",
48+
questions_path: str | None = None,
49+
tables_path: str | None = None,
50+
workdir_path: str = "llmsql_workdir",
51+
num_fewshots: int = 5,
52+
batch_size: int = 8,
4553
seed: int = 42,
46-
dtype: torch.dtype = torch.float16,
47-
device_map: str | dict[str, int] | None = "auto",
48-
generate_kwargs: dict[str, Any] | None = None,
4954
) -> list[dict[str, str]]:
5055
"""
5156
Inference a causal model (Transformers) on the LLMSQL benchmark.
5257
5358
Args:
5459
model_or_model_name_or_path: Model object or HF model name/path.
5560
tokenizer_or_name: Tokenizer object or HF tokenizer name/path.
61+
62+
# Model Loading:
63+
trust_remote_code: Whether to trust remote code (default: True).
64+
dtype: Torch dtype for model (default: float16).
65+
device_map: Device placement strategy (default: "auto").
66+
hf_token: Hugging Face authentication token.
67+
model_kwargs: Additional arguments for AutoModelForCausalLM.from_pretrained().
68+
Note: 'dtype', 'device_map', 'trust_remote_code', 'token'
69+
are handled separately and will override values here.
70+
71+
# Tokenizer Loading:
72+
tokenizer_kwargs: Additional arguments for AutoTokenizer.from_pretrained(). 'padding_side' defaults to "left".
73+
Note: 'trust_remote_code', 'token' are handled separately and will override values here.
74+
75+
76+
# Prompt & Chat:
5677
chat_template: Optional chat template to apply before tokenization.
57-
model_args: Optional kwargs passed to `from_pretrained` if needed.
58-
hf_token: Hugging Face token (optional).
59-
output_file: Output JSONL file for completions.
78+
79+
# Generation:
80+
max_new_tokens: Maximum tokens to generate per sequence.
81+
temperature: Sampling temperature (0.0 = greedy).
82+
do_sample: Whether to use sampling vs greedy decoding.
83+
top_p: Nucleus sampling parameter.
84+
top_k: Top-k sampling parameter.
85+
generation_kwargs: Additional arguments for model.generate().
86+
Note: 'max_new_tokens', 'temperature', 'do_sample',
87+
'top_p', 'top_k' are handled separately.
88+
89+
# Benchmark:
90+
output_file: Output JSONL file path for completions.
6091
questions_path: Path to benchmark questions JSONL.
6192
tables_path: Path to benchmark tables JSONL.
62-
workdir_path: Work directory (default: "llmsql_workdir").
63-
num_fewshots: 0, 1, or 5 — prompt builder choice.
93+
workdir_path: Working directory path.
94+
num_fewshots: Number of few-shot examples (0, 1, or 5).
6495
batch_size: Batch size for inference.
65-
max_new_tokens: Max tokens to generate.
66-
temperature: Sampling temperature.
67-
do_sample: Whether to sample or use greedy decoding.
68-
top_p: Nucleus sampling parameter.
69-
top_k: Top-k sampling parameter.
70-
seed: Random seed.
71-
dtype: Torch dtype (default: float16).
72-
device_map: Device map ("auto" for multi-GPU).
73-
**generate_kwargs: Extra arguments for `model.generate`.
96+
seed: Random seed for reproducibility.
7497
7598
Returns:
76-
List[dict[str, str]]: Generated SQL results.
99+
List of generated SQL results with metadata.
77100
"""
78101
# --- Setup ---
79102
_setup_seed(seed=seed)
80103

81104
workdir = Path(workdir_path)
82105
workdir.mkdir(parents=True, exist_ok=True)
83106

84-
if generate_kwargs is None:
85-
generate_kwargs = {}
107+
model_kwargs = model_kwargs or {}
108+
tokenizer_kwargs = tokenizer_kwargs or {}
109+
generation_kwargs = generation_kwargs or {}
86110

87-
model_args = model_args or {}
88-
if "torch_dtype" in model_args:
89-
dtype = model_args.pop("torch_dtype")
90-
if "trust_remote_code" in model_args:
91-
trust_remote_code = model_args.pop("trust_remote_code")
92-
93-
# --- Load model ---
111+
# --- Load Model ---
94112
if isinstance(model_or_model_name_or_path, str):
95-
model_args = model_args or {}
96-
log.info(f"Loading model from: {model_or_model_name_or_path}")
113+
load_args = {
114+
"torch_dtype": dtype,
115+
"device_map": device_map,
116+
"trust_remote_code": trust_remote_code,
117+
"token": hf_token,
118+
**model_kwargs,
119+
}
120+
121+
print(f"Loading model from: {model_or_model_name_or_path}")
97122
model = AutoModelForCausalLM.from_pretrained(
98123
model_or_model_name_or_path,
99-
torch_dtype=dtype,
100-
device_map=device_map,
101-
token=hf_token,
102-
trust_remote_code=trust_remote_code,
103-
**model_args,
124+
**load_args,
104125
)
105126
else:
106127
model = model_or_model_name_or_path
107-
log.info(f"Using provided model object: {type(model)}")
128+
print(f"Using provided model object: {type(model)}")
108129

109-
# --- Load tokenizer ---
130+
# --- Load Tokenizer ---
110131
if tokenizer_or_name is None:
111132
if isinstance(model_or_model_name_or_path, str):
112-
tokenizer = AutoTokenizer.from_pretrained(
113-
model_or_model_name_or_path,
114-
token=hf_token,
115-
trust_remote_code=True,
116-
padding_side="left"
117-
)
133+
tok_name = model_or_model_name_or_path
118134
else:
119-
raise ValueError("Tokenizer must be provided if model is passed directly.")
135+
raise ValueError(
136+
"tokenizer_or_name must be provided when passing a model object directly."
137+
)
120138
elif isinstance(tokenizer_or_name, str):
121-
tokenizer = AutoTokenizer.from_pretrained(
122-
tokenizer_or_name,
123-
token=hf_token,
124-
trust_remote_code=True,
125-
padding_side="left"
126-
)
139+
tok_name = tokenizer_or_name
127140
else:
141+
# Already a tokenizer object
128142
tokenizer = tokenizer_or_name
143+
tok_name = None
144+
145+
if tok_name:
146+
load_tok_args = {
147+
"trust_remote_code": True,
148+
"token": hf_token,
149+
"padding_side": tokenizer_kwargs.get("padding_side", "left"),
150+
**tokenizer_kwargs,
151+
}
152+
tokenizer = AutoTokenizer.from_pretrained(tok_name, **load_tok_args)
129153

130-
# ensure pad token exists
131154
if tokenizer.pad_token is None:
132155
tokenizer.pad_token = tokenizer.eos_token
133156

157+
gen_params = {
158+
"max_new_tokens": max_new_tokens,
159+
"temperature": temperature,
160+
"do_sample": do_sample,
161+
"top_p": top_p,
162+
"top_k": top_k,
163+
"pad_token_id": tokenizer.pad_token_id,
164+
**generation_kwargs,
165+
}
166+
134167
model.eval()
135168

136169
# --- Load necessary files ---
@@ -185,13 +218,7 @@ def inference_transformers(
185218

186219
outputs = model.generate(
187220
**inputs,
188-
max_new_tokens=max_new_tokens,
189-
temperature=temperature if do_sample else 0.0,
190-
do_sample=do_sample,
191-
top_p=top_p,
192-
top_k=top_k,
193-
pad_token_id=tokenizer.pad_token_id,
194-
**generate_kwargs,
221+
**gen_params,
195222
)
196223

197224
input_lengths = [len(ids) for ids in inputs["input_ids"]]

0 commit comments

Comments
 (0)