Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
convert_to_json_schema,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import LC_ID_PREFIX, from_env

if TYPE_CHECKING:
Expand Down Expand Up @@ -1650,7 +1650,7 @@ class AnswerWithJustification(BaseModel):
)
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast("TypeBaseModel", schema)], first_tool_only=True
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
Expand Down
53 changes: 27 additions & 26 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
)
from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import create_model_v2
from langchain_core.utils.pydantic import TypeBaseModel, create_model_v2, get_fields

if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
Expand Down Expand Up @@ -355,14 +355,14 @@ def OutputType(self) -> type[Output]: # noqa: N802
raise TypeError(msg)

@property
def input_schema(self) -> type[BaseModel]:
def input_schema(self) -> TypeBaseModel:
"""The type of input this `Runnable` accepts specified as a Pydantic model."""
return self.get_input_schema()

def get_input_schema(
self,
config: RunnableConfig | None = None, # noqa: ARG002
) -> type[BaseModel]:
) -> TypeBaseModel:
"""Get a Pydantic model that can be used to validate input to the `Runnable`.

`Runnable` objects that leverage the `configurable_fields` and
Expand Down Expand Up @@ -427,10 +427,13 @@ def add_one(x: int) -> int:
!!! version-added "Added in version 0.3.0"

"""
return self.get_input_schema(config).model_json_schema()
schema = self.get_input_schema(config)
if issubclass(schema, BaseModel):
return schema.model_json_schema()
return schema.schema()

@property
def output_schema(self) -> type[BaseModel]:
def output_schema(self) -> TypeBaseModel:
"""Output schema.

The type of output this `Runnable` produces specified as a Pydantic model.
Expand All @@ -440,7 +443,7 @@ def output_schema(self) -> type[BaseModel]:
def get_output_schema(
self,
config: RunnableConfig | None = None, # noqa: ARG002
) -> type[BaseModel]:
) -> TypeBaseModel:
"""Get a Pydantic model that can be used to validate output to the `Runnable`.

`Runnable` objects that leverage the `configurable_fields` and
Expand Down Expand Up @@ -505,7 +508,10 @@ def add_one(x: int) -> int:
!!! version-added "Added in version 0.3.0"

"""
return self.get_output_schema(config).model_json_schema()
schema = self.get_output_schema(config)
if issubclass(schema, BaseModel):
return schema.model_json_schema()
return schema.schema()

@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
Expand Down Expand Up @@ -2671,7 +2677,7 @@ def configurable_alternatives(

def _seq_input_schema(
steps: list[Runnable[Any, Any]], config: RunnableConfig | None
) -> type[BaseModel]:
) -> TypeBaseModel:
# Import locally to prevent circular import
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
RunnableAssign,
Expand All @@ -2689,7 +2695,7 @@ def _seq_input_schema(
"RunnableSequenceInput",
field_definitions={
k: (v.annotation, v.default)
for k, v in next_input_schema.model_fields.items()
for k, v in get_fields(next_input_schema).items()
if k not in first.mapper.steps__
},
)
Expand All @@ -2701,7 +2707,7 @@ def _seq_input_schema(

def _seq_output_schema(
steps: list[Runnable[Any, Any]], config: RunnableConfig | None
) -> type[BaseModel]:
) -> TypeBaseModel:
# Import locally to prevent circular import
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
RunnableAssign,
Expand All @@ -2721,7 +2727,7 @@ def _seq_output_schema(
field_definitions={
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.model_fields.items()
for k, v in get_fields(prev_output_schema).items()
},
**{
k: (v.annotation, v.default)
Expand All @@ -2738,11 +2744,11 @@ def _seq_output_schema(
"RunnableSequenceOutput",
field_definitions={
k: (v.annotation, v.default)
for k, v in prev_output_schema.model_fields.items()
for k, v in get_fields(prev_output_schema).items()
if k in last.keys
},
)
field = prev_output_schema.model_fields[last.keys]
field = get_fields(prev_output_schema)[last.keys]
return create_model_v2(
"RunnableSequenceOutput", root=(field.annotation, field.default)
)
Expand Down Expand Up @@ -2924,7 +2930,7 @@ def OutputType(self) -> type[Output]:
return self.last.OutputType

@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""Get the input schema of the `Runnable`.

Args:
Expand All @@ -2937,9 +2943,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
return _seq_input_schema(self.steps, config)

@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""Get the output schema of the `Runnable`.

Args:
Expand Down Expand Up @@ -3653,7 +3657,7 @@ def InputType(self) -> Any:
return Any

@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""Get the input schema of the `Runnable`.

Args:
Expand All @@ -3664,8 +3668,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod

"""
if all(
s.get_input_schema(config).model_json_schema().get("type", "object")
== "object"
s.get_input_jsonschema(config).get("type", "object") == "object"
for s in self.steps__.values()
):
# This is correct, but pydantic typings/mypy don't think so.
Expand All @@ -3674,7 +3677,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
field_definitions={
k: (v.annotation, v.default)
for step in self.steps__.values()
for k, v in step.get_input_schema(config).model_fields.items()
for k, v in get_fields(step.get_input_schema(config)).items()
if k != "__root__"
},
)
Expand Down Expand Up @@ -4460,7 +4463,7 @@ def InputType(self) -> Any:
return Any

@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
"""The Pydantic schema for the input to this `Runnable`.

Args:
Expand Down Expand Up @@ -5437,15 +5440,13 @@ def OutputType(self) -> type[Output]:
)

@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
if self.custom_input_type is not None:
return super().get_input_schema(config)
return self.bound.get_input_schema(merge_configs(self.config, config))

@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
if self.custom_output_type is not None:
return super().get_output_schema(config)
return self.bound.get_output_schema(merge_configs(self.config, config))
Expand Down
10 changes: 4 additions & 6 deletions libs/core/langchain_core/runnables/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
cast,
)

from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from typing_extensions import override

from langchain_core.runnables.base import (
Expand All @@ -35,6 +35,7 @@
Output,
get_unique_config_specs,
)
from langchain_core.utils.pydantic import TypeBaseModel


class RunnableBranch(RunnableSerializable[Input, Output]):
Expand Down Expand Up @@ -154,18 +155,15 @@ def get_lc_namespace(cls) -> list[str]:
return ["langchain", "schema", "runnable"]

@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)

for runnable in runnables:
if (
runnable.get_input_schema(config).model_json_schema().get("type")
is not None
):
if runnable.get_input_jsonschema(config).get("type") is not None:
return runnable.get_input_schema(config)

return super().get_input_schema(config)
Expand Down
9 changes: 4 additions & 5 deletions libs/core/langchain_core/runnables/configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from weakref import WeakValueDictionary

from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from typing_extensions import override

from langchain_core.runnables.base import Runnable, RunnableSerializable
Expand All @@ -41,6 +41,7 @@
gather_with_concurrency,
get_unique_config_specs,
)
from langchain_core.utils.pydantic import TypeBaseModel

if TYPE_CHECKING:
from langchain_core.runnables.graph import Graph
Expand Down Expand Up @@ -90,14 +91,12 @@ def OutputType(self) -> type[Output]:
return self.default.OutputType

@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
runnable, config = self.prepare(config)
return runnable.get_input_schema(config)

@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
runnable, config = self.prepare(config)
return runnable.get_output_schema(config)

Expand Down
9 changes: 4 additions & 5 deletions libs/core/langchain_core/runnables/fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, cast

from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from typing_extensions import override

from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
Expand All @@ -29,6 +29,7 @@
get_unique_config_specs,
)
from langchain_core.utils.aiter import py_anext
from langchain_core.utils.pydantic import TypeBaseModel

if TYPE_CHECKING:
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
Expand Down Expand Up @@ -116,13 +117,11 @@ def OutputType(self) -> type[Output]:
return self.runnable.OutputType

@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
return self.runnable.get_input_schema(config)

@override
def get_output_schema(
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
return self.runnable.get_output_schema(config)

@property
Expand Down
14 changes: 8 additions & 6 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
from langchain_core.utils.pydantic import (
TypeBaseModel,
_IgnoreUnserializable,
is_basemodel_subclass,
)

if TYPE_CHECKING:
from collections.abc import Sequence

from pydantic import BaseModel

from langchain_core.runnables.base import Runnable as RunnableType


Expand Down Expand Up @@ -98,7 +100,7 @@ class Node(NamedTuple):
"""The unique identifier of the node."""
name: str
"""The name of the node."""
data: type[BaseModel] | RunnableType | None
data: TypeBaseModel | RunnableType | None
"""The data of the node."""
metadata: dict[str, Any] | None
"""Optional metadata for the node. """
Expand Down Expand Up @@ -178,7 +180,7 @@ class MermaidDrawMethod(Enum):

def node_data_str(
id: str,
data: type[BaseModel] | RunnableType | None,
data: TypeBaseModel | RunnableType | None,
) -> str:
"""Convert the data of a node to a string.

Expand Down Expand Up @@ -312,7 +314,7 @@ def next_id(self) -> str:

def add_node(
self,
data: type[BaseModel] | RunnableType | None,
data: TypeBaseModel | RunnableType | None,
id: str | None = None,
*,
metadata: dict[str, Any] | None = None,
Expand Down
Loading