Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,32 @@

return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
try:
from langchain_huggingface.chat_models import ChatHuggingFace
from langchain_huggingface.llms import HuggingFacePipeline
except ImportError as e:
raise ImportError(
"Please install langchain-huggingface to use HuggingFace models."

Check failure on line 423 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.9) / Python 3.9

Ruff (EM101)

langchain/chat_models/base.py:423:17: EM101 Exception must not use a string literal, assign to variable first

Check failure on line 423 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (EM101)

langchain/chat_models/base.py:423:17: EM101 Exception must not use a string literal, assign to variable first
) from e

Check failure on line 424 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.9) / Python 3.9

Ruff (TRY003)

langchain/chat_models/base.py:422:19: TRY003 Avoid specifying long messages outside the exception class

Check failure on line 424 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (TRY003)

langchain/chat_models/base.py:422:19: TRY003 Avoid specifying long messages outside the exception class

# The 'task' kwarg is required by from_model_id but not the base constructor.
# We pop it from kwargs to avoid the Pydantic 'extra_forbidden' error.
task = kwargs.pop("task", None)
if not task:
raise ValueError(
"The 'task' keyword argument is required for HuggingFace models. "
"For example: task='text-generation'."

Check failure on line 432 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.9) / Python 3.9

Ruff (EM101)

langchain/chat_models/base.py:431:17: EM101 Exception must not use a string literal, assign to variable first

Check failure on line 432 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (EM101)

langchain/chat_models/base.py:431:17: EM101 Exception must not use a string literal, assign to variable first
)

Check failure on line 433 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.9) / Python 3.9

Ruff (TRY003)

langchain/chat_models/base.py:430:19: TRY003 Avoid specifying long messages outside the exception class

Check failure on line 433 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (TRY003)

langchain/chat_models/base.py:430:19: TRY003 Avoid specifying long messages outside the exception class

# Initialize the base LLM pipeline with the model and arguments
llm = HuggingFacePipeline.from_model_id(
model_id=model,
task=task,
**kwargs, # Pass remaining kwargs like `device`
)

return ChatHuggingFace(model_id=model, **kwargs)
# Pass the initialized LLM to the chat wrapper
return ChatHuggingFace(llm=llm)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
Expand Down
14 changes: 14 additions & 0 deletions libs/langchain/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableSequence
from langchain_huggingface.chat_models import ChatHuggingFace
from pydantic import SecretStr

from langchain.chat_models.base import __all__, init_chat_model
Expand Down Expand Up @@ -289,3 +290,16 @@ def test_configurable_with_default() -> None:
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
chain = prompt | model_with_config
assert isinstance(chain, RunnableSequence)


def test_init_chat_model_huggingface() -> None:
"""Test that init_chat_model works with huggingface."""
model_name = "google-bert/bert-base-uncased"

llm = init_chat_model(
model=model_name,
model_provider="huggingface",
task="text-generation",
)
assert isinstance(llm, ChatHuggingFace)
assert llm.llm.model_id == model_name
Loading