diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 4de96a1a26995..21e248d4180a7 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -415,10 +415,32 @@ def _init_chat_model_helper( return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore] if model_provider == "huggingface": - _check_pkg("langchain_huggingface") - from langchain_huggingface import ChatHuggingFace + try: + from langchain_huggingface.chat_models import ChatHuggingFace + from langchain_huggingface.llms import HuggingFacePipeline + except ImportError as e: + raise ImportError( + "Please install langchain-huggingface to use HuggingFace models." + ) from e + + # The 'task' kwarg is required by from_model_id but not the base constructor. + # We pop it from kwargs to avoid the Pydantic 'extra_forbidden' error. + task = kwargs.pop("task", None) + if not task: + raise ValueError( + "The 'task' keyword argument is required for HuggingFace models. " + "For example: task='text-generation'." + ) + + # Initialize the base LLM pipeline with the model and arguments + llm = HuggingFacePipeline.from_model_id( + model_id=model, + task=task, + **kwargs, # Pass remaining kwargs like `device` + ) - return ChatHuggingFace(model_id=model, **kwargs) + # Pass the initialized LLM to the chat wrapper + return ChatHuggingFace(llm=llm) if model_provider == "groq": _check_pkg("langchain_groq") from langchain_groq import ChatGroq diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index 4a1c6b7c7fb3f..8210b94ff74d8 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -6,6 +6,7 @@ from langchain_core.language_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableConfig, RunnableSequence +from langchain_huggingface.chat_models import ChatHuggingFace from pydantic import SecretStr from langchain.chat_models.base import __all__, init_chat_model @@ -289,3 +290,16 @@ def test_configurable_with_default() -> None: prompt = ChatPromptTemplate.from_messages([("system", "foo")]) chain = prompt | model_with_config assert isinstance(chain, RunnableSequence) + + +def test_init_chat_model_huggingface() -> None: + """Test that init_chat_model works with huggingface.""" + model_name = "google-bert/bert-base-uncased" + + llm = init_chat_model( + model=model_name, + model_provider="huggingface", + task="text-generation", + ) + assert isinstance(llm, ChatHuggingFace) + assert llm.llm.model_id == model_name