diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index ff1daed4b35fc..88516b379141c 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -12,6 +12,7 @@ Literal, TypeAlias, TypeVar, + overload, ) from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -107,6 +108,8 @@ def _get_token_ids_default_method(text: str) -> list[int]: LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", AIMessage, str) """Type variable for the output of a language model.""" +_ModelT = TypeVar("_ModelT", bound=BaseModel | Mapping) + def _get_verbosity() -> bool: return get_verbose() @@ -267,9 +270,40 @@ async def agenerate_prompt( """ + @overload + def with_structured_output( + self, + schema: Mapping[str, Any], + *, + include_raw: Literal[False] = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, dict]: ... + + @overload def with_structured_output( - self, schema: dict | type, **kwargs: Any - ) -> Runnable[LanguageModelInput, dict | BaseModel]: + self, + schema: type[_ModelT], + *, + include_raw: Literal[False] = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, _ModelT]: ... + + @overload + def with_structured_output( + self, + schema: Mapping[str, Any] | type[_ModelT], + *, + include_raw: Literal[True], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, dict]: ... + + def with_structured_output( + self, + schema: Mapping | type, + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Any]: """Not implemented on this class.""" # Implement this on child class if there is a way of steering the model to # generate responses that match a given schema. diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index bfd37ea58835a..cd1647c703891 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -3,14 +3,14 @@ from __future__ import annotations import asyncio +import builtins import inspect import json -import typing from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Callable, Iterator, Sequence +from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from functools import cached_property from operator import itemgetter -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload from pydantic import BaseModel, ConfigDict, Field from typing_extensions import override @@ -73,6 +73,8 @@ from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from langchain_core.utils.utils import LC_ID_PREFIX, from_env +_ModelT = TypeVar("_ModelT", bound=BaseModel | Mapping) + if TYPE_CHECKING: import uuid @@ -226,7 +228,7 @@ async def agenerate_from_stream( return await run_in_executor(None, generate_from_stream, iter(chunks)) -def _format_ls_structured_output(ls_structured_output_format: dict | None) -> dict: +def _format_ls_structured_output(ls_structured_output_format: Mapping | None) -> dict: if ls_structured_output_format: try: ls_structured_output_format_dict = { @@ -717,7 +719,7 @@ async def astream( # --- Custom methods --- - def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: # noqa: ARG002 + def _combine_llm_outputs(self, llm_outputs: list[Mapping | None]) -> builtins.dict: # noqa: ARG002 return {} def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]: @@ -763,7 +765,7 @@ def _get_invocation_params( self, stop: list[str] | None = None, **kwargs: Any, - ) -> dict: + ) -> builtins.dict: params = self.dict() params["stop"] = stop return {**params, **kwargs} @@ -1479,7 +1481,7 @@ def _llm_type(self) -> str: """Return type of chat model.""" @override - def dict(self, **kwargs: Any) -> dict: + def dict(self, **kwargs: Any) -> builtins.dict: """Return a dictionary of the LLM.""" starter_dict = dict(self._identifying_params) starter_dict["_type"] = self._llm_type @@ -1487,9 +1489,7 @@ def dict(self, **kwargs: Any) -> dict: def bind_tools( self, - tools: Sequence[ - typing.Dict[str, Any] | type | Callable | BaseTool # noqa: UP006 - ], + tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool], *, tool_choice: str | None = None, **kwargs: Any, @@ -1506,13 +1506,40 @@ def bind_tools( """ raise NotImplementedError + @overload + def with_structured_output( + self, + schema: Mapping[str, Any], + *, + include_raw: Literal[False] = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, builtins.dict]: ... + + @overload + def with_structured_output( + self, + schema: type[_ModelT], + *, + include_raw: Literal[False] = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, _ModelT]: ... + + @overload + def with_structured_output( + self, + schema: Mapping[str, Any] | type[_ModelT], + *, + include_raw: Literal[True], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, builtins.dict]: ... + def with_structured_output( self, - schema: typing.Dict | type, # noqa: UP006 + schema: Mapping | type, *, include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, typing.Dict | BaseModel]: # noqa: UP006 + ) -> Runnable[LanguageModelInput, Any]: """Model wrapper that returns outputs formatted to match the given schema. Args: diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 6fe3e35a61039..73dc37641f449 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -8,6 +8,7 @@ import types import typing import uuid +from collections.abc import Mapping from typing import ( TYPE_CHECKING, Annotated, @@ -327,7 +328,7 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: def convert_to_openai_function( - function: dict[str, Any] | type | Callable | BaseTool, + function: Mapping[str, Any] | type | Callable | BaseTool, *, strict: bool | None = None, ) -> dict[str, Any]: @@ -357,7 +358,7 @@ def convert_to_openai_function( required and guaranteed to be part of the output. """ # an Anthropic format tool - if isinstance(function, dict) and all( + if isinstance(function, Mapping) and all( k in function for k in ("name", "input_schema") ): oai_function = { @@ -367,7 +368,7 @@ def convert_to_openai_function( if "description" in function: oai_function["description"] = function["description"] # an Amazon Bedrock Converse format tool - elif isinstance(function, dict) and "toolSpec" in function: + elif isinstance(function, Mapping) and "toolSpec" in function: oai_function = { "name": function["toolSpec"]["name"], "parameters": function["toolSpec"]["inputSchema"]["json"], @@ -375,15 +376,15 @@ def convert_to_openai_function( if "description" in function["toolSpec"]: oai_function["description"] = function["toolSpec"]["description"] # already in OpenAI function format - elif isinstance(function, dict) and "name" in function: + elif isinstance(function, Mapping) and "name" in function: oai_function = { k: v for k, v in function.items() if k in {"name", "description", "parameters", "strict"} } # a JSON schema with title and description - elif isinstance(function, dict) and "title" in function: - function_copy = function.copy() + elif isinstance(function, Mapping) and "title" in function: + function_copy = dict(function) oai_function = {"name": function_copy.pop("title")} if "description" in function_copy: oai_function["description"] = function_copy.pop("description") @@ -453,7 +454,7 @@ def convert_to_openai_function( def convert_to_openai_tool( - tool: dict[str, Any] | type[BaseModel] | Callable | BaseTool, + tool: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool, *, strict: bool | None = None, ) -> dict[str, Any]: @@ -491,12 +492,12 @@ def convert_to_openai_tool( # Import locally to prevent circular import from langchain_core.tools import Tool # noqa: PLC0415 - if isinstance(tool, dict): + if isinstance(tool, Mapping): if tool.get("type") in _WellKnownOpenAITools: - return tool + return dict(tool) # As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11" if (tool.get("type") or "").startswith("web_search_preview"): - return tool + return dict(tool) if isinstance(tool, Tool) and (tool.metadata or {}).get("type") == "custom_tool": oai_tool = { "type": "custom", @@ -511,7 +512,7 @@ def convert_to_openai_tool( def convert_to_json_schema( - schema: dict[str, Any] | type[BaseModel] | Callable | BaseTool, + schema: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool, *, strict: bool | None = None, ) -> dict[str, Any]: diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index a3568bd380f2f..48f8679d3ccb2 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from functools import partial from inspect import isclass from typing import Any, cast @@ -17,8 +18,8 @@ def _fake_runnable( - _: Any, *, schema: dict | type[BaseModel], value: Any = 42, **_kwargs: Any -) -> BaseModel | dict: + _: Any, *, schema: Mapping | type, value: Any = 42, **_kwargs: Any +) -> Any: if isclass(schema) and is_basemodel_subclass(schema): return schema(name="yo", value=value) params = cast("dict", schema)["parameters"] @@ -29,9 +30,7 @@ class FakeStructuredChatModel(FakeListChatModel): """Fake chat model for testing purposes.""" @override - def with_structured_output( - self, schema: dict | type[BaseModel], **kwargs: Any - ) -> Runnable: + def with_structured_output(self, schema: Mapping | type, **kwargs: Any) -> Runnable: return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs)) @property diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 1d10887c72525..001805314861e 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -1,10 +1,9 @@ -from collections.abc import AsyncIterator, Callable, Iterator, Sequence +from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from typing import ( Any, ) import pytest -from pydantic import BaseModel from syrupy.assertion import SnapshotAssertion from typing_extensions import override @@ -335,15 +334,15 @@ def _generate( @override def bind_tools( self, - tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], + tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool], **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: return self.bind(tools=tools) @override def with_structured_output( - self, schema: dict | type[BaseModel], **kwargs: Any - ) -> Runnable[LanguageModelInput, dict | BaseModel]: + self, schema: Mapping | type, **kwargs: Any + ) -> Runnable[LanguageModelInput, Any]: return RunnableLambda(lambda _: {"foo": self.foo}) @property @@ -368,7 +367,7 @@ def _generate( @override def bind_tools( self, - tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], + tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool], **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: return self.bind(tools=tools)