Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Literal,
TypeAlias,
TypeVar,
overload,
)

from pydantic import BaseModel, ConfigDict, Field, field_validator
Expand Down Expand Up @@ -107,6 +108,8 @@ def _get_token_ids_default_method(text: str) -> list[int]:
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", AIMessage, str)
"""Type variable for the output of a language model."""

_ModelT = TypeVar("_ModelT", bound=BaseModel | Mapping)


def _get_verbosity() -> bool:
return get_verbose()
Expand Down Expand Up @@ -267,9 +270,40 @@ async def agenerate_prompt(

"""

@overload
def with_structured_output(
self,
schema: Mapping[str, Any],
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict]: ...

@overload
def with_structured_output(
self, schema: dict | type, **kwargs: Any
) -> Runnable[LanguageModelInput, dict | BaseModel]:
self,
schema: type[_ModelT],
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _ModelT]: ...

@overload
def with_structured_output(
self,
schema: Mapping[str, Any] | type[_ModelT],
*,
include_raw: Literal[True],
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict]: ...

def with_structured_output(
self,
schema: Mapping | type,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Any]:
"""Not implemented on this class."""
# Implement this on child class if there is a way of steering the model to
# generate responses that match a given schema.
Expand Down
51 changes: 39 additions & 12 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from __future__ import annotations

import asyncio
import builtins
import inspect
import json
import typing
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from functools import cached_property
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override
Expand Down Expand Up @@ -73,6 +73,8 @@
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from langchain_core.utils.utils import LC_ID_PREFIX, from_env

_ModelT = TypeVar("_ModelT", bound=BaseModel | Mapping)

if TYPE_CHECKING:
import uuid

Expand Down Expand Up @@ -226,7 +228,7 @@ async def agenerate_from_stream(
return await run_in_executor(None, generate_from_stream, iter(chunks))


def _format_ls_structured_output(ls_structured_output_format: dict | None) -> dict:
def _format_ls_structured_output(ls_structured_output_format: Mapping | None) -> dict:
if ls_structured_output_format:
try:
ls_structured_output_format_dict = {
Expand Down Expand Up @@ -717,7 +719,7 @@ async def astream(

# --- Custom methods ---

def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: # noqa: ARG002
def _combine_llm_outputs(self, llm_outputs: list[Mapping | None]) -> builtins.dict: # noqa: ARG002
return {}

def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]:
Expand Down Expand Up @@ -763,7 +765,7 @@ def _get_invocation_params(
self,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
) -> builtins.dict:
params = self.dict()
params["stop"] = stop
return {**params, **kwargs}
Expand Down Expand Up @@ -1479,17 +1481,15 @@ def _llm_type(self) -> str:
"""Return type of chat model."""

@override
def dict(self, **kwargs: Any) -> dict:
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict

def bind_tools(
self,
tools: Sequence[
typing.Dict[str, Any] | type | Callable | BaseTool # noqa: UP006
],
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
*,
tool_choice: str | None = None,
**kwargs: Any,
Expand All @@ -1506,13 +1506,40 @@ def bind_tools(
"""
raise NotImplementedError

@overload
def with_structured_output(
self,
schema: Mapping[str, Any],
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, builtins.dict]: ...

@overload
def with_structured_output(
self,
schema: type[_ModelT],
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _ModelT]: ...

@overload
def with_structured_output(
self,
schema: Mapping[str, Any] | type[_ModelT],
*,
include_raw: Literal[True],
**kwargs: Any,
) -> Runnable[LanguageModelInput, builtins.dict]: ...

def with_structured_output(
self,
schema: typing.Dict | type, # noqa: UP006
schema: Mapping | type,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, typing.Dict | BaseModel]: # noqa: UP006
) -> Runnable[LanguageModelInput, Any]:
"""Model wrapper that returns outputs formatted to match the given schema.

Args:
Expand Down
23 changes: 12 additions & 11 deletions libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import types
import typing
import uuid
from collections.abc import Mapping
from typing import (
TYPE_CHECKING,
Annotated,
Expand Down Expand Up @@ -327,7 +328,7 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:


def convert_to_openai_function(
function: dict[str, Any] | type | Callable | BaseTool,
function: Mapping[str, Any] | type | Callable | BaseTool,
*,
strict: bool | None = None,
) -> dict[str, Any]:
Expand Down Expand Up @@ -357,7 +358,7 @@ def convert_to_openai_function(
required and guaranteed to be part of the output.
"""
# an Anthropic format tool
if isinstance(function, dict) and all(
if isinstance(function, Mapping) and all(
k in function for k in ("name", "input_schema")
):
oai_function = {
Expand All @@ -367,23 +368,23 @@ def convert_to_openai_function(
if "description" in function:
oai_function["description"] = function["description"]
# an Amazon Bedrock Converse format tool
elif isinstance(function, dict) and "toolSpec" in function:
elif isinstance(function, Mapping) and "toolSpec" in function:
oai_function = {
"name": function["toolSpec"]["name"],
"parameters": function["toolSpec"]["inputSchema"]["json"],
}
if "description" in function["toolSpec"]:
oai_function["description"] = function["toolSpec"]["description"]
# already in OpenAI function format
elif isinstance(function, dict) and "name" in function:
elif isinstance(function, Mapping) and "name" in function:
oai_function = {
k: v
for k, v in function.items()
if k in {"name", "description", "parameters", "strict"}
}
# a JSON schema with title and description
elif isinstance(function, dict) and "title" in function:
function_copy = function.copy()
elif isinstance(function, Mapping) and "title" in function:
function_copy = dict(function)
oai_function = {"name": function_copy.pop("title")}
if "description" in function_copy:
oai_function["description"] = function_copy.pop("description")
Expand Down Expand Up @@ -453,7 +454,7 @@ def convert_to_openai_function(


def convert_to_openai_tool(
tool: dict[str, Any] | type[BaseModel] | Callable | BaseTool,
tool: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool,
*,
strict: bool | None = None,
) -> dict[str, Any]:
Expand Down Expand Up @@ -491,12 +492,12 @@ def convert_to_openai_tool(
# Import locally to prevent circular import
from langchain_core.tools import Tool # noqa: PLC0415

if isinstance(tool, dict):
if isinstance(tool, Mapping):
if tool.get("type") in _WellKnownOpenAITools:
return tool
return dict(tool)
# As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11"
if (tool.get("type") or "").startswith("web_search_preview"):
return tool
return dict(tool)
if isinstance(tool, Tool) and (tool.metadata or {}).get("type") == "custom_tool":
oai_tool = {
"type": "custom",
Expand All @@ -511,7 +512,7 @@ def convert_to_openai_tool(


def convert_to_json_schema(
schema: dict[str, Any] | type[BaseModel] | Callable | BaseTool,
schema: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool,
*,
strict: bool | None = None,
) -> dict[str, Any]:
Expand Down
9 changes: 4 additions & 5 deletions libs/core/tests/unit_tests/prompts/test_structured.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Mapping
from functools import partial
from inspect import isclass
from typing import Any, cast
Expand All @@ -17,8 +18,8 @@


def _fake_runnable(
_: Any, *, schema: dict | type[BaseModel], value: Any = 42, **_kwargs: Any
) -> BaseModel | dict:
_: Any, *, schema: Mapping | type, value: Any = 42, **_kwargs: Any
) -> Any:
if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=value)
params = cast("dict", schema)["parameters"]
Expand All @@ -29,9 +30,7 @@ class FakeStructuredChatModel(FakeListChatModel):
"""Fake chat model for testing purposes."""

@override
def with_structured_output(
self, schema: dict | type[BaseModel], **kwargs: Any
) -> Runnable:
def with_structured_output(self, schema: Mapping | type, **kwargs: Any) -> Runnable:
return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs))

@property
Expand Down
11 changes: 5 additions & 6 deletions libs/core/tests/unit_tests/runnables/test_fallbacks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from typing import (
Any,
)

import pytest
from pydantic import BaseModel
from syrupy.assertion import SnapshotAssertion
from typing_extensions import override

Expand Down Expand Up @@ -335,15 +334,15 @@ def _generate(
@override
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self.bind(tools=tools)

@override
def with_structured_output(
self, schema: dict | type[BaseModel], **kwargs: Any
) -> Runnable[LanguageModelInput, dict | BaseModel]:
self, schema: Mapping | type, **kwargs: Any
) -> Runnable[LanguageModelInput, Any]:
return RunnableLambda(lambda _: {"foo": self.foo})

@property
Expand All @@ -368,7 +367,7 @@ def _generate(
@override
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self.bind(tools=tools)
Expand Down
Loading