Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,20 @@
*,
model_provider: Optional[str] = None,
configurable_fields: None = None,
configurable_fields: None = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> BaseChatModel: ...


@overload
def init_chat_model(
model: None = None,
model: None = None,
*,
model_provider: Optional[str] = None,
configurable_fields: None = None,
configurable_fields: None = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
Expand Down Expand Up @@ -128,6 +131,16 @@
- ``deepseek...`` -> ``deepseek``
- ``grok...`` -> ``xai``
- ``sonar...`` -> ``perplexity``
- ``gpt-3...`` | ``gpt-4...`` | ``o1...`` -> ``openai``
- ``claude...`` -> ``anthropic``
- ``amazon...`` -> ``bedrock``
- ``gemini...`` -> ``google_vertexai``
- ``command...`` -> ``cohere``
- ``accounts/fireworks...`` -> ``fireworks``
- ``mistral...`` -> ``mistralai``
- ``deepseek...`` -> ``deepseek``
- ``grok...`` -> ``xai``
- ``sonar...`` -> ``perplexity``
configurable_fields: Which model parameters are configurable:

- None: No configurable fields.
Expand Down Expand Up @@ -277,6 +290,10 @@
GetWeather,
GetPopulation,
]
[
GetWeather,
GetPopulation,
]
)
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
Expand Down Expand Up @@ -415,10 +432,32 @@

return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
try:
from langchain_huggingface.chat_models import ChatHuggingFace
from langchain_huggingface.llms import HuggingFacePipeline
except ImportError as e:
raise ImportError(
"Please install langchain-huggingface to use HuggingFace models."
) from e

# The 'task' kwarg is required by from_model_id but not the base constructor.
# We pop it from kwargs to avoid the Pydantic 'extra_forbidden' error.
task = kwargs.pop("task", None)
if not task:
raise ValueError(
"The 'task' keyword argument is required for HuggingFace models. "
"For example: task='text-generation'."
)

# Initialize the base LLM pipeline with the model and arguments
llm = HuggingFacePipeline.from_model_id(
model_id=model,
task=task,
**kwargs, # Pass remaining kwargs like `device`
)

return ChatHuggingFace(model_id=model, **kwargs)
# Pass the initialized LLM to the chat wrapper
return ChatHuggingFace(llm=llm)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
Expand Down Expand Up @@ -497,7 +536,8 @@


def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):
if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):

Check failure on line 540 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.9) / Python 3.9

Ruff (invalid-syntax)

langchain/chat_models/base.py:540:5: invalid-syntax: Expected an indented block after `if` statement

Check failure on line 540 in libs/langchain/langchain/chat_models/base.py

View workflow job for this annotation

GitHub Actions / lint (libs/langchain, 3.13) / Python 3.13

Ruff (invalid-syntax)

langchain/chat_models/base.py:540:5: invalid-syntax: Expected an indented block after `if` statement
return "openai"
if model_name.startswith("claude"):
return "anthropic"
Expand Down Expand Up @@ -958,3 +998,4 @@
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return self.__getattr__("with_structured_output")(schema, **kwargs)

14 changes: 14 additions & 0 deletions libs/langchain/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableSequence
from langchain_huggingface.chat_models import ChatHuggingFace
from pydantic import SecretStr

from langchain.chat_models.base import __all__, init_chat_model
Expand Down Expand Up @@ -289,3 +290,16 @@ def test_configurable_with_default() -> None:
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
chain = prompt | model_with_config
assert isinstance(chain, RunnableSequence)


def test_init_chat_model_huggingface() -> None:
"""Test that init_chat_model works with huggingface."""
model_name = "google-bert/bert-base-uncased"

llm = init_chat_model(
model=model_name,
model_provider="huggingface",
task="text-generation",
)
assert isinstance(llm, ChatHuggingFace)
assert llm.llm.model_id == model_name
Loading