@@ -121,6 +121,7 @@ struct PyBatchedFunctionExecutor {
121121
122122 enable_cache : bool ,
123123 behavior_version : Option < u32 > ,
124+ batching_options : batching:: BatchingOptions ,
124125}
125126
126127#[ async_trait]
@@ -168,11 +169,13 @@ impl BatchedFunctionExecutor for PyBatchedFunctionExecutor {
168169 fn behavior_version ( & self ) -> Option < u32 > {
169170 self . behavior_version
170171 }
172+ fn batching_options ( & self ) -> batching:: BatchingOptions {
173+ self . batching_options . clone ( )
174+ }
171175}
172176
173177pub ( crate ) struct PyFunctionFactory {
174178 pub py_function_factory : Py < PyAny > ,
175- pub batching : bool ,
176179}
177180
178181#[ async_trait]
@@ -237,7 +240,7 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
237240 . as_ref ( )
238241 . ok_or_else ( || anyhow ! ( "Python execution context is missing" ) ) ?
239242 . clone ( ) ;
240- let ( prepare_fut, enable_cache, behavior_version) =
243+ let ( prepare_fut, enable_cache, behavior_version, batching_options ) =
241244 Python :: with_gil ( |py| -> anyhow:: Result < _ > {
242245 let prepare_coro = executor
243246 . call_method ( py, "prepare" , ( ) , None )
@@ -257,31 +260,45 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
257260 . call_method ( py, "behavior_version" , ( ) , None )
258261 . to_result_with_py_trace ( py) ?
259262 . extract :: < Option < u32 > > ( py) ?;
260- Ok ( ( prepare_fut, enable_cache, behavior_version) )
263+ let batching_options = executor
264+ . call_method ( py, "batching_options" , ( ) , None )
265+ . to_result_with_py_trace ( py) ?
266+ . extract :: < crate :: py:: Pythonized < Option < batching:: BatchingOptions > > > (
267+ py,
268+ ) ?
269+ . into_inner ( ) ;
270+ Ok ( (
271+ prepare_fut,
272+ enable_cache,
273+ behavior_version,
274+ batching_options,
275+ ) )
261276 } ) ?;
262277 prepare_fut. await ?;
263- let executor: Box < dyn interface:: SimpleFunctionExecutor > = if self . batching {
264- Box :: new (
265- PyBatchedFunctionExecutor {
278+ let executor: Box < dyn interface:: SimpleFunctionExecutor > =
279+ if let Some ( batching_options) = batching_options {
280+ Box :: new (
281+ PyBatchedFunctionExecutor {
282+ py_function_executor : executor,
283+ py_exec_ctx,
284+ result_type,
285+ enable_cache,
286+ behavior_version,
287+ batching_options,
288+ }
289+ . into_fn_executor ( ) ,
290+ )
291+ } else {
292+ Box :: new ( Arc :: new ( PyFunctionExecutor {
266293 py_function_executor : executor,
267294 py_exec_ctx,
295+ num_positional_args,
296+ kw_args_names,
268297 result_type,
269298 enable_cache,
270299 behavior_version,
271- }
272- . into_fn_executor ( ) ,
273- )
274- } else {
275- Box :: new ( Arc :: new ( PyFunctionExecutor {
276- py_function_executor : executor,
277- py_exec_ctx,
278- num_positional_args,
279- kw_args_names,
280- result_type,
281- enable_cache,
282- behavior_version,
283- } ) )
284- } ;
300+ } ) )
301+ } ;
285302 Ok ( executor)
286303 }
287304 } ;
0 commit comments