Skip to content

Commit 75e0f74

Browse files
committed
feat(py-target): pass optional fields schema to target connector
1 parent 5a29038 commit 75e0f74

File tree

2 files changed

+140
-50
lines changed

2 files changed

+140
-50
lines changed

python/cocoindex/op.py

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from . import _engine # type: ignore
1919
from .subprocess_exec import executor_stub
2020
from .convert import (
21+
dump_engine_object,
2122
make_engine_value_encoder,
2223
make_engine_value_decoder,
2324
make_engine_key_decoder,
@@ -32,6 +33,7 @@
3233
AnalyzedDictType,
3334
EnrichedValueType,
3435
decode_engine_field_schemas,
36+
FieldSchema,
3537
)
3638
from .runtime import to_async_call
3739

@@ -432,16 +434,43 @@ class _TargetConnectorContext:
432434
target_name: str
433435
spec: Any
434436
prepared_spec: Any
437+
key_fields_schema: list[FieldSchema]
435438
key_decoder: Callable[[Any], Any]
439+
value_fields_schema: list[FieldSchema]
436440
value_decoder: Callable[[Any], Any]
437441

438442

443+
def _build_args(
444+
method: Callable[..., Any], num_required_args: int, **kwargs: Any
445+
) -> list[Any]:
446+
signature = inspect.signature(method)
447+
for param in signature.parameters.values():
448+
if param.kind not in (
449+
inspect.Parameter.POSITIONAL_ONLY,
450+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
451+
):
452+
raise ValueError(
453+
f"Method {method.__name__} should only have positional arguments, got {param.kind.name}"
454+
)
455+
if len(signature.parameters) < num_required_args:
456+
raise ValueError(
457+
f"Method {method.__name__} must have at least {num_required_args} required arguments: "
458+
f"{', '.join(list(kwargs.keys())[:num_required_args])}"
459+
)
460+
if len(kwargs) > len(kwargs):
461+
raise ValueError(
462+
f"Method {method.__name__} can only have at most {num_required_args} arguments: {', '.join(kwargs.keys())}"
463+
)
464+
return [v for _, v in zip(signature.parameters, kwargs.values())]
465+
466+
439467
class _TargetConnector:
440468
"""
441469
The connector class passed to the engine.
442470
"""
443471

444472
_spec_cls: type
473+
_state_cls: type
445474
_connector_cls: type
446475

447476
_get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
@@ -451,8 +480,9 @@ class _TargetConnector:
451480
_mutate_async_fn: Callable[..., Awaitable[None]]
452481
_mutatation_type: AnalyzedDictType | None
453482

454-
def __init__(self, spec_cls: type, connector_cls: type):
483+
def __init__(self, spec_cls: type, state_cls: type, connector_cls: type):
455484
self._spec_cls = spec_cls
485+
self._state_cls = state_cls
456486
self._connector_cls = connector_cls
457487

