Skip to content

Commit 59b686e

Browse files
committed
fix(core): fix Pydantic v1 support in tools/runnable
1 parent 62769a0 commit 59b686e

File tree

13 files changed

+87
-88
lines changed

13 files changed

+87
-88
lines changed

libs/core/langchain_core/language_models/chat_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
convert_to_json_schema,
7070
convert_to_openai_tool,
7171
)
72-
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
72+
from langchain_core.utils.pydantic import is_basemodel_subclass
7373
from langchain_core.utils.utils import LC_ID_PREFIX, from_env
7474

7575
if TYPE_CHECKING:
@@ -1650,7 +1650,7 @@ class AnswerWithJustification(BaseModel):
16501650
)
16511651
if isinstance(schema, type) and is_basemodel_subclass(schema):
16521652
output_parser: OutputParserLike = PydanticToolsParser(
1653-
tools=[cast("TypeBaseModel", schema)], first_tool_only=True
1653+
tools=[schema], first_tool_only=True
16541654
)
16551655
else:
16561656
key_name = convert_to_openai_tool(schema)["function"]["name"]

libs/core/langchain_core/runnables/base.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
)
9797
from langchain_core.utils.aiter import aclosing, atee, py_anext
9898
from langchain_core.utils.iter import safetee
99-
from langchain_core.utils.pydantic import create_model_v2
99+
from langchain_core.utils.pydantic import TypeBaseModel, create_model_v2, get_fields
100100

101101
if TYPE_CHECKING:
102102
from langchain_core.callbacks.manager import (
@@ -355,14 +355,14 @@ def OutputType(self) -> type[Output]: # noqa: N802
355355
raise TypeError(msg)
356356

357357
@property
358-
def input_schema(self) -> type[BaseModel]:
358+
def input_schema(self) -> TypeBaseModel:
359359
"""The type of input this `Runnable` accepts specified as a Pydantic model."""
360360
return self.get_input_schema()
361361

362362
def get_input_schema(
363363
self,
364364
config: RunnableConfig | None = None, # noqa: ARG002
365-
) -> type[BaseModel]:
365+
) -> TypeBaseModel:
366366
"""Get a Pydantic model that can be used to validate input to the `Runnable`.
367367
368368
`Runnable` objects that leverage the `configurable_fields` and
@@ -427,10 +427,13 @@ def add_one(x: int) -> int:
427427
!!! version-added "Added in version 0.3.0"
428428
429429
"""
430-
return self.get_input_schema(config).model_json_schema()
430+
schema = self.get_input_schema(config)
431+
if issubclass(schema, BaseModel):
432+
return schema.model_json_schema()
433+
return schema.schema()
431434

432435
@property
433-
def output_schema(self) -> type[BaseModel]:
436+
def output_schema(self) -> TypeBaseModel:
434437
"""Output schema.
435438
436439
The type of output this `Runnable` produces specified as a Pydantic model.
@@ -440,7 +443,7 @@ def output_schema(self) -> type[BaseModel]:
440443
def get_output_schema(
441444
self,
442445
config: RunnableConfig | None = None, # noqa: ARG002
443-
) -> type[BaseModel]:
446+
) -> TypeBaseModel:
444447
"""Get a Pydantic model that can be used to validate output to the `Runnable`.
445448
446449
`Runnable` objects that leverage the `configurable_fields` and
@@ -505,7 +508,10 @@ def add_one(x: int) -> int:
505508
!!! version-added "Added in version 0.3.0"
506509
507510
"""
508-
return self.get_output_schema(config).model_json_schema()
511+
schema = self.get_output_schema(config)
512+
if issubclass(schema, BaseModel):
513+
return schema.model_json_schema()
514+
return schema.schema()
509515

510516
@property
511517
def config_specs(self) -> list[ConfigurableFieldSpec]:
@@ -2671,7 +2677,7 @@ def configurable_alternatives(
26712677

26722678
def _seq_input_schema(
26732679
steps: list[Runnable[Any, Any]], config: RunnableConfig | None
2674-
) -> type[BaseModel]:
2680+
) -> TypeBaseModel:
26752681
# Import locally to prevent circular import
26762682
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
26772683
RunnableAssign,
@@ -2689,7 +2695,7 @@ def _seq_input_schema(
26892695
"RunnableSequenceInput",
26902696
field_definitions={
26912697
k: (v.annotation, v.default)
2692-
for k, v in next_input_schema.model_fields.items()
2698+
for k, v in get_fields(next_input_schema).items()
26932699
if k not in first.mapper.steps__
26942700
},
26952701
)
@@ -2701,7 +2707,7 @@ def _seq_input_schema(
27012707

27022708
def _seq_output_schema(
27032709
steps: list[Runnable[Any, Any]], config: RunnableConfig | None
2704-
) -> type[BaseModel]:
2710+
) -> TypeBaseModel:
27052711
# Import locally to prevent circular import
27062712
from langchain_core.runnables.passthrough import ( # noqa: PLC0415
27072713
RunnableAssign,
@@ -2721,7 +2727,7 @@ def _seq_output_schema(
27212727
field_definitions={
27222728
**{
27232729
k: (v.annotation, v.default)
2724-
for k, v in prev_output_schema.model_fields.items()
2730+
for k, v in get_fields(prev_output_schema).items()
27252731
},
27262732
**{
27272733
k: (v.annotation, v.default)
@@ -2738,11 +2744,11 @@ def _seq_output_schema(
27382744
"RunnableSequenceOutput",
27392745
field_definitions={
27402746
k: (v.annotation, v.default)
2741-
for k, v in prev_output_schema.model_fields.items()
2747+
for k, v in get_fields(prev_output_schema).items()
27422748
if k in last.keys
27432749
},
27442750
)
2745-
field = prev_output_schema.model_fields[last.keys]
2751+
field = get_fields(prev_output_schema)[last.keys]
27462752
return create_model_v2(
27472753
"RunnableSequenceOutput", root=(field.annotation, field.default)
27482754
)
@@ -2924,7 +2930,7 @@ def OutputType(self) -> type[Output]:
29242930
return self.last.OutputType
29252931

29262932
@override
2927-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
2933+
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
29282934
"""Get the input schema of the `Runnable`.
29292935
29302936
Args:
@@ -2937,9 +2943,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
29372943
return _seq_input_schema(self.steps, config)
29382944

