66import dataclasses
77import inspect
88from enum import Enum
9- from typing import Any , Awaitable , Callable , Protocol , dataclass_transform
9+ from typing import Any , Awaitable , Callable , Protocol , dataclass_transform , Annotated
1010
1111from . import _engine # type: ignore
1212from .convert import encode_engine_value , make_engine_value_decoder
13- from .typing import encode_enriched_type , resolve_forward_ref
13+ from .typing import TypeAttr , encode_enriched_type , resolve_forward_ref
1414
1515
1616class OpCategory (Enum ):
@@ -85,18 +85,33 @@ def __call__(
8585_gpu_dispatch_lock = asyncio .Lock ()
8686
8787
88+ _COCOINDEX_ATTR_PREFIX = "cocoindex.io/"
89+
90+
91+ class RelatedFieldAttribute (Enum ):
92+ """The attribute of a field that is related to the op."""
93+
94+ VECTOR_ORIGIN_TEXT = _COCOINDEX_ATTR_PREFIX + "vector_origin_text"
95+ CHUNKS_BASE_TEXT = _COCOINDEX_ATTR_PREFIX + "chunk_base_text"
96+ RECTS_BASE_IMAGE = _COCOINDEX_ATTR_PREFIX + "rects_base_image"
97+
98+
8899@dataclasses .dataclass
89100class OpArgs :
90101 """
91102 - gpu: Whether the executor will be executed on GPU.
92103 - cache: Whether the executor will be cached.
93104 - behavior_version: The behavior version of the executor. Cache will be invalidated if it
94105 changes. Must be provided if `cache` is True.
106+ - related_arg_attr: It specifies the relationship between an input argument and the output,
107+ e.g. `(RelatedFieldAttribute.CHUNKS_BASE_TEXT, "content")` means the output is chunks for the
108+ input argument with name `content`.
95109 """
96110
97111 gpu : bool = False
98112 cache : bool = False
99113 behavior_version : int | None = None
114+ related_arg_attr : tuple [RelatedFieldAttribute , str ] | None = None
100115
101116
102117def _to_async_call (call : Callable [..., Any ]) -> Callable [..., Awaitable [Any ]]:
@@ -143,6 +158,15 @@ def analyze(
143158 """
144159 self ._args_decoders = []
145160 self ._kwargs_decoders = {}
161+ attributes = []
162+
163+ def process_attribute (arg_name : str , arg : _engine .OpArgSchema ) -> None :
164+ if op_args .related_arg_attr is not None :
165+ related_attr , related_arg_name = op_args .related_arg_attr
166+ if related_arg_name == arg_name :
167+ attributes .append (
168+ TypeAttr (related_attr .value , arg .analyzed_value )
169+ )
146170
147171 # Match arguments with parameters.
148172 next_param_idx = 0
@@ -164,6 +188,7 @@ def analyze(
164188 [arg_name ], arg .value_type ["type" ], arg_param .annotation
165189 )
166190 )
191+ process_attribute (arg_name , arg )
167192 if arg_param .kind != inspect .Parameter .VAR_POSITIONAL :
168193 next_param_idx += 1
169194
@@ -194,6 +219,7 @@ def analyze(
194219 self ._kwargs_decoders [kwarg_name ] = make_engine_value_decoder (
195220 [kwarg_name ], kwarg .value_type ["type" ], arg_param .annotation
196221 )
222+ process_attribute (kwarg_name , kwarg )
197223
198224 missing_args = [
199225 name
@@ -216,9 +242,12 @@ def analyze(
216242
217243 prepare_method = getattr (executor_cls , "analyze" , None )
218244 if prepare_method is not None :
219- return prepare_method (self , * args , ** kwargs )
245+ result = prepare_method (self , * args , ** kwargs )
220246 else :
221- return expected_return
247+ result = expected_return
248+ if len (attributes ) > 0 :
249+ result = Annotated [result , * attributes ]
250+ return result
222251
223252 async def prepare (self ) -> None :
224253 """
0 commit comments