@@ -31,7 +31,7 @@ def main() -> None:
3131 llmsql inference --method transformers \
3232 --model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \
3333 --output-file outputs/llama_preds.jsonl \
34- --model-args '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}'
34+ --model-kwargs '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}'
3535
3636 # 4️⃣ Pass LLM init kwargs (for vLLM)
3737 llmsql inference --method vllm \
@@ -44,7 +44,7 @@ def main() -> None:
4444 --model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
4545 --output-file outputs/temp_0.9.jsonl \
4646 --temperature 0.9 \
47- --generate -kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}'
47+ --generation -kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}'
4848"""
4949
5050 inf_parser = subparsers .add_parser (
@@ -127,15 +127,17 @@ def main() -> None:
127127 except json .JSONDecodeError :
128128 print ("⚠️ Could not parse --llm-kwargs JSON, passing as string." )
129129
130- if fn_kwargs .get ("model_args " ) is not None :
130+ if fn_kwargs .get ("model_kwargs " ) is not None :
131131 try :
132- fn_kwargs ["model_args " ] = json .loads (fn_kwargs ["model_args " ])
132+ fn_kwargs ["model_kwargs " ] = json .loads (fn_kwargs ["model_kwargs " ])
133133 except json .JSONDecodeError :
134134 raise
135135
136- if fn_kwargs .get ("generate_kwargs " ) is not None :
136+ if fn_kwargs .get ("generation_kwargs " ) is not None :
137137 try :
138- fn_kwargs ["generate_kwargs" ] = json .loads (fn_kwargs ["generate_kwargs" ])
138+ fn_kwargs ["generation_kwargs" ] = json .loads (
139+ fn_kwargs ["generation_kwargs" ]
140+ )
139141 except json .JSONDecodeError :
140142 raise
141143
0 commit comments