66from tqdm import tqdm
77from transformers import AutoModelForCausalLM , AutoTokenizer
88
9- from llmsql .config .config import DEFAULT_WORKDIR_PATH
109from llmsql .loggers .logging_config import log
1110from llmsql .utils .inference_utils import _maybe_download , _setup_seed
1211from llmsql .utils .utils import (
@@ -27,110 +26,144 @@ def inference_transformers(
2726 model_or_model_name_or_path : str | AutoModelForCausalLM ,
2827 tokenizer_or_name : str | Any | None = None ,
2928 * ,
30- chat_template : str | None = None ,
31- model_args : dict [str , Any ] | None = None ,
32- hf_token : str | None = None ,
33- output_file : str = "outputs/predictions.jsonl" ,
34- questions_path : str | None = None ,
35- tables_path : str | None = None ,
36- workdir_path : str = DEFAULT_WORKDIR_PATH ,
37- num_fewshots : int = 5 ,
29+ # --- Model Loading Parameters ---
3830 trust_remote_code : bool = True ,
39- batch_size : int = 8 ,
31+ dtype : torch .dtype = torch .float16 ,
32+ device_map : str | dict [str , int ] | None = "auto" ,
33+ hf_token : str | None = None ,
34+ model_kwargs : dict [str , Any ] | None = None ,
35+ # --- Tokenizer Loading Parameters ---
36+ tokenizer_kwargs : dict [str , Any ] | None = None ,
37+ # --- Prompt & Chat Parameters ---
38+ chat_template : str | None = None ,
39+ # --- Generation Parameters ---
4040 max_new_tokens : int = 256 ,
4141 temperature : float = 0.0 ,
4242 do_sample : bool = False ,
4343 top_p : float = 1.0 ,
4444 top_k : int = 50 ,
45+ generation_kwargs : dict [str , Any ] | None = None ,
46+ # --- Benchmark Parameters ---
47+ output_file : str = "outputs/predictions.jsonl" ,
48+ questions_path : str | None = None ,
49+ tables_path : str | None = None ,
50+ workdir_path : str = "llmsql_workdir" ,
51+ num_fewshots : int = 5 ,
52+ batch_size : int = 8 ,
4553 seed : int = 42 ,
46- dtype : torch .dtype = torch .float16 ,
47- device_map : str | dict [str , int ] | None = "auto" ,
48- generate_kwargs : dict [str , Any ] | None = None ,
4954) -> list [dict [str , str ]]:
5055 """
5156 Inference a causal model (Transformers) on the LLMSQL benchmark.
5257
5358 Args:
5459 model_or_model_name_or_path: Model object or HF model name/path.
5560 tokenizer_or_name: Tokenizer object or HF tokenizer name/path.
61+
62+ # Model Loading:
63+ trust_remote_code: Whether to trust remote code (default: True).
64+ dtype: Torch dtype for model (default: float16).
65+ device_map: Device placement strategy (default: "auto").
66+ hf_token: Hugging Face authentication token.
67+ model_kwargs: Additional arguments for AutoModelForCausalLM.from_pretrained().
68+ Note: 'dtype', 'device_map', 'trust_remote_code', 'token'
69+ are handled separately and will override values here.
70+
71+ # Tokenizer Loading:
72+ tokenizer_kwargs: Additional arguments for AutoTokenizer.from_pretrained(). 'padding_side' defaults to "left".
73+ Note: 'trust_remote_code', 'token' are handled separately and will override values here.
74+
75+
76+ # Prompt & Chat:
5677 chat_template: Optional chat template to apply before tokenization.
57- model_args: Optional kwargs passed to `from_pretrained` if needed.
58- hf_token: Hugging Face token (optional).
59- output_file: Output JSONL file for completions.
78+
79+ # Generation:
80+ max_new_tokens: Maximum tokens to generate per sequence.
81+ temperature: Sampling temperature (0.0 = greedy).
82+ do_sample: Whether to use sampling vs greedy decoding.
83+ top_p: Nucleus sampling parameter.
84+ top_k: Top-k sampling parameter.
85+ generation_kwargs: Additional arguments for model.generate().
86+ Note: 'max_new_tokens', 'temperature', 'do_sample',
87+ 'top_p', 'top_k' are handled separately.
88+
89+ # Benchmark:
90+ output_file: Output JSONL file path for completions.
6091 questions_path: Path to benchmark questions JSONL.
6192 tables_path: Path to benchmark tables JSONL.
62- workdir_path: Work directory (default: "llmsql_workdir") .
63- num_fewshots: 0, 1, or 5 — prompt builder choice .
93+ workdir_path: Working directory path .
94+ num_fewshots: Number of few-shot examples ( 0, 1, or 5) .
6495 batch_size: Batch size for inference.
65- max_new_tokens: Max tokens to generate.
66- temperature: Sampling temperature.
67- do_sample: Whether to sample or use greedy decoding.
68- top_p: Nucleus sampling parameter.
69- top_k: Top-k sampling parameter.
70- seed: Random seed.
71- dtype: Torch dtype (default: float16).
72- device_map: Device map ("auto" for multi-GPU).
73- **generate_kwargs: Extra arguments for `model.generate`.
96+ seed: Random seed for reproducibility.
7497
7598 Returns:
76- List[dict[str, str]]: Generated SQL results.
99+ List of generated SQL results with metadata .
77100 """
78101 # --- Setup ---
79102 _setup_seed (seed = seed )
80103
81104 workdir = Path (workdir_path )
82105 workdir .mkdir (parents = True , exist_ok = True )
83106
84- if generate_kwargs is None :
85- generate_kwargs = {}
107+ model_kwargs = model_kwargs or {}
108+ tokenizer_kwargs = tokenizer_kwargs or {}
109+ generation_kwargs = generation_kwargs or {}
86110
87- model_args = model_args or {}
88- if "torch_dtype" in model_args :
89- dtype = model_args .pop ("torch_dtype" )
90- if "trust_remote_code" in model_args :
91- trust_remote_code = model_args .pop ("trust_remote_code" )
92-
93- # --- Load model ---
111+ # --- Load Model ---
94112 if isinstance (model_or_model_name_or_path , str ):
95- model_args = model_args or {}
96- log .info (f"Loading model from: { model_or_model_name_or_path } " )
113+ load_args = {
114+ "torch_dtype" : dtype ,
115+ "device_map" : device_map ,
116+ "trust_remote_code" : trust_remote_code ,
117+ "token" : hf_token ,
118+ ** model_kwargs ,
119+ }
120+
121+ print (f"Loading model from: { model_or_model_name_or_path } " )
97122 model = AutoModelForCausalLM .from_pretrained (
98123 model_or_model_name_or_path ,
99- torch_dtype = dtype ,
100- device_map = device_map ,
101- token = hf_token ,
102- trust_remote_code = trust_remote_code ,
103- ** model_args ,
124+ ** load_args ,
104125 )
105126 else :
106127 model = model_or_model_name_or_path
107- log . info (f"Using provided model object: { type (model )} " )
128+ print (f"Using provided model object: { type (model )} " )
108129
109- # --- Load tokenizer ---
130+ # --- Load Tokenizer ---
110131 if tokenizer_or_name is None :
111132 if isinstance (model_or_model_name_or_path , str ):
112- tokenizer = AutoTokenizer .from_pretrained (
113- model_or_model_name_or_path ,
114- token = hf_token ,
115- trust_remote_code = True ,
116- padding_side = "left"
117- )
133+ tok_name = model_or_model_name_or_path
118134 else :
119- raise ValueError ("Tokenizer must be provided if model is passed directly." )
135+ raise ValueError (
136+ "tokenizer_or_name must be provided when passing a model object directly."
137+ )
120138 elif isinstance (tokenizer_or_name , str ):
121- tokenizer = AutoTokenizer .from_pretrained (
122- tokenizer_or_name ,
123- token = hf_token ,
124- trust_remote_code = True ,
125- padding_side = "left"
126- )
139+ tok_name = tokenizer_or_name
127140 else :
141+ # Already a tokenizer object
128142 tokenizer = tokenizer_or_name
143+ tok_name = None
144+
145+ if tok_name :
146+ load_tok_args = {
147+ "trust_remote_code" : True ,
148+ "token" : hf_token ,
149+ "padding_side" : tokenizer_kwargs .get ("padding_side" , "left" ),
150+ ** tokenizer_kwargs ,
151+ }
152+ tokenizer = AutoTokenizer .from_pretrained (tok_name , ** load_tok_args )
129153
130- # ensure pad token exists
131154 if tokenizer .pad_token is None :
132155 tokenizer .pad_token = tokenizer .eos_token
133156
157+ gen_params = {
158+ "max_new_tokens" : max_new_tokens ,
159+ "temperature" : temperature ,
160+ "do_sample" : do_sample ,
161+ "top_p" : top_p ,
162+ "top_k" : top_k ,
163+ "pad_token_id" : tokenizer .pad_token_id ,
164+ ** generation_kwargs ,
165+ }
166+
134167 model .eval ()
135168
136169 # --- Load necessary files ---
@@ -185,13 +218,7 @@ def inference_transformers(
185218
186219 outputs = model .generate (
187220 ** inputs ,
188- max_new_tokens = max_new_tokens ,
189- temperature = temperature if do_sample else 0.0 ,
190- do_sample = do_sample ,
191- top_p = top_p ,
192- top_k = top_k ,
193- pad_token_id = tokenizer .pad_token_id ,
194- ** generate_kwargs ,
221+ ** gen_params ,
195222 )
196223
197224 input_lengths = [len (ids ) for ids in inputs ["input_ids" ]]
0 commit comments