458488
self._get_persistent_key_fn = _get_required_method(
@@ -517,8 +547,8 @@ def create_export_context(
517547
self,
518548
name: str,
519549
spec: dict[str, Any],
520-
key_fields_schema: list[Any],
521-
value_fields_schema: list[Any],
550+
raw_key_fields_schema: list[Any],
551+
raw_value_fields_schema: list[Any],
522552
) -> _TargetConnectorContext:
523553
key_annotation, value_annotation = (
524554
(
@@ -529,36 +559,72 @@ def create_export_context(
529559
else (Any, Any)
530560
)
531561

562+
key_fields_schema = decode_engine_field_schemas(raw_key_fields_schema)
532563
key_decoder = make_engine_key_decoder(
533-
["(key)"],
534-
decode_engine_field_schemas(key_fields_schema),
535-
analyze_type_info(key_annotation),
564+
["<key>"], key_fields_schema, analyze_type_info(key_annotation)
536565
)
566+
value_fields_schema = decode_engine_field_schemas(raw_value_fields_schema)
537567
value_decoder = make_engine_struct_decoder(
538-
["(value)"],
539-
decode_engine_field_schemas(value_fields_schema),
540-
analyze_type_info(value_annotation),
568+
["<value>"], value_fields_schema, analyze_type_info(value_annotation)
541569
)
542570

543571
loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
544-
prepare_method = getattr(self._connector_cls, "prepare", None)
545-
if prepare_method is None:
546-
prepared_spec = loaded_spec
547-
else:
548-
prepared_spec = prepare_method(loaded_spec)
549-
550572
return _TargetConnectorContext(
551573
target_name=name,
552574
spec=loaded_spec,
553-
prepared_spec=prepared_spec,
575+
prepared_spec=None,
576+
key_fields_schema=key_fields_schema,
554577
key_decoder=key_decoder,
578+
value_fields_schema=value_fields_schema,
555579
value_decoder=value_decoder,
556580
)
557581

558582
def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
559-
return self._get_persistent_key_fn(
560-
export_context.spec, export_context.target_name
583+
args = _build_args(
584+
self._get_persistent_key_fn,
585+
1,
586+
spec=export_context.spec,
587+
target_name=export_context.target_name,
588+
)
589+
return dump_engine_object(self._get_persistent_key_fn(*args))
590+
591+
def get_setup_state(self, export_context: _TargetConnectorContext) -> Any:
592+
get_persistent_state_fn = getattr(self._connector_cls, "get_setup_state", None)
593+
if get_persistent_state_fn is None:
594+
state = export_context.spec
595+
if not isinstance(state, self._state_cls):
596+
raise ValueError(
597+
f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._state_cls}"
598+
)
599+
else:
600+
args = _build_args(
601+
get_persistent_state_fn,
602+
1,
603+
spec=export_context.spec,
604+
key_fields_schema=export_context.key_fields_schema,
605+
value_fields_schema=export_context.value_fields_schema,
606+
)
607+
state = get_persistent_state_fn(*args)
608+
if not isinstance(state, self._state_cls):
609+
raise ValueError(
610+
f"Method {get_persistent_state_fn.__name__} must return an instance of {self._state_cls}, got {type(state)}"
611+
)
612+
return dump_engine_object(state)
613+
614+
async def prepare_async(self, export_context: _TargetConnectorContext) -> None:
615+
prepare_fn = getattr(self._connector_cls, "prepare", None)
616+
if prepare_fn is None:
617+
export_context.prepared_spec = export_context.spec
618+
return
619+
args = _build_args(
620+
prepare_fn,
621+
1,
622+
spec=export_context.spec,
623+
key_fields_schema=export_context.key_fields_schema,
624+
value_fields_schema=export_context.value_fields_schema,
561625
)
626+
async_prepare_fn = to_async_call(prepare_fn)
627+
export_context.prepared_spec = await async_prepare_fn(*args)
562628

563629
def describe_resource(self, key: Any) -> str:
564630
describe_fn = getattr(self._connector_cls, "describe", None)
@@ -572,13 +638,13 @@ async def apply_setup_changes_async(
572638
) -> None:
573639
for key, previous, current in changes:
574640
prev_specs = [
575-
_load_spec_from_engine(self._spec_cls, spec)
641+
_load_spec_from_engine(self._state_cls, spec)
576642
if spec is not None
577643
else None
578644
for spec in previous
579645
]
580646
curr_spec = (
581-
_load_spec_from_engine(self._spec_cls, current)
647+
_load_spec_from_engine(self._state_cls, current)
582648
if current is not None
583649
else None
584650
)
@@ -611,7 +677,9 @@ async def mutate_async(
611677
)
612678

613679

614-
def target_connector(spec_cls: type) -> Callable[[type], type]:
680+
def target_connector(
681+
spec_cls: type, state_cls: type | None = None
682+
) -> Callable[[type], type]:
615683
"""
616684
Decorate a class to provide a target connector for an op.
617685
"""
@@ -622,7 +690,7 @@ def target_connector(spec_cls: type) -> Callable[[type], type]:
622690

623691
# Register the target connector.
624692
def _inner(connector_cls: type) -> type:
625-
connector = _TargetConnector(spec_cls, connector_cls)
693+
connector = _TargetConnector(spec_cls, state_cls or spec_cls, connector_cls)
626694
_engine.register_target_connector(spec_cls.__name__, connector)
627695
return connector_cls
628696

src/ops/py_factory.rs

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -283,43 +283,65 @@ impl interface::TargetFactory for PyExportTargetFactory {
283283
.ok_or_else(|| anyhow!("Python execution context is missing"))?
284284
.clone();
285285
for data_collection in data_collections.into_iter() {
286-
let (py_export_ctx, persistent_key) =
287-
Python::with_gil(|py| -> Result<(Py<PyAny>, serde_json::Value)> {
288-
// Deserialize the spec to Python object.
289-
let py_export_ctx = self
290-
.py_target_connector
291-
.call_method(
292-
py,
293-
"create_export_context",
294-
(
295-
&data_collection.name,
296-
pythonize(py, &data_collection.spec)?,
297-
pythonize(py, &data_collection.key_fields_schema)?,
298-
pythonize(py, &data_collection.value_fields_schema)?,
299-
),
300-
None,
301-
)
302-
.to_result_with_py_trace(py)?;
303-
304-
// Call the `get_persistent_key` method to get the persistent key.
305-
let persistent_key = self
306-
.py_target_connector
307-
.call_method(py, "get_persistent_key", (&py_export_ctx,), None)
308-
.to_result_with_py_trace(py)?;
309-
let persistent_key = depythonize(&persistent_key.into_bound(py))?;
310-
Ok((py_export_ctx, persistent_key))
311-
})?;
286+
let (py_export_ctx, persistent_key, setup_state) = Python::with_gil(|py| {
287+
// Deserialize the spec to Python object.
288+
let py_export_ctx = self
289+
.py_target_connector
290+
.call_method(
291+
py,
292+
"create_export_context",
293+
(
294+
&data_collection.name,
295+
pythonize(py, &data_collection.spec)?,
296+
pythonize(py, &data_collection.key_fields_schema)?,
297+
pythonize(py, &data_collection.value_fields_schema)?,
298+
),
299+
None,
300+
)
301+
.to_result_with_py_trace(py)?;
302+
303+
// Call the `get_persistent_key` method to get the persistent key.
304+
let persistent_key = self
305+
.py_target_connector
306+
.call_method(py, "get_persistent_key", (&py_export_ctx,), None)
307+
.to_result_with_py_trace(py)?;
308+
let persistent_key: serde_json::Value =
309+
depythonize(&persistent_key.into_bound(py))?;
312310

311+
let setup_state = self
312+
.py_target_connector
313+
.call_method(py, "get_setup_state", (&py_export_ctx,), None)
314+
.to_result_with_py_trace(py)?;
315+
let setup_state: serde_json::Value = depythonize(&setup_state.into_bound(py))?;
316+
317+
anyhow::Ok((py_export_ctx, persistent_key, setup_state))
318+
})?;
319+
320+
let factory = self.clone();
313321
let py_exec_ctx = py_exec_ctx.clone();
314322
let build_output = interface::ExportDataCollectionBuildOutput {
315323
export_context: Box::pin(async move {
316-
Ok(Arc::new(PyTargetExecutorContext {
324+
Python::with_gil(|py| {
325+
let prepare_coro = factory
326+
.py_target_connector
327+
.call_method(py, "prepare_async", (&py_export_ctx,), None)
328+
.to_result_with_py_trace(py)?;
329+
let task_locals = pyo3_async_runtimes::TaskLocals::new(
330+
py_exec_ctx.event_loop.bind(py).clone(),
331+
);
332+
anyhow::Ok(pyo3_async_runtimes::into_future_with_locals(
333+
&task_locals,
334+
prepare_coro.into_bound(py),
335+
)?)
336+
})?
337+
.await?;
338+
anyhow::Ok(Arc::new(PyTargetExecutorContext {
317339
py_export_ctx,
318340
py_exec_ctx,
319341
}) as Arc<dyn Any + Send + Sync>)
320342
}),
321343
setup_key: persistent_key,
322-
desired_setup_state: data_collection.spec,
344+
desired_setup_state: setup_state,
323345
};
324346
build_outputs.push(build_output);
325347
}

0 commit comments

Comments
 (0)