3232from . import index
3333from . import op
3434from . import setting
35- from .convert import dump_engine_object , encode_engine_value , make_engine_value_decoder
35+ from .convert import (
36+ dump_engine_object ,
37+ make_engine_value_decoder ,
38+ make_engine_value_encoder ,
39+ )
3640from .op import FunctionSpec
3741from .runtime import execution_context
3842from .setup import SetupChangeBundle
@@ -974,33 +978,60 @@ class TransformFlowInfo(NamedTuple):
974978 result_decoder : Callable [[Any ], T ]
975979
976980
981+ class FlowArgInfo (NamedTuple ):
982+ name : str
983+ type_hint : Any
984+ encoder : Callable [[Any ], Any ]
985+
986+
977987class TransformFlow (Generic [T ]):
978988 """
979989 A transient transformation flow that transforms in-memory data.
980990 """
981991
982992 _flow_fn : Callable [..., DataSlice [T ]]
983993 _flow_name : str
984- _flow_arg_types : list [Any ]
985- _param_names : list [str ]
994+ _args_info : list [FlowArgInfo ]
986995
987996 _lazy_lock : asyncio .Lock
988997 _lazy_flow_info : TransformFlowInfo | None = None
989998
990999 def __init__ (
9911000 self ,
9921001 flow_fn : Callable [..., DataSlice [T ]],
993- flow_arg_types : Sequence [Any ],
9941002 / ,
9951003 name : str | None = None ,
9961004 ):
9971005 self ._flow_fn = flow_fn
9981006 self ._flow_name = _transform_flow_name_builder .build_name (
9991007 name , prefix = "_transform_flow_"
10001008 )
1001- self ._flow_arg_types = list (flow_arg_types )
10021009 self ._lazy_lock = asyncio .Lock ()
10031010
1011+ sig = inspect .signature (flow_fn )
1012+ args_info = []
1013+ for param_name , param in sig .parameters .items ():
1014+ if param .kind not in (
1015+ inspect .Parameter .POSITIONAL_OR_KEYWORD ,
1016+ inspect .Parameter .KEYWORD_ONLY ,
1017+ ):
1018+ raise ValueError (
1019+ f"Parameter `{ param_name } ` is not a parameter can be passed by name"
1020+ )
1021+ value_type_annotation : type | None = _get_data_slice_annotation_type (
1022+ param .annotation
1023+ )
1024+ if value_type_annotation is None :
1025+ raise ValueError (
1026+ f"Parameter `{ param_name } ` for { flow_fn } has no value type annotation. "
1027+ "Please use `cocoindex.DataSlice[T]` where T is the type of the value."
1028+ )
1029+ encoder = make_engine_value_encoder (
1030+ analyze_type_info (value_type_annotation )
1031+ )
1032+ args_info .append (FlowArgInfo (param_name , value_type_annotation , encoder ))
1033+ self ._args_info = args_info
1034+
10041035 def __call__ (self , * args : Any , ** kwargs : Any ) -> DataSlice [T ]:
10051036 return self ._flow_fn (* args , ** kwargs )
10061037
@@ -1020,31 +1051,15 @@ async def _flow_info_async(self) -> TransformFlowInfo:
10201051
10211052 async def _build_flow_info_async (self ) -> TransformFlowInfo :
10221053 flow_builder_state = _FlowBuilderState (self ._flow_name )
1023- sig = inspect .signature (self ._flow_fn )
1024- if len (sig .parameters ) != len (self ._flow_arg_types ):
1025- raise ValueError (
1026- f"Number of parameters in the flow function ({ len (sig .parameters )} ) "
1027- f"does not match the number of argument types ({ len (self ._flow_arg_types )} )"
1028- )
1029-
10301054 kwargs : dict [str , DataSlice [T ]] = {}
1031- for (param_name , param ), param_type in zip (
1032- sig .parameters .items (), self ._flow_arg_types
1033- ):
1034- if param .kind not in (
1035- inspect .Parameter .POSITIONAL_OR_KEYWORD ,
1036- inspect .Parameter .KEYWORD_ONLY ,
1037- ):
1038- raise ValueError (
1039- f"Parameter `{ param_name } ` is not a parameter can be passed by name"
1040- )
1041- encoded_type = encode_enriched_type (param_type )
1055+ for arg_info in self ._args_info :
1056+ encoded_type = encode_enriched_type (arg_info .type_hint )
10421057 if encoded_type is None :
1043- raise ValueError (f"Parameter `{ param_name } ` has no type annotation" )
1058+ raise ValueError (f"Parameter `{ arg_info . name } ` has no type annotation" )
10441059 engine_ds = flow_builder_state .engine_flow_builder .add_direct_input (
1045- param_name , encoded_type
1060+ arg_info . name , encoded_type
10461061 )
1047- kwargs [param_name ] = DataSlice (
1062+ kwargs [arg_info . name ] = DataSlice (
10481063 _DataSliceState (flow_builder_state , engine_ds )
10491064 )
10501065
@@ -1057,13 +1072,12 @@ async def _build_flow_info_async(self) -> TransformFlowInfo:
10571072 execution_context .event_loop
10581073 )
10591074 )
1060- self ._param_names = list (sig .parameters .keys ())
10611075
10621076 engine_return_type = (
10631077 _data_slice_state (output ).engine_data_slice .data_type ().schema ()
10641078 )
10651079 python_return_type : type [T ] | None = _get_data_slice_annotation_type (
1066- sig .return_annotation
1080+ inspect . signature ( self . _flow_fn ) .return_annotation
10671081 )
10681082 result_decoder = make_engine_value_decoder (
10691083 [], engine_return_type ["type" ], analyze_type_info (python_return_type )
@@ -1095,18 +1109,14 @@ async def eval_async(self, *args: Any, **kwargs: Any) -> T:
10951109 """
10961110 flow_info = await self ._flow_info_async ()
10971111 params = []
1098- for i , (arg , arg_type ) in enumerate (
1099- zip (self ._param_names , self ._flow_arg_types )
1100- ):
1101- param_type = (
1102- self ._flow_arg_types [i ] if i < len (self ._flow_arg_types ) else Any
1103- )
1112+ for i , arg_info in enumerate (self ._args_info ):
11041113 if i < len (args ):
1105- params . append ( encode_engine_value ( args [i ], type_hint = param_type ))
1114+ arg = args [i ]
11061115 elif arg in kwargs :
1107- params . append ( encode_engine_value ( kwargs [arg ], type_hint = param_type ))
1116+ arg = kwargs [arg ]
11081117 else :
11091118 raise ValueError (f"Parameter { arg } is not provided" )
1119+ params .append (arg_info .encoder (arg ))
11101120 engine_result = await flow_info .engine_flow .evaluate_async (params )
11111121 return flow_info .result_decoder (engine_result )
11121122
@@ -1117,27 +1127,7 @@ def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]
11171127 """
11181128
11191129 def _transform_flow_wrapper (fn : Callable [..., DataSlice [T ]]) -> TransformFlow [T ]:
1120- sig = inspect .signature (fn )
1121- arg_types = []
1122- for param_name , param in sig .parameters .items ():
1123- if param .kind not in (
1124- inspect .Parameter .POSITIONAL_OR_KEYWORD ,
1125- inspect .Parameter .KEYWORD_ONLY ,
1126- ):
1127- raise ValueError (
1128- f"Parameter `{ param_name } ` is not a parameter can be passed by name"
1129- )
1130- value_type_annotation : type [T ] | None = _get_data_slice_annotation_type (
1131- param .annotation
1132- )
1133- if value_type_annotation is None :
1134- raise ValueError (
1135- f"Parameter `{ param_name } ` for { fn } has no value type annotation. "
1136- "Please use `cocoindex.DataSlice[T]` where T is the type of the value."
1137- )
1138- arg_types .append (value_type_annotation )
1139-
1140- _transform_flow = TransformFlow (fn , arg_types )
1130+ _transform_flow = TransformFlow (fn )
11411131 functools .update_wrapper (_transform_flow , fn )
11421132 return _transform_flow
11431133
0 commit comments