11"""
22Facilities for defining cocoindex operations.
33"""
4+ import asyncio
45import dataclasses
56import inspect
67
@@ -78,6 +79,7 @@ def _register_op_factory(
7879 category : OpCategory ,
7980 expected_args : list [tuple [str , inspect .Parameter ]],
8081 expected_return ,
82+ is_async : bool ,
8183 executor_cls : type ,
8284 spec_cls : type ,
8385 op_args : OpArgs ,
@@ -168,6 +170,19 @@ def __call__(self, *args, **kwargs):
168170 converted_args = (converter (arg ) for converter , arg in zip (self ._args_converters , args ))
169171 converted_kwargs = {arg_name : self ._kwargs_converters [arg_name ](arg )
170172 for arg_name , arg in kwargs .items ()}
173+ if is_async :
174+ async def _inner ():
175+ if op_args .gpu :
176+ await asyncio .to_thread (_gpu_dispatch_lock .acquire )
177+ try :
178+ output = await super (_WrappedClass , self ).__call__ (
179+ * converted_args , ** converted_kwargs )
180+ finally :
181+ if op_args .gpu :
182+ _gpu_dispatch_lock .release ()
183+ return to_engine_value (output )
184+ return _inner ()
185+
171186 if op_args .gpu :
172187 # For GPU executions, data-level parallelism is applied, so we don't want to
173188 # execute different tasks in parallel.
@@ -189,7 +204,8 @@ def __call__(self, *args, **kwargs):
189204 if category == OpCategory .FUNCTION :
190205 _engine .register_function_factory (
191206 spec_cls .__name__ ,
192- _FunctionExecutorFactory (spec_cls , _WrappedClass ))
207+ _FunctionExecutorFactory (spec_cls , _WrappedClass ),
208+ is_async )
193209 else :
194210 raise ValueError (f"Unsupported executor type { category } " )
195211
@@ -214,6 +230,7 @@ def _inner(cls: type[Executor]) -> type:
214230 category = spec_cls ._op_category ,
215231 expected_args = list (sig .parameters .items ())[1 :], # First argument is `self`
216232 expected_return = sig .return_annotation ,
233+ is_async = inspect .iscoroutinefunction (cls .__call__ ),
217234 executor_cls = cls ,
218235 spec_cls = spec_cls ,
219236 op_args = op_args )
@@ -249,6 +266,7 @@ class _Spec(FunctionSpec):
249266 category = OpCategory .FUNCTION ,
250267 expected_args = list (sig .parameters .items ()),
251268 expected_return = sig .return_annotation ,
269+ is_async = inspect .iscoroutinefunction (fn ),
252270 executor_cls = _Executor ,
253271 spec_cls = _Spec ,
254272 op_args = op_args )
0 commit comments