Skip to content

Commit 74b9345

Browse files
authored
feat(subprocessing): implement subprocessing for GPU workloads (#889)
* refactor: make sure spec classes are picklable * refactor: simplify wrapper class in `_register_op_factory` * feat(subprocessing): implement subprocessing for GPU workloads
1 parent 579aa96 commit 74b9345

File tree

6 files changed

+330
-98
lines changed

6 files changed

+330
-98
lines changed

python/cocoindex/cli.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import importlib.util
55
import os
66
import signal
7-
import sys
87
import threading
9-
import types
108
from types import FrameType
119
from typing import Any, Iterable
1210

@@ -20,6 +18,8 @@
2018
from . import flow, lib, setting
2119
from .setup import flow_names_with_setup
2220
from .runtime import execution_context
21+
from .subprocess_exec import add_user_app
22+
from .user_app_loader import load_user_app
2323

2424
# Create ServerSettings lazily upon first call, as environment variables may be loaded from files, etc.
2525
COCOINDEX_HOST = "https://cocoindex.io"
@@ -76,50 +76,9 @@ def _get_app_ref_from_specifier(
7676
return app_ref
7777

7878

79-
def _load_user_app(app_target: str) -> types.ModuleType:
80-
"""
81-
Loads the user's application, which can be a file path or an installed module name.
82-
Exits on failure.
83-
"""
84-
if not app_target:
85-
raise click.ClickException("Application target not provided.")
86-
87-
looks_like_path = os.sep in app_target or app_target.lower().endswith(".py")
88-
89-
if looks_like_path:
90-
if not os.path.isfile(app_target):
91-
raise click.ClickException(f"Application file path not found: {app_target}")
92-
app_path = os.path.abspath(app_target)
93-
app_dir = os.path.dirname(app_path)
94-
module_name = os.path.splitext(os.path.basename(app_path))[0]
95-
96-
if app_dir not in sys.path:
97-
sys.path.insert(0, app_dir)
98-
try:
99-
spec = importlib.util.spec_from_file_location(module_name, app_path)
100-
if spec is None:
101-
raise ImportError(f"Could not create spec for file: {app_path}")
102-
module = importlib.util.module_from_spec(spec)
103-
sys.modules[spec.name] = module
104-
if spec.loader is None:
105-
raise ImportError(f"Could not create loader for file: {app_path}")
106-
spec.loader.exec_module(module)
107-
return module
108-
except (ImportError, FileNotFoundError, PermissionError) as e:
109-
raise click.ClickException(f"Failed importing file '{app_path}': {e}")
110-
finally:
111-
if app_dir in sys.path and sys.path[0] == app_dir:
112-
sys.path.pop(0)
113-
114-
# Try as module
115-
try:
116-
return importlib.import_module(app_target)
117-
except ImportError as e:
118-
raise click.ClickException(f"Failed to load module '{app_target}': {e}")
119-
except Exception as e:
120-
raise click.ClickException(
121-
f"Unexpected error importing module '{app_target}': {e}"
122-
)
79+
def _load_user_app(app_target: str) -> None:
80+
load_user_app(app_target)
81+
add_user_app(app_target)
12382

12483

12584
def _initialize_cocoindex_in_process() -> None:

python/cocoindex/functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class SentenceTransformerEmbedExecutor:
8989
spec: SentenceTransformerEmbed
9090
_model: Any | None = None
9191

92-
def analyze(self, _text: Any) -> type:
92+
def analyze(self) -> type:
9393
try:
9494
# Only import sentence_transformers locally when it's needed, as its import is very slow.
9595
import sentence_transformers # pylint: disable=import-outside-toplevel
@@ -245,7 +245,7 @@ class ColPaliEmbedImageExecutor:
245245
spec: ColPaliEmbedImage
246246
_model_info: ColPaliModelInfo
247247

248-
def analyze(self, _img_bytes: Any) -> type:
248+
def analyze(self) -> type:
249249
# Get shared model and dimension
250250
self._model_info = _get_colpali_model_and_processor(self.spec.model)
251251

@@ -321,7 +321,7 @@ class ColPaliEmbedQueryExecutor:
321321
spec: ColPaliEmbedQuery
322322
_model_info: ColPaliModelInfo
323323

324-
def analyze(self, _query: Any) -> type:
324+
def analyze(self) -> type:
325325
# Get shared model and dimension
326326
self._model_info = _get_colpali_model_and_processor(self.spec.model)
327327

python/cocoindex/op.py

Lines changed: 72 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@
1111
Awaitable,
1212
Callable,
1313
Protocol,
14+
ParamSpec,
15+
TypeVar,
16+
Type,
17+
cast,
1418
dataclass_transform,
1519
Annotated,
1620
get_args,
1721
)
1822

1923
from . import _engine # type: ignore
24+
from .subprocess_exec import executor_stub
2025
from .convert import (
2126
make_engine_value_encoder,
2227
make_engine_value_decoder,
@@ -85,11 +90,13 @@ class Executor(Protocol):
8590
op_category: OpCategory
8691

8792

88-
def _load_spec_from_engine(spec_cls: type, spec: dict[str, Any]) -> Any:
93+
def _load_spec_from_engine(
94+
spec_loader: Callable[..., Any], spec: dict[str, Any]
95+
) -> Any:
8996
"""
9097
Load a spec from the engine.
9198
"""
92-
return spec_cls(**spec)
99+
return spec_loader(**spec)
93100

94101

95102
def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
@@ -101,18 +108,18 @@ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
101108
return method
102109

103110

104-
class _FunctionExecutorFactory:
105-
_spec_cls: type
111+
class _EngineFunctionExecutorFactory:
112+
_spec_loader: Callable[..., Any]
106113
_executor_cls: type
107114

108-
def __init__(self, spec_cls: type, executor_cls: type):
109-
self._spec_cls = spec_cls
115+
def __init__(self, spec_loader: Callable[..., Any], executor_cls: type):
116+
self._spec_loader = spec_loader
110117
self._executor_cls = executor_cls
111118

112119
def __call__(
113120
self, spec: dict[str, Any], *args: Any, **kwargs: Any
114121
) -> tuple[dict[str, Any], Executor]:
115-
spec = _load_spec_from_engine(self._spec_cls, spec)
122+
spec = _load_spec_from_engine(self._spec_loader, spec)
116123
executor = self._executor_cls(spec)
117124
result_type = executor.analyze_schema(*args, **kwargs)
118125
return (result_type, executor)
@@ -166,31 +173,32 @@ def _register_op_factory(
166173
category: OpCategory,
167174
expected_args: list[tuple[str, inspect.Parameter]],
168175
expected_return: Any,
169-
executor_cls: type,
170-
spec_cls: type,
176+
executor_factory: Any,
177+
spec_loader: Callable[..., Any],
178+
op_kind: str,
171179
op_args: OpArgs,
172-
) -> type:
180+
) -> None:
173181
"""
174182
Register an op factory.
175183
"""
176184

177-
class _Fallback:
178-
def enable_cache(self) -> bool:
179-
return op_args.cache
180-
181-
def behavior_version(self) -> int | None:
182-
return op_args.behavior_version
183-
184-
class _WrappedClass(executor_cls, _Fallback): # type: ignore[misc]
185+
class _WrappedExecutor:
186+
_executor: Any
185187
_args_info: list[_ArgInfo]
186188
_kwargs_info: dict[str, _ArgInfo]
187-
_acall: Callable[..., Awaitable[Any]]
188189
_result_encoder: Callable[[Any], Any]
190+
_acall: Callable[..., Awaitable[Any]] | None = None
189191

190192
def __init__(self, spec: Any) -> None:
191-
super().__init__()
192-
self.spec = spec
193-
self._acall = _to_async_call(super().__call__)
193+
executor: Any
194+
195+
if op_args.gpu:
196+
executor = executor_stub(executor_factory, spec)
197+
else:
198+
executor = executor_factory()
199+
executor.spec = spec
200+
201+
self._executor = executor
194202

195203
def analyze_schema(
196204
self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema
@@ -294,9 +302,9 @@ def process_arg(
294302
if len(missing_args) > 0:
295303
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
296304

297-
base_analyze_method = getattr(self, "analyze", None)
305+
base_analyze_method = getattr(self._executor, "analyze", None)
298306
if base_analyze_method is not None:
299-
result_type = base_analyze_method(*args, **kwargs)
307+
result_type = base_analyze_method()
300308
else:
301309
result_type = expected_return
302310
if len(attributes) > 0:
@@ -316,9 +324,10 @@ async def prepare(self) -> None:
316324
Prepare for execution.
317325
It's executed after `analyze` and before any `__call__` execution.
318326
"""
319-
prepare_method = getattr(super(), "prepare", None)
327+
prepare_method = getattr(self._executor, "prepare", None)
320328
if prepare_method is not None:
321329
await _to_async_call(prepare_method)()
330+
self._acall = _to_async_call(self._executor.__call__)
322331

323332
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
324333
decoded_args = []
@@ -338,6 +347,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
338347
return None
339348
decoded_kwargs[kwarg_name] = kwarg_info.decoder(arg)
340349

350+
assert self._acall is not None
341351
if op_args.gpu:
342352
# For GPU executions, data-level parallelism is applied, so we don't want to
343353
# execute different tasks in parallel.
@@ -350,21 +360,19 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
350360
output = await self._acall(*decoded_args, **decoded_kwargs)
351361
return self._result_encoder(output)
352362

353-
_WrappedClass.__name__ = executor_cls.__name__
354-
_WrappedClass.__doc__ = executor_cls.__doc__
355-
_WrappedClass.__module__ = executor_cls.__module__
356-
_WrappedClass.__qualname__ = executor_cls.__qualname__
357-
_WrappedClass.__wrapped__ = executor_cls
363+
def enable_cache(self) -> bool:
364+
return op_args.cache
365+
366+
def behavior_version(self) -> int | None:
367+
return op_args.behavior_version
358368

359369
if category == OpCategory.FUNCTION:
360370
_engine.register_function_factory(
361-
spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass)
371+
op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor)
362372
)
363373
else:
364374
raise ValueError(f"Unsupported executor type {category}")
365375

366-
return _WrappedClass
367-
368376

369377
def executor_class(**args: Any) -> Callable[[type], type]:
370378
"""
@@ -382,18 +390,31 @@ def _inner(cls: type[Executor]) -> type:
382390
raise TypeError("Expect a `spec` field with type hint")
383391
spec_cls = resolve_forward_ref(type_hints["spec"])
384392
sig = inspect.signature(cls.__call__)
385-
return _register_op_factory(
393+
_register_op_factory(
386394
category=spec_cls._op_category,
387395
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
388396
expected_return=sig.return_annotation,
389-
executor_cls=cls,
390-
spec_cls=spec_cls,
397+
executor_factory=cls,
398+
spec_loader=spec_cls,
399+
op_kind=spec_cls.__name__,
391400
op_args=op_args,
392401
)
402+
return cls
393403

394404
return _inner
395405

396406

407+
class _EmptyFunctionSpec(FunctionSpec):
408+
pass
409+
410+
411+
class _SimpleFunctionExecutor:
412+
spec: Any
413+
414+
def prepare(self) -> None:
415+
self.__call__ = self.spec.__call__
416+
417+
397418
def function(**args: Any) -> Callable[[Callable[..., Any]], FunctionSpec]:
398419
"""
399420
Decorate a function to provide a function for an op.
@@ -404,30 +425,32 @@ def _inner(fn: Callable[..., Any]) -> FunctionSpec:
404425
# Convert snake case to camel case.
405426
op_name = "".join(word.capitalize() for word in fn.__name__.split("_"))
406427
sig = inspect.signature(fn)
428+
full_name = f"{fn.__module__}.{fn.__qualname__}"
407429

408-
class _Executor:
409-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
410-
return fn(*args, **kwargs)
430+
# An object that is both callable and can act as a FunctionSpec.
431+
class _CallableSpec(_EmptyFunctionSpec):
432+
__call__ = staticmethod(fn)
411433

412-
class _Spec(FunctionSpec):
413-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
414-
return fn(*args, **kwargs)
434+
def __reduce__(self) -> str | tuple[Any, ...]:
435+
return full_name
415436

416-
_Spec.__name__ = op_name
417-
_Spec.__doc__ = fn.__doc__
418-
_Spec.__module__ = fn.__module__
419-
_Spec.__qualname__ = fn.__qualname__
437+
_CallableSpec.__name__ = op_name
438+
_CallableSpec.__doc__ = fn.__doc__
439+
_CallableSpec.__qualname__ = fn.__qualname__
440+
_CallableSpec.__module__ = fn.__module__
441+
callable_spec = _CallableSpec()
420442

421443
_register_op_factory(
422444
category=OpCategory.FUNCTION,
423445
expected_args=list(sig.parameters.items()),
424446
expected_return=sig.return_annotation,
425-
executor_cls=_Executor,
426-
spec_cls=_Spec,
447+
executor_factory=_SimpleFunctionExecutor,
448+
spec_loader=lambda: callable_spec,
449+
op_kind=op_name,
427450
op_args=op_args,
428451
)
429452

430-
return _Spec()
453+
return callable_spec
431454

432455
return _inner
433456

0 commit comments

Comments
 (0)