9696)
9797from langchain_core .utils .aiter import aclosing , atee , py_anext
9898from langchain_core .utils .iter import safetee
99- from langchain_core .utils .pydantic import create_model_v2
99+ from langchain_core .utils .pydantic import TypeBaseModel , create_model_v2 , get_fields
100100
101101if TYPE_CHECKING :
102102 from langchain_core .callbacks .manager import (
@@ -355,14 +355,14 @@ def OutputType(self) -> type[Output]: # noqa: N802
355355 raise TypeError (msg )
356356
357357 @property
358- def input_schema (self ) -> type [ BaseModel ] :
358+ def input_schema (self ) -> TypeBaseModel :
359359 """The type of input this `Runnable` accepts specified as a Pydantic model."""
360360 return self .get_input_schema ()
361361
362362 def get_input_schema (
363363 self ,
364364 config : RunnableConfig | None = None , # noqa: ARG002
365- ) -> type [ BaseModel ] :
365+ ) -> TypeBaseModel :
366366 """Get a Pydantic model that can be used to validate input to the `Runnable`.
367367
368368 `Runnable` objects that leverage the `configurable_fields` and
@@ -427,10 +427,13 @@ def add_one(x: int) -> int:
427427 !!! version-added "Added in version 0.3.0"
428428
429429 """
430- return self .get_input_schema (config ).model_json_schema ()
430+ schema = self .get_input_schema (config )
431+ if issubclass (schema , BaseModel ):
432+ return schema .model_json_schema ()
433+ return schema .schema ()
431434
432435 @property
433- def output_schema (self ) -> type [ BaseModel ] :
436+ def output_schema (self ) -> TypeBaseModel :
434437 """Output schema.
435438
436439 The type of output this `Runnable` produces specified as a Pydantic model.
@@ -440,7 +443,7 @@ def output_schema(self) -> type[BaseModel]:
440443 def get_output_schema (
441444 self ,
442445 config : RunnableConfig | None = None , # noqa: ARG002
443- ) -> type [ BaseModel ] :
446+ ) -> TypeBaseModel :
444447 """Get a Pydantic model that can be used to validate output to the `Runnable`.
445448
446449 `Runnable` objects that leverage the `configurable_fields` and
@@ -505,7 +508,10 @@ def add_one(x: int) -> int:
505508 !!! version-added "Added in version 0.3.0"
506509
507510 """
508- return self .get_output_schema (config ).model_json_schema ()
511+ schema = self .get_output_schema (config )
512+ if issubclass (schema , BaseModel ):
513+ return schema .model_json_schema ()
514+ return schema .schema ()
509515
510516 @property
511517 def config_specs (self ) -> list [ConfigurableFieldSpec ]:
@@ -2671,7 +2677,7 @@ def configurable_alternatives(
26712677
26722678def _seq_input_schema (
26732679 steps : list [Runnable [Any , Any ]], config : RunnableConfig | None
2674- ) -> type [ BaseModel ] :
2680+ ) -> TypeBaseModel :
26752681 # Import locally to prevent circular import
26762682 from langchain_core .runnables .passthrough import ( # noqa: PLC0415
26772683 RunnableAssign ,
@@ -2689,7 +2695,7 @@ def _seq_input_schema(
26892695 "RunnableSequenceInput" ,
26902696 field_definitions = {
26912697 k : (v .annotation , v .default )
2692- for k , v in next_input_schema . model_fields .items ()
2698+ for k , v in get_fields ( next_input_schema ) .items ()
26932699 if k not in first .mapper .steps__
26942700 },
26952701 )
@@ -2701,7 +2707,7 @@ def _seq_input_schema(
27012707
27022708def _seq_output_schema (
27032709 steps : list [Runnable [Any , Any ]], config : RunnableConfig | None
2704- ) -> type [ BaseModel ] :
2710+ ) -> TypeBaseModel :
27052711 # Import locally to prevent circular import
27062712 from langchain_core .runnables .passthrough import ( # noqa: PLC0415
27072713 RunnableAssign ,
@@ -2721,7 +2727,7 @@ def _seq_output_schema(
27212727 field_definitions = {
27222728 ** {
27232729 k : (v .annotation , v .default )
2724- for k , v in prev_output_schema . model_fields .items ()
2730+ for k , v in get_fields ( prev_output_schema ) .items ()
27252731 },
27262732 ** {
27272733 k : (v .annotation , v .default )
@@ -2738,11 +2744,11 @@ def _seq_output_schema(
27382744 "RunnableSequenceOutput" ,
27392745 field_definitions = {
27402746 k : (v .annotation , v .default )
2741- for k , v in prev_output_schema . model_fields .items ()
2747+ for k , v in get_fields ( prev_output_schema ) .items ()
27422748 if k in last .keys
27432749 },
27442750 )
2745- field = prev_output_schema . model_fields [last .keys ]
2751+ field = get_fields ( prev_output_schema ) [last .keys ]
27462752 return create_model_v2 (
27472753 "RunnableSequenceOutput" , root = (field .annotation , field .default )
27482754 )
@@ -2924,7 +2930,7 @@ def OutputType(self) -> type[Output]:
29242930 return self .last .OutputType
29252931
29262932 @override
2927- def get_input_schema (self , config : RunnableConfig | None = None ) -> type [ BaseModel ] :
2933+ def get_input_schema (self , config : RunnableConfig | None = None ) -> TypeBaseModel :
29282934 """Get the input schema of the `Runnable`.
29292935
29302936 Args:
@@ -2937,9 +2943,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
29372943 return _seq_input_schema (self .steps , config )
29382944
29392945 @override
2940- def get_output_schema (
2941- self , config : RunnableConfig | None = None
2942- ) -> type [BaseModel ]:
2946+ def get_output_schema (self , config : RunnableConfig | None = None ) -> TypeBaseModel :
29432947 """Get the output schema of the `Runnable`.
29442948
29452949 Args:
@@ -3653,7 +3657,7 @@ def InputType(self) -> Any:
36533657 return Any
36543658
36553659 @override
3656- def get_input_schema (self , config : RunnableConfig | None = None ) -> type [ BaseModel ] :
3660+ def get_input_schema (self , config : RunnableConfig | None = None ) -> TypeBaseModel :
36573661 """Get the input schema of the `Runnable`.
36583662
36593663 Args:
@@ -3664,8 +3668,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
36643668
36653669 """
36663670 if all (
3667- s .get_input_schema (config ).model_json_schema ().get ("type" , "object" )
3668- == "object"
3671+ s .get_input_jsonschema (config ).get ("type" , "object" ) == "object"
36693672 for s in self .steps__ .values ()
36703673 ):
36713674 # This is correct, but pydantic typings/mypy don't think so.
@@ -3674,7 +3677,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
36743677 field_definitions = {
36753678 k : (v .annotation , v .default )
36763679 for step in self .steps__ .values ()
3677- for k , v in step .get_input_schema (config ). model_fields .items ()
3680+ for k , v in get_fields ( step .get_input_schema (config )) .items ()
36783681 if k != "__root__"
36793682 },
36803683 )
@@ -4460,7 +4463,7 @@ def InputType(self) -> Any:
44604463 return Any
44614464
44624465 @override
4463- def get_input_schema (self , config : RunnableConfig | None = None ) -> type [ BaseModel ] :
4466+ def get_input_schema (self , config : RunnableConfig | None = None ) -> TypeBaseModel :
44644467 """The Pydantic schema for the input to this `Runnable`.
44654468
44664469 Args:
@@ -5437,15 +5440,13 @@ def OutputType(self) -> type[Output]:
54375440 )
54385441
54395442 @override
5440- def get_input_schema (self , config : RunnableConfig | None = None ) -> type [ BaseModel ] :
5443+ def get_input_schema (self , config : RunnableConfig | None = None ) -> TypeBaseModel :
54415444 if self .custom_input_type is not None :
54425445 return super ().get_input_schema (config )
54435446 return self .bound .get_input_schema (merge_configs (self .config , config ))
54445447
54455448 @override
5446- def get_output_schema (
5447- self , config : RunnableConfig | None = None
5448- ) -> type [BaseModel ]:
5449+ def get_output_schema (self , config : RunnableConfig | None = None ) -> TypeBaseModel :
54495450 if self .custom_output_type is not None :
54505451 return super ().get_output_schema (config )
54515452 return self .bound .get_output_schema (merge_configs (self .config , config ))
0 commit comments