Skip to content

Commit eed25ec

Browse files
authored
feat(attr): expose Python API to configure arg relation attributes (#784)
1 parent 0a36a8f commit eed25ec

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

python/cocoindex/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"targets",
3939
"storages",
4040
"cli",
41+
"op",
4142
"utils",
4243
# Auth registry
4344
"AuthEntryReference",

python/cocoindex/functions.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,19 @@ class SentenceTransformerEmbed(op.FunctionSpec):
6666
args: dict[str, Any] | None = None
6767

6868

69-
@op.executor_class(gpu=True, cache=True, behavior_version=1)
69+
@op.executor_class(
70+
gpu=True,
71+
cache=True,
72+
behavior_version=1,
73+
related_arg_attr=(op.RelatedFieldAttribute.VECTOR_ORIGIN_TEXT, "text"),
74+
)
7075
class SentenceTransformerEmbedExecutor:
7176
"""Executor for SentenceTransformerEmbed."""
7277

7378
spec: SentenceTransformerEmbed
7479
_model: Any | None = None
7580

76-
def analyze(self, text: Any) -> type:
81+
def analyze(self, _text: Any) -> type:
7782
try:
7883
# Only import sentence_transformers locally when it's needed, as its import is very slow.
7984
import sentence_transformers # pylint: disable=import-outside-toplevel
@@ -88,11 +93,7 @@ def analyze(self, text: Any) -> type:
8893
args = self.spec.args or {}
8994
self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
9095
dim = self._model.get_sentence_embedding_dimension()
91-
result: type = Annotated[
92-
Vector[np.float32, Literal[dim]], # type: ignore
93-
TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value),
94-
]
95-
return result
96+
return Vector[np.float32, Literal[dim]] # type: ignore
9697

9798
def __call__(self, text: str) -> NDArray[np.float32]:
9899
assert self._model is not None

python/cocoindex/op.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import dataclasses
77
import inspect
88
from enum import Enum
9-
from typing import Any, Awaitable, Callable, Protocol, dataclass_transform
9+
from typing import Any, Awaitable, Callable, Protocol, dataclass_transform, Annotated
1010

1111
from . import _engine # type: ignore
1212
from .convert import encode_engine_value, make_engine_value_decoder
13-
from .typing import encode_enriched_type, resolve_forward_ref
13+
from .typing import TypeAttr, encode_enriched_type, resolve_forward_ref
1414

1515

1616
class OpCategory(Enum):
@@ -85,18 +85,33 @@ def __call__(
8585
_gpu_dispatch_lock = asyncio.Lock()
8686

8787

88+
_COCOINDEX_ATTR_PREFIX = "cocoindex.io/"
89+
90+
91+
class RelatedFieldAttribute(Enum):
92+
"""The attribute of a field that is related to the op."""
93+
94+
VECTOR_ORIGIN_TEXT = _COCOINDEX_ATTR_PREFIX + "vector_origin_text"
95+
CHUNKS_BASE_TEXT = _COCOINDEX_ATTR_PREFIX + "chunk_base_text"
96+
RECTS_BASE_IMAGE = _COCOINDEX_ATTR_PREFIX + "rects_base_image"
97+
98+
8899
@dataclasses.dataclass
89100
class OpArgs:
90101
"""
91102
- gpu: Whether the executor will be executed on GPU.
92103
- cache: Whether the executor will be cached.
93104
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
94105
changes. Must be provided if `cache` is True.
106+
- related_arg_attr: It specifies the relationship between an input argument and the output,
107+
e.g. `(RelatedFieldAttribute.CHUNKS_BASE_TEXT, "content")` means the output is chunks for the
108+
input argument with name `content`.
95109
"""
96110

97111
gpu: bool = False
98112
cache: bool = False
99113
behavior_version: int | None = None
114+
related_arg_attr: tuple[RelatedFieldAttribute, str] | None = None
100115

101116

102117
def _to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
@@ -143,6 +158,15 @@ def analyze(
143158
"""
144159
self._args_decoders = []
145160
self._kwargs_decoders = {}
161+
attributes = []
162+
163+
def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
164+
if op_args.related_arg_attr is not None:
165+
related_attr, related_arg_name = op_args.related_arg_attr
166+
if related_arg_name == arg_name:
167+
attributes.append(
168+
TypeAttr(related_attr.value, arg.analyzed_value)
169+
)
146170

147171
# Match arguments with parameters.
148172
next_param_idx = 0
@@ -164,6 +188,7 @@ def analyze(
164188
[arg_name], arg.value_type["type"], arg_param.annotation
165189
)
166190
)
191+
process_attribute(arg_name, arg)
167192
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
168193
next_param_idx += 1
169194

@@ -194,6 +219,7 @@ def analyze(
194219
self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
195220
[kwarg_name], kwarg.value_type["type"], arg_param.annotation
196221
)
222+
process_attribute(kwarg_name, kwarg)
197223

198224
missing_args = [
199225
name
@@ -216,9 +242,12 @@ def analyze(
216242

217243
prepare_method = getattr(executor_cls, "analyze", None)
218244
if prepare_method is not None:
219-
return prepare_method(self, *args, **kwargs)
245+
result = prepare_method(self, *args, **kwargs)
220246
else:
221-
return expected_return
247+
result = expected_return
248+
if len(attributes) > 0:
249+
result = Annotated[result, *attributes]
250+
return result
222251

223252
async def prepare(self) -> None:
224253
"""

0 commit comments

Comments
 (0)