Skip to content

Commit 832ab93

Browse files
authored
fix: make gpu-enabled simple function pickleable (#916)
fix: make gpu-enabled simple function picklable
1 parent ecc9377 commit 832ab93

File tree

2 files changed

+24
-30
lines changed

2 files changed

+24
-30
lines changed

python/cocoindex/flow.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,26 @@ def _spec_kind(spec: Any) -> str:
105105

106106
def _transform_helper(
107107
flow_builder_state: _FlowBuilderState,
108-
fn_spec: FunctionSpec,
108+
fn_spec: FunctionSpec | Callable[..., Any],
109109
transform_args: list[tuple[Any, str | None]],
110110
name: str | None = None,
111111
) -> DataSlice[Any]:
112-
if not isinstance(fn_spec, FunctionSpec):
112+
if isinstance(fn_spec, FunctionSpec):
113+
kind = _spec_kind(fn_spec)
114+
spec = fn_spec
115+
elif callable(fn_spec) and (
116+
op_kind := getattr(fn_spec, "__cocoindex_op_kind__", None)
117+
):
118+
kind = op_kind
119+
spec = op.EmptyFunctionSpec()
120+
else:
113121
raise ValueError("transform() can only be called on a CocoIndex function")
114122

115123
return _create_data_slice(
116124
flow_builder_state,
117125
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
118-
_spec_kind(fn_spec),
119-
dump_engine_object(fn_spec),
126+
kind,
127+
dump_engine_object(spec),
120128
transform_args,
121129
target_scope,
122130
flow_builder_state.field_name_builder.build_name(
@@ -245,7 +253,7 @@ def for_each(
245253
f(scope)
246254

247255
def transform(
248-
self, fn_spec: op.FunctionSpec, *args: Any, **kwargs: Any
256+
self, fn_spec: op.FunctionSpec | Callable[..., Any], *args: Any, **kwargs: Any
249257
) -> DataSlice[Any]:
250258
"""
251259
Apply a function to the data slice.
@@ -513,7 +521,7 @@ def add_source(
513521
)
514522

515523
def transform(
516-
self, fn_spec: FunctionSpec, *args: Any, **kwargs: Any
524+
self, fn_spec: FunctionSpec | Callable[..., Any], *args: Any, **kwargs: Any
517525
) -> DataSlice[Any]:
518526
"""
519527
Apply a function to inputs, returning a DataSlice.

python/cocoindex/op.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -388,53 +388,39 @@ def _inner(cls: type[Executor]) -> type:
388388
return _inner
389389

390390

391-
class _EmptyFunctionSpec(FunctionSpec):
391+
class EmptyFunctionSpec(FunctionSpec):
392392
pass
393393

394394

395395
class _SimpleFunctionExecutor:
396-
spec: Any
396+
spec: Callable[..., Any]
397397

398398
def prepare(self) -> None:
399-
self.__call__ = self.spec.__call__
399+
self.__call__ = staticmethod(self.spec)
400400

401401

402-
def function(**args: Any) -> Callable[[Callable[..., Any]], FunctionSpec]:
402+
def function(**args: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
403403
"""
404404
Decorate a function to provide a function for an op.
405405
"""
406406
op_args = OpArgs(**args)
407407

408-
def _inner(fn: Callable[..., Any]) -> FunctionSpec:
408+
def _inner(fn: Callable[..., Any]) -> Callable[..., Any]:
409409
# Convert snake case to camel case.
410-
op_name = "".join(word.capitalize() for word in fn.__name__.split("_"))
410+
op_kind = "".join(word.capitalize() for word in fn.__name__.split("_"))
411411
sig = inspect.signature(fn)
412-
full_name = f"{fn.__module__}.{fn.__qualname__}"
413-
414-
# An object that is both callable and can act as a FunctionSpec.
415-
class _CallableSpec(_EmptyFunctionSpec):
416-
__call__ = staticmethod(fn)
417-
418-
def __reduce__(self) -> str | tuple[Any, ...]:
419-
return full_name
420-
421-
_CallableSpec.__name__ = op_name
422-
_CallableSpec.__doc__ = fn.__doc__
423-
_CallableSpec.__qualname__ = fn.__qualname__
424-
_CallableSpec.__module__ = fn.__module__
425-
callable_spec = _CallableSpec()
426-
412+
fn.__cocoindex_op_kind__ = op_kind # type: ignore
427413
_register_op_factory(
428414
category=OpCategory.FUNCTION,
429415
expected_args=list(sig.parameters.items()),
430416
expected_return=sig.return_annotation,
431417
executor_factory=_SimpleFunctionExecutor,
432-
spec_loader=lambda: callable_spec,
433-
op_kind=op_name,
418+
spec_loader=lambda: fn,
419+
op_kind=op_kind,
434420
op_args=op_args,
435421
)
436422

437-
return callable_spec
423+
return fn
438424

439425
return _inner
440426

0 commit comments

Comments
 (0)