118118
119119Other = TypeVar ("Other" )
120120
121+ _RUNNABLE_GENERIC_NUM_ARGS = 2 # Input and Output
122+
121123
122124class Runnable (ABC , Generic [Input , Output ]):
123125 """A unit of work that can be invoked, batched, streamed, transformed and composed.
@@ -309,15 +311,18 @@ def InputType(self) -> type[Input]: # noqa: N802
309311 for base in self .__class__ .mro ():
310312 if hasattr (base , "__pydantic_generic_metadata__" ):
311313 metadata = base .__pydantic_generic_metadata__
312- if "args" in metadata and len (metadata ["args" ]) == 2 :
314+ if (
315+ "args" in metadata
316+ and len (metadata ["args" ]) == _RUNNABLE_GENERIC_NUM_ARGS
317+ ):
313318 return metadata ["args" ][0 ]
314319
315320 # If we didn't find a Pydantic model in the parent classes,
316321 # then loop through __orig_bases__. This corresponds to
317322 # Runnables that are not pydantic models.
318323 for cls in self .__class__ .__orig_bases__ : # type: ignore[attr-defined]
319324 type_args = get_args (cls )
320- if type_args and len (type_args ) == 2 :
325+ if type_args and len (type_args ) == _RUNNABLE_GENERIC_NUM_ARGS :
321326 return type_args [0 ]
322327
323328 msg = (
@@ -340,12 +345,15 @@ def OutputType(self) -> type[Output]: # noqa: N802
340345 for base in self .__class__ .mro ():
341346 if hasattr (base , "__pydantic_generic_metadata__" ):
342347 metadata = base .__pydantic_generic_metadata__
343- if "args" in metadata and len (metadata ["args" ]) == 2 :
348+ if (
349+ "args" in metadata
350+ and len (metadata ["args" ]) == _RUNNABLE_GENERIC_NUM_ARGS
351+ ):
344352 return metadata ["args" ][1 ]
345353
346354 for cls in self .__class__ .__orig_bases__ : # type: ignore[attr-defined]
347355 type_args = get_args (cls )
348- if type_args and len (type_args ) == 2 :
356+ if type_args and len (type_args ) == _RUNNABLE_GENERIC_NUM_ARGS :
349357 return type_args [1 ]
350358
351359 msg = (
@@ -2750,6 +2758,9 @@ def _seq_output_schema(
27502758 return last .get_output_schema (config )
27512759
27522760
2761+ _RUNNABLE_SEQUENCE_MIN_STEPS = 2
2762+
2763+
27532764class RunnableSequence (RunnableSerializable [Input , Output ]):
27542765 """Sequence of `Runnable` objects, where the output of one is the input of the next.
27552766
@@ -2872,8 +2883,11 @@ def __init__(
28722883 steps_flat .extend (step .steps )
28732884 else :
28742885 steps_flat .append (coerce_to_runnable (step ))
2875- if len (steps_flat ) < 2 :
2876- msg = f"RunnableSequence must have at least 2 steps, got { len (steps_flat )} "
2886+ if len (steps_flat ) < _RUNNABLE_SEQUENCE_MIN_STEPS :
2887+ msg = (
2888+ f"RunnableSequence must have at least { _RUNNABLE_SEQUENCE_MIN_STEPS } "
2889+ f"steps, got { len (steps_flat )} "
2890+ )
28772891 raise ValueError (msg )
28782892 super ().__init__ (
28792893 first = steps_flat [0 ],
@@ -4477,7 +4491,7 @@ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseMod
44774491 # on itemgetter objects, so we have to parse the repr
44784492 items = str (func ).replace ("operator.itemgetter(" , "" )[:- 1 ].split (", " )
44794493 if all (
4480- item [0 ] == "'" and item [- 1 ] == "'" and len ( item ) > 2 for item in items
4494+ item [0 ] == "'" and item [- 1 ] == "'" and item != "''" for item in items
44814495 ):
44824496 fields = {item [1 :- 1 ]: (Any , ...) for item in items }
44834497 # It's a dict, lol
0 commit comments