Skip to content

Commit 36b6b3c

Browse files
authored
HuggingFacePipeline: Forward model_kwargs. (#696)
Since the tokenizer and model are constructed manually, model_kwargs needs to be passed to their constructors. Additionally, the pipeline has a specific named parameter to pass these with, which can provide forward compatibility if they are used for something other than tokenizer or model construction.
1 parent 3a30e6d commit 36b6b3c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

langchain/llms/huggingface_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,19 @@ def from_model_id(
6868
)
6969
from transformers import pipeline as hf_pipeline
7070

71-
tokenizer = AutoTokenizer.from_pretrained(model_id)
71+
_model_kwargs = model_kwargs or {}
72+
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
7273
if task == "text-generation":
73-
model = AutoModelForCausalLM.from_pretrained(model_id)
74+
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
7475
elif task == "text2text-generation":
75-
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
76+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
7677
else:
7778
raise ValueError(
7879
f"Got invalid task {task}, "
7980
f"currently only {VALID_TASKS} are supported"
8081
)
81-
_model_kwargs = model_kwargs or {}
8282
pipeline = hf_pipeline(
83-
task=task, model=model, tokenizer=tokenizer, **_model_kwargs
83+
task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs
8484
)
8585
if pipeline.task not in VALID_TASKS:
8686
raise ValueError(

0 commit comments

Comments
 (0)