Skip to content

Commit 46aeda9

Browse files
committed
feat(custom-target): support compatibility and custom setup key type
1 parent 1e815b2 commit 46aeda9

File tree

3 files changed

+73
-19
lines changed

3 files changed

+73
-19
lines changed

python/cocoindex/op.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -458,13 +458,22 @@ def _build_args(
458458
return [v for _, v in zip(signature.parameters, kwargs.values())]
459459

460460

461+
class TargetStateCompatibility(Enum):
462+
"""The compatibility of the target state."""
463+
464+
COMPATIBLE = "Compatible"
465+
PARTIALLY_COMPATIBLE = "PartialCompatible"
466+
NOT_COMPATIBLE = "NotCompatible"
467+
468+
461469
class _TargetConnector:
462470
"""
463471
The connector class passed to the engine.
464472
"""
465473

466474
_spec_cls: type[Any]
467-
_state_cls: type[Any]
475+
_setup_key_type: Any
476+
_setup_state_cls: type[Any]
468477
_connector_cls: type[Any]
469478

470479
_get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
@@ -475,10 +484,15 @@ class _TargetConnector:
475484
_mutatation_type: AnalyzedDictType | None
476485

477486
def __init__(
478-
self, spec_cls: type[Any], state_cls: type[Any], connector_cls: type[Any]
487+
self,
488+
spec_cls: type[Any],
489+
setup_key_type: Any,
490+
setup_state_cls: type[Any],
491+
connector_cls: type[Any],
479492
):
480493
self._spec_cls = spec_cls
481-
self._state_cls = state_cls
494+
self._setup_key_type = setup_key_type
495+
self._setup_state_cls = setup_state_cls
482496
self._connector_cls = connector_cls
483497

484498
self._get_persistent_key_fn = _get_required_method(
@@ -591,9 +605,9 @@ def get_setup_state(self, export_context: _TargetConnectorContext) -> Any:
591605
get_persistent_state_fn = getattr(self._connector_cls, "get_setup_state", None)
592606
if get_persistent_state_fn is None:
593607
state = export_context.spec
594-
if not isinstance(state, self._state_cls):
608+
if not isinstance(state, self._setup_state_cls):
595609
raise ValueError(
596-
f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._state_cls}"
610+
f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._setup_state_cls}"
597611
)
598612
else:
599613
args = _build_args(
@@ -605,12 +619,31 @@ def get_setup_state(self, export_context: _TargetConnectorContext) -> Any:
605619
index_options=export_context.index_options,
606620
)
607621
state = get_persistent_state_fn(*args)
608-
if not isinstance(state, self._state_cls):
622+
if not isinstance(state, self._setup_state_cls):
609623
raise ValueError(
610-
f"Method {get_persistent_state_fn.__name__} must return an instance of {self._state_cls}, got {type(state)}"
624+
f"Method {get_persistent_state_fn.__name__} must return an instance of {self._setup_state_cls}, got {type(state)}"
611625
)
612626
return dump_engine_object(state)
613627

628+
def check_state_compatibility(
629+
self, raw_desired_state: Any, raw_existing_state: Any
630+
) -> Any:
631+
check_state_compatibility_fn = getattr(
632+
self._connector_cls, "check_state_compatibility", None
633+
)
634+
if check_state_compatibility_fn is not None:
635+
compatibility = check_state_compatibility_fn(
636+
load_engine_object(self._setup_state_cls, raw_desired_state),
637+
load_engine_object(self._setup_state_cls, raw_existing_state),
638+
)
639+
else:
640+
compatibility = (
641+
TargetStateCompatibility.COMPATIBLE
642+
if raw_desired_state == raw_existing_state
643+
else TargetStateCompatibility.PARTIALLY_COMPATIBLE
644+
)
645+
return dump_engine_object(compatibility)
646+
614647
async def prepare_async(self, export_context: _TargetConnectorContext) -> None:
615648
prepare_fn = getattr(self._connector_cls, "prepare", None)
616649
if prepare_fn is None:
@@ -626,7 +659,8 @@ async def prepare_async(self, export_context: _TargetConnectorContext) -> None:
626659
async_prepare_fn = to_async_call(prepare_fn)
627660
export_context.prepared_spec = await async_prepare_fn(*args)
628661

629-
def describe_resource(self, key: Any) -> str:
662+
def describe_resource(self, raw_key: Any) -> str:
663+
key = load_engine_object(self._setup_key_type, raw_key)
630664
describe_fn = getattr(self._connector_cls, "describe", None)
631665
if describe_fn is None:
632666
return str(key)
@@ -636,13 +670,16 @@ async def apply_setup_changes_async(
636670
self,
637671
changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
638672
) -> None:
639-
for key, previous, current in changes:
673+
for raw_key, previous, current in changes:
674+
key = load_engine_object(self._setup_key_type, raw_key)
640675
prev_specs = [
641-
load_engine_object(self._state_cls, spec) if spec is not None else None
676+
load_engine_object(self._setup_state_cls, spec)
677+
if spec is not None
678+
else None
642679
for spec in previous
643680
]
644681
curr_spec = (
645-
load_engine_object(self._state_cls, current)
682+
load_engine_object(self._setup_state_cls, current)
646683
if current is not None
647684
else None
648685
)
@@ -676,7 +713,10 @@ async def mutate_async(
676713

677714

678715
def target_connector(
679-
spec_cls: type[Any], state_cls: type[Any] | None = None
716+
*,
717+
spec_cls: type[Any],
718+
setup_key_type: Any = Any,
719+
setup_state_cls: type[Any] | None = None,
680720
) -> Callable[[type], type]:
681721
"""
682722
Decorate a class to provide a target connector for an op.
@@ -688,7 +728,9 @@ def target_connector(
688728

689729
# Register the target connector.
690730
def _inner(connector_cls: type) -> type:
691-
connector = _TargetConnector(spec_cls, state_cls or spec_cls, connector_cls)
731+
connector = _TargetConnector(
732+
spec_cls, setup_key_type, setup_state_cls or spec_cls, connector_cls
733+
)
692734
_engine.register_target_connector(spec_cls.__name__, connector)
693735
return connector_cls
694736

src/ops/interface.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ pub struct ResourceSetupChangeItem<'a> {
236236
pub setup_change: &'a dyn setup::ResourceSetupChange,
237237
}
238238

239-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
239+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
240240
pub enum SetupStateCompatibility {
241241
/// The resource is fully compatible with the desired state.
242242
/// This means the resource can be updated to the desired state without any loss of data.

src/ops/py_factory.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,23 @@ impl interface::TargetFactory for PyExportTargetFactory {
382382
desired_state: &serde_json::Value,
383383
existing_state: &serde_json::Value,
384384
) -> Result<SetupStateCompatibility> {
385-
Ok(if desired_state == existing_state {
386-
SetupStateCompatibility::Compatible
387-
} else {
388-
SetupStateCompatibility::PartialCompatible
389-
})
385+
let compatibility = Python::with_gil(|py| -> Result<_> {
386+
let result = self
387+
.py_target_connector
388+
.call_method(
389+
py,
390+
"check_state_compatibility",
391+
(
392+
pythonize(py, desired_state)?,
393+
pythonize(py, existing_state)?,
394+
),
395+
None,
396+
)
397+
.to_result_with_py_trace(py)?;
398+
let compatibility: SetupStateCompatibility = depythonize(&result.into_bound(py))?;
399+
Ok(compatibility)
400+
})?;
401+
Ok(compatibility)
390402
}
391403

392404
fn describe_resource(&self, key: &serde_json::Value) -> Result<String> {

0 commit comments

Comments
 (0)