Skip to content

Commit 62a8497

Browse files
Merge pull request #61 from LLMSQL/60-bug-change-names-of-the-variables
text fixed
2 parents fb4a5f4 + 7fde1d6 commit 62a8497

File tree

5 files changed

+28
-26
lines changed

5 files changed

+28
-26
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ results = inference_transformers(
9292
batch_size=8,
9393
max_new_tokens=256,
9494
do_sample=False,
95-
model_args={
95+
model_kwargs={
9696
"torch_dtype": "bfloat16",
9797
}
9898
)

docs/docs/usage.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ Using transformers backend.
3333
batch_size=8,
3434
max_new_tokens=256,
3535
temperature=0.7,
36-
model_args={
36+
model_kwargs={
3737
"attn_implementation": "flash_attention_2",
3838
"torch_dtype": "bfloat16",
3939
},
40-
generate_kwargs={
40+
generation_kwargs={
4141
"do_sample": False,
4242
},
4343
)

llmsql/__main__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def main() -> None:
3131
llmsql inference --method transformers \
3232
--model-or-model-name-or-path meta-llama/Llama-3-8b-instruct \
3333
--output-file outputs/llama_preds.jsonl \
34-
--model-args '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}'
34+
--model-kwargs '{"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"}'
3535
3636
# 4️⃣ Pass LLM init kwargs (for vLLM)
3737
llmsql inference --method vllm \
@@ -44,7 +44,7 @@ def main() -> None:
4444
--model-or-model-name-or-path Qwen/Qwen2.5-1.5B-Instruct \
4545
--output-file outputs/temp_0.9.jsonl \
4646
--temperature 0.9 \
47-
--generate-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}'
47+
--generation-kwargs '{"do_sample": true, "top_p": 0.9, "top_k": 40}'
4848
"""
4949

5050
inf_parser = subparsers.add_parser(
@@ -127,15 +127,17 @@ def main() -> None:
127127
except json.JSONDecodeError:
128128
print("⚠️ Could not parse --llm-kwargs JSON, passing as string.")
129129

130-
if fn_kwargs.get("model_args") is not None:
130+
if fn_kwargs.get("model_kwargs") is not None:
131131
try:
132-
fn_kwargs["model_args"] = json.loads(fn_kwargs["model_args"])
132+
fn_kwargs["model_kwargs"] = json.loads(fn_kwargs["model_kwargs"])
133133
except json.JSONDecodeError:
134134
raise
135135

136-
if fn_kwargs.get("generate_kwargs") is not None:
136+
if fn_kwargs.get("generation_kwargs") is not None:
137137
try:
138-
fn_kwargs["generate_kwargs"] = json.loads(fn_kwargs["generate_kwargs"])
138+
fn_kwargs["generation_kwargs"] = json.loads(
139+
fn_kwargs["generation_kwargs"]
140+
)
139141
except json.JSONDecodeError:
140142
raise
141143

llmsql/inference/README.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,21 @@ pip install llmsql[vllm]
3333
from llmsql import inference_transformers
3434

3535
results = inference_transformers(
36-
model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
37-
output_file="outputs/preds_transformers.jsonl",
38-
questions_path="data/questions.jsonl",
39-
tables_path="data/tables.jsonl",
40-
num_fewshots=5,
41-
batch_size=8,
42-
max_new_tokens=256,
43-
temperature=0.7,
44-
model_args={
45-
"torch_dtype": "bfloat16",
46-
},
47-
generate_kwargs={
48-
"do_sample": False,
49-
},
50-
)
36+
model_or_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
37+
output_file="outputs/preds_transformers.jsonl",
38+
questions_path="data/questions.jsonl",
39+
tables_path="data/tables.jsonl",
40+
num_fewshots=5,
41+
batch_size=8,
42+
max_new_tokens=256,
43+
temperature=0.7,
44+
model_kwargs={
45+
"torch_dtype": "bfloat16",
46+
},
47+
generation_kwargs={
48+
"do_sample": False,
49+
},
50+
)
5151
```
5252

5353
---

llmsql/inference/inference_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
batch_size=8,
2222
max_new_tokens=256,
2323
temperature=0.7,
24-
model_args={
24+
model_kwargs={
2525
"torch_dtype": "bfloat16",
2626
},
27-
generate_kwargs={
27+
generation_kwargs={
2828
"do_sample": False,
2929
},
3030
)

0 commit comments

Comments
 (0)