29392945
@override
2940-
def get_output_schema(
2941-
self, config: RunnableConfig | None = None
2942-
) -> type[BaseModel]:
2946+
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
29432947
"""Get the output schema of the `Runnable`.
29442948
29452949
Args:
@@ -3653,7 +3657,7 @@ def InputType(self) -> Any:
36533657
return Any
36543658

36553659
@override
3656-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
3660+
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
36573661
"""Get the input schema of the `Runnable`.
36583662
36593663
Args:
@@ -3664,8 +3668,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
36643668
36653669
"""
36663670
if all(
3667-
s.get_input_schema(config).model_json_schema().get("type", "object")
3668-
== "object"
3671+
s.get_input_jsonschema(config).get("type", "object") == "object"
36693672
for s in self.steps__.values()
36703673
):
36713674
# This is correct, but pydantic typings/mypy don't think so.
@@ -3674,7 +3677,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
36743677
field_definitions={
36753678
k: (v.annotation, v.default)
36763679
for step in self.steps__.values()
3677-
for k, v in step.get_input_schema(config).model_fields.items()
3680+
for k, v in get_fields(step.get_input_schema(config)).items()
36783681
if k != "__root__"
36793682
},
36803683
)
@@ -4460,7 +4463,7 @@ def InputType(self) -> Any:
44604463
return Any
44614464

44624465
@override
4463-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
4466+
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
44644467
"""The Pydantic schema for the input to this `Runnable`.
44654468
44664469
Args:
@@ -5437,15 +5440,13 @@ def OutputType(self) -> type[Output]:
54375440
)
54385441

54395442
@override
5440-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
5443+
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
54415444
if self.custom_input_type is not None:
54425445
return super().get_input_schema(config)
54435446
return self.bound.get_input_schema(merge_configs(self.config, config))
54445447

54455448
@override
5446-
def get_output_schema(
5447-
self, config: RunnableConfig | None = None
5448-
) -> type[BaseModel]:
5449+
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
54495450
if self.custom_output_type is not None:
54505451
return super().get_output_schema(config)
54515452
return self.bound.get_output_schema(merge_configs(self.config, config))

libs/core/langchain_core/runnables/branch.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
cast,
1414
)
1515

16-
from pydantic import BaseModel, ConfigDict
16+
from pydantic import ConfigDict
1717
from typing_extensions import override
1818

1919
from langchain_core.runnables.base import (
@@ -35,6 +35,7 @@
3535
Output,
3636
get_unique_config_specs,
3737
)
38+
from langchain_core.utils.pydantic import TypeBaseModel
3839

3940

4041
class RunnableBranch(RunnableSerializable[Input, Output]):
@@ -154,18 +155,15 @@ def get_lc_namespace(cls) -> list[str]:
154155
return ["langchain", "schema", "runnable"]
155156

156157
@override
157-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
158+
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
158159
runnables = (
159160
[self.default]
160161
+ [r for _, r in self.branches]
161162
+ [r for r, _ in self.branches]
162163
)
163164

164165
for runnable in runnables:
165-
if (
166-
runnable.get_input_schema(config).model_json_schema().get("type")
167-
is not None
168-
):
166+
if runnable.get_input_jsonschema(config).get("type") is not None:
169167
return runnable.get_input_schema(config)
170168

171169
return super().get_input_schema(config)

libs/core/langchain_core/runnables/configurable.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from weakref import WeakValueDictionary
2121

22-
from pydantic import BaseModel, ConfigDict
22+
from pydantic import ConfigDict
2323
from typing_extensions import override
2424

2525
from langchain_core.runnables.base import Runnable, RunnableSerializable
@@ -41,6 +41,7 @@
4141
gather_with_concurrency,
4242
get_unique_config_specs,
4343
)
44+
from langchain_core.utils.pydantic import TypeBaseModel
4445

4546
if TYPE_CHECKING:
4647
from langchain_core.runnables.graph import Graph
@@ -90,14 +91,12 @@ def OutputType(self) -> type[Output]:
9091
return self.default.OutputType
9192

9293
@override
93-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
94+
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
9495
runnable, config = self.prepare(config)
9596
return runnable.get_input_schema(config)
9697

9798
@override
98-
def get_output_schema(
99-
self, config: RunnableConfig | None = None
100-
) -> type[BaseModel]:
99+
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
101100
runnable, config = self.prepare(config)
102101
return runnable.get_output_schema(config)
103102

libs/core/langchain_core/runnables/fallbacks.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import wraps
88
from typing import TYPE_CHECKING, Any, cast
99

10-
from pydantic import BaseModel, ConfigDict
10+
from pydantic import ConfigDict
1111
from typing_extensions import override
1212

1313
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
@@ -29,6 +29,7 @@
2929
get_unique_config_specs,
3030
)
3131
from langchain_core.utils.aiter import py_anext
32+
from langchain_core.utils.pydantic import TypeBaseModel
3233

3334
if TYPE_CHECKING:
3435
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
@@ -116,13 +117,11 @@ def OutputType(self) -> type[Output]:
116117
return self.runnable.OutputType
117118

118119
@override
119-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
120+
def get_input_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
120121
return self.runnable.get_input_schema(config)
121122

122123
@override
123-
def get_output_schema(
124-
self, config: RunnableConfig | None = None
125-
) -> type[BaseModel]:
124+
def get_output_schema(self, config: RunnableConfig | None = None) -> TypeBaseModel:
126125
return self.runnable.get_output_schema(config)
127126

128127
@property

libs/core/langchain_core/runnables/graph.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020
from langchain_core.load.serializable import to_json_not_implemented
2121
from langchain_core.runnables.base import Runnable, RunnableSerializable
22-
from langchain_core.utils.pydantic import _IgnoreUnserializable, is_basemodel_subclass
22+
from langchain_core.utils.pydantic import (
23+
TypeBaseModel,
24+
_IgnoreUnserializable,
25+
is_basemodel_subclass,
26+
)
2327

2428
if TYPE_CHECKING:
2529
from collections.abc import Sequence
2630

27-
from pydantic import BaseModel
28-
2931
from langchain_core.runnables.base import Runnable as RunnableType
3032

3133

@@ -98,7 +100,7 @@ class Node(NamedTuple):
98100
"""The unique identifier of the node."""
99101
name: str
100102
"""The name of the node."""
101-
data: type[BaseModel] | RunnableType | None
103+
data: TypeBaseModel | RunnableType | None
102104
"""The data of the node."""
103105
metadata: dict[str, Any] | None
104106
"""Optional metadata for the node. """
@@ -178,7 +180,7 @@ class MermaidDrawMethod(Enum):
178180

179181
def node_data_str(
180182
id: str,
181-
data: type[BaseModel] | RunnableType | None,
183+
data: TypeBaseModel | RunnableType | None,
182184
) -> str:
183185
"""Convert the data of a node to a string.
184186
@@ -312,7 +314,7 @@ def next_id(self) -> str:
312314

313315
def add_node(
314316
self,
315-
data: type[BaseModel] | RunnableType | None,
317+
data: TypeBaseModel | RunnableType | None,
316318
id: str | None = None,
317319
*,
318320
metadata: dict[str, Any] | None = None,

0 commit comments

Comments
 (0)