Skip to content

Commit f1dfc5e

Browse files
authored
feat(custom-target): support customized target in engine and Python SDK (#812)
* feat(custom-target): support custom target from the engine side * chore: update method name to call from engine * feat(custom-target): support from Python SDK - setup logic * feat(custom-target): support from Python SDK - mutate logic
1 parent 3b7eee9 commit f1dfc5e

File tree

5 files changed

+537
-19
lines changed

5 files changed

+537
-19
lines changed

python/cocoindex/op.py

Lines changed: 237 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,30 @@
66
import dataclasses
77
import inspect
88
from enum import Enum
9-
from typing import Any, Awaitable, Callable, Protocol, dataclass_transform, Annotated
9+
from typing import (
10+
Any,
11+
Awaitable,
12+
Callable,
13+
Protocol,
14+
dataclass_transform,
15+
Annotated,
16+
get_args,
17+
)
1018

1119
from . import _engine # type: ignore
12-
from .convert import encode_engine_value, make_engine_value_decoder
13-
from .typing import TypeAttr, encode_enriched_type, resolve_forward_ref
20+
from .convert import (
21+
encode_engine_value,
22+
make_engine_value_decoder,
23+
make_engine_struct_decoder,
24+
)
25+
from .typing import (
26+
TypeAttr,
27+
encode_enriched_type,
28+
resolve_forward_ref,
29+
analyze_type_info,
30+
AnalyzedAnyType,
31+
AnalyzedDictType,
32+
)
1433

1534

1635
class OpCategory(Enum):
@@ -65,6 +84,22 @@ class Executor(Protocol):
6584
op_category: OpCategory
6685

6786

87+
def _load_spec_from_engine(spec_cls: type, spec: dict[str, Any]) -> Any:
88+
"""
89+
Load a spec from the engine.
90+
"""
91+
return spec_cls(**spec)
92+
93+
94+
def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
95+
method = getattr(cls, name, None)
96+
if method is None:
97+
raise ValueError(f"Method {name}() is required for {cls.__name__}")
98+
if not inspect.isfunction(method):
99+
raise ValueError(f"Method {cls.__name__}.{name}() is not a function")
100+
return method
101+
102+
68103
class _FunctionExecutorFactory:
69104
_spec_cls: type
70105
_executor_cls: type
@@ -76,7 +111,7 @@ def __init__(self, spec_cls: type, executor_cls: type):
76111
def __call__(
77112
self, spec: dict[str, Any], *args: Any, **kwargs: Any
78113
) -> tuple[dict[str, Any], Executor]:
79-
spec = self._spec_cls(**spec)
114+
spec = _load_spec_from_engine(self._spec_cls, spec)
80115
executor = self._executor_cls(spec)
81116
result_type = executor.analyze(*args, **kwargs)
82117
return (encode_enriched_type(result_type), executor)
@@ -359,3 +394,201 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
359394
return _Spec()
360395

361396
return _inner
397+
398+
399+
########################################################
400+
# Customized target connector
401+
########################################################
402+
403+
404+
@dataclasses.dataclass
405+
class _TargetConnectorContext:
406+
target_name: str
407+
spec: Any
408+
key_decoder: Callable[[Any], Any]
409+
value_decoder: Callable[[Any], Any]
410+
411+
412+
class _TargetConnector:
413+
"""
414+
The connector class passed to the engine.
415+
"""
416+
417+
_spec_cls: type
418+
_connector_cls: type
419+
420+
_get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
421+
_apply_setup_change_async_fn: Callable[
422+
[Any, dict[str, Any] | None, dict[str, Any] | None], Awaitable[None]
423+
]
424+
_mutate_async_fn: Callable[..., Awaitable[None]]
425+
_mutatation_type: AnalyzedDictType | None
426+
427+
def __init__(self, spec_cls: type, connector_cls: type):
428+
self._spec_cls = spec_cls
429+
self._connector_cls = connector_cls
430+
431+
self._get_persistent_key_fn = _get_required_method(
432+
connector_cls, "get_persistent_key"
433+
)
434+
self._apply_setup_change_async_fn = _to_async_call(
435+
_get_required_method(connector_cls, "apply_setup_change")
436+
)
437+
438+
mutate_fn = _get_required_method(connector_cls, "mutate")
439+
self._mutate_async_fn = _to_async_call(mutate_fn)
440+
441+
# Store the type annotation for later use
442+
self._mutatation_type = self._analyze_mutate_mutation_type(
443+
connector_cls, mutate_fn
444+
)
445+
446+
@staticmethod
447+
def _analyze_mutate_mutation_type(
448+
connector_cls: type, mutate_fn: Callable[..., Any]
449+
) -> AnalyzedDictType | None:
450+
# Validate mutate_fn signature and extract type annotation
451+
mutate_sig = inspect.signature(mutate_fn)
452+
params = list(mutate_sig.parameters.values())
453+
454+
if len(params) != 1:
455+
raise ValueError(
456+
f"Method {connector_cls.__name__}.mutate(*args) must have exactly one parameter, "
457+
f"got {len(params)}"
458+
)
459+
460+
param = params[0]
461+
if param.kind != inspect.Parameter.VAR_POSITIONAL:
462+
raise ValueError(
463+
f"Method {connector_cls.__name__}.mutate(*args) parameter must be *args format, "
464+
f"got {param.kind.name}"
465+
)
466+
467+
# Extract type annotation
468+
analyzed_args_type = analyze_type_info(param.annotation)
469+
if isinstance(analyzed_args_type.variant, AnalyzedAnyType):
470+
return None
471+
472+
if analyzed_args_type.base_type is tuple:
473+
args = get_args(analyzed_args_type.core_type)
474+
if not args:
475+
return None
476+
if len(args) == 2:
477+
mutation_type = analyze_type_info(args[1])
478+
if isinstance(mutation_type.variant, AnalyzedAnyType):
479+
return None
480+
if isinstance(mutation_type.variant, AnalyzedDictType):
481+
return mutation_type.variant
482+
483+
raise ValueError(
484+
f"Method {connector_cls.__name__}.mutate(*args) parameter must be a tuple with "
485+
f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
486+
"got {args_type}"
487+
)
488+
489+
def create_export_context(
490+
self,
491+
name: str,
492+
spec: dict[str, Any],
493+
key_fields_schema: list[Any],
494+
value_fields_schema: list[Any],
495+
) -> _TargetConnectorContext:
496+
key_annotation, value_annotation = (
497+
(
498+
self._mutatation_type.key_type,
499+
self._mutatation_type.value_type,
500+
)
501+
if self._mutatation_type is not None
502+
else (None, None)
503+
)
504+
505+
if len(key_fields_schema) == 1:
506+
key_decoder = make_engine_value_decoder(
507+
["(key)"], key_fields_schema[0]["type"], key_annotation
508+
)
509+
else:
510+
key_decoder = make_engine_struct_decoder(
511+
["(key)"], key_fields_schema, analyze_type_info(key_annotation)
512+
)
513+
514+
value_decoder = make_engine_struct_decoder(
515+
["(value)"], value_fields_schema, analyze_type_info(value_annotation)
516+
)
517+
518+
return _TargetConnectorContext(
519+
target_name=name,
520+
spec=_load_spec_from_engine(self._spec_cls, spec),
521+
key_decoder=key_decoder,
522+
value_decoder=value_decoder,
523+
)
524+
525+
def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
526+
return self._get_persistent_key_fn(
527+
export_context.spec, export_context.target_name
528+
)
529+
530+
def describe_resource(self, key: Any) -> str:
531+
describe_fn = getattr(self._connector_cls, "describe", None)
532+
if describe_fn is None:
533+
return str(key)
534+
return str(describe_fn(key))
535+
536+
async def apply_setup_changes_async(
537+
self,
538+
changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
539+
) -> None:
540+
for key, previous, current in changes:
541+
prev_specs = [
542+
_load_spec_from_engine(self._spec_cls, spec)
543+
if spec is not None
544+
else None
545+
for spec in previous
546+
]
547+
curr_spec = (
548+
_load_spec_from_engine(self._spec_cls, current)
549+
if current is not None
550+
else None
551+
)
552+
for prev_spec in prev_specs:
553+
await self._apply_setup_change_async_fn(key, prev_spec, curr_spec)
554+
555+
@staticmethod
556+
def _decode_mutation(
557+
context: _TargetConnectorContext, mutation: list[tuple[Any, Any | None]]
558+
) -> tuple[Any, dict[Any, Any | None]]:
559+
return (
560+
context.spec,
561+
{
562+
context.key_decoder(key): context.value_decoder(value)
563+
for key, value in mutation
564+
},
565+
)
566+
567+
async def mutate_async(
568+
self,
569+
mutations: list[tuple[_TargetConnectorContext, list[tuple[Any, Any | None]]]],
570+
) -> None:
571+
await self._mutate_async_fn(
572+
*(
573+
self._decode_mutation(context, mutation)
574+
for context, mutation in mutations
575+
)
576+
)
577+
578+
579+
def target_connector(spec_cls: type) -> Callable[[type], type]:
580+
"""
581+
Decorate a class to provide a target connector for an op.
582+
"""
583+
584+
# Validate the spec_cls is a TargetSpec.
585+
if not issubclass(spec_cls, TargetSpec):
586+
raise ValueError(f"Expect a TargetSpec, got {spec_cls}")
587+
588+
# Register the target connector.
589+
def _inner(connector_cls: type) -> type:
590+
connector = _TargetConnector(spec_cls, connector_cls)
591+
_engine.register_target_connector(spec_cls.__name__, connector)
592+
return connector_cls
593+
594+
return _inner

src/ops/interface.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ pub trait ExportTargetFactory: Send + Sync {
272272
key: &serde_json::Value,
273273
desired_state: Option<serde_json::Value>,
274274
existing_states: setup::CombinedState<serde_json::Value>,
275-
context: Arc<FlowInstanceContext>,
275+
context: Arc<interface::FlowInstanceContext>,
276276
) -> Result<Box<dyn setup::ResourceSetupStatus>>;
277277

278278
/// Normalize the key. e.g. the JSON format may change (after code change, e.g. new optional field or field ordering), even if the underlying value is not changed.

0 commit comments

Comments
 (0)