1818from . import _engine # type: ignore
1919from .subprocess_exec import executor_stub
2020from .convert import (
21+ dump_engine_object ,
2122 make_engine_value_encoder ,
2223 make_engine_value_decoder ,
2324 make_engine_key_decoder ,
3233 AnalyzedDictType ,
3334 EnrichedValueType ,
3435 decode_engine_field_schemas ,
36+ FieldSchema ,
3537)
3638from .runtime import to_async_call
3739
@@ -432,16 +434,43 @@ class _TargetConnectorContext:
432434 target_name : str
433435 spec : Any
434436 prepared_spec : Any
437+ key_fields_schema : list [FieldSchema ]
435438 key_decoder : Callable [[Any ], Any ]
439+ value_fields_schema : list [FieldSchema ]
436440 value_decoder : Callable [[Any ], Any ]
437441
438442
443+ def _build_args (
444+ method : Callable [..., Any ], num_required_args : int , ** kwargs : Any
445+ ) -> list [Any ]:
446+ signature = inspect .signature (method )
447+ for param in signature .parameters .values ():
448+ if param .kind not in (
449+ inspect .Parameter .POSITIONAL_ONLY ,
450+ inspect .Parameter .POSITIONAL_OR_KEYWORD ,
451+ ):
452+ raise ValueError (
453+ f"Method { method .__name__ } should only have positional arguments, got { param .kind .name } "
454+ )
455+ if len (signature .parameters ) < num_required_args :
456+ raise ValueError (
457+ f"Method { method .__name__ } must have at least { num_required_args } required arguments: "
458+ f"{ ', ' .join (list (kwargs .keys ())[:num_required_args ])} "
459+ )
460+ if len (kwargs ) > len (kwargs ):
461+ raise ValueError (
462+ f"Method { method .__name__ } can only have at most { num_required_args } arguments: { ', ' .join (kwargs .keys ())} "
463+ )
464+ return [v for _ , v in zip (signature .parameters , kwargs .values ())]
465+
466+
439467class _TargetConnector :
440468 """
441469 The connector class passed to the engine.
442470 """
443471
444472 _spec_cls : type
473+ _state_cls : type
445474 _connector_cls : type
446475
447476 _get_persistent_key_fn : Callable [[_TargetConnectorContext , str ], Any ]
@@ -451,8 +480,9 @@ class _TargetConnector:
451480 _mutate_async_fn : Callable [..., Awaitable [None ]]
452481 _mutatation_type : AnalyzedDictType | None
453482
454- def __init__ (self , spec_cls : type , connector_cls : type ):
483+ def __init__ (self , spec_cls : type , state_cls : type , connector_cls : type ):
455484 self ._spec_cls = spec_cls
485+ self ._state_cls = state_cls
456486 self ._connector_cls = connector_cls
457487
458488 self ._get_persistent_key_fn = _get_required_method (
@@ -517,8 +547,8 @@ def create_export_context(
517547 self ,
518548 name : str ,
519549 spec : dict [str , Any ],
520- key_fields_schema : list [Any ],
521- value_fields_schema : list [Any ],
550+ raw_key_fields_schema : list [Any ],
551+ raw_value_fields_schema : list [Any ],
522552 ) -> _TargetConnectorContext :
523553 key_annotation , value_annotation = (
524554 (
@@ -529,36 +559,72 @@ def create_export_context(
529559 else (Any , Any )
530560 )
531561
562+ key_fields_schema = decode_engine_field_schemas (raw_key_fields_schema )
532563 key_decoder = make_engine_key_decoder (
533- ["(key)" ],
534- decode_engine_field_schemas (key_fields_schema ),
535- analyze_type_info (key_annotation ),
564+ ["<key>" ], key_fields_schema , analyze_type_info (key_annotation )
536565 )
566+ value_fields_schema = decode_engine_field_schemas (raw_value_fields_schema )
537567 value_decoder = make_engine_struct_decoder (
538- ["(value)" ],
539- decode_engine_field_schemas (value_fields_schema ),
540- analyze_type_info (value_annotation ),
568+ ["<value>" ], value_fields_schema , analyze_type_info (value_annotation )
541569 )
542570
543571 loaded_spec = _load_spec_from_engine (self ._spec_cls , spec )
544- prepare_method = getattr (self ._connector_cls , "prepare" , None )
545- if prepare_method is None :
546- prepared_spec = loaded_spec
547- else :
548- prepared_spec = prepare_method (loaded_spec )
549-
550572 return _TargetConnectorContext (
551573 target_name = name ,
552574 spec = loaded_spec ,
553- prepared_spec = prepared_spec ,
575+ prepared_spec = None ,
576+ key_fields_schema = key_fields_schema ,
554577 key_decoder = key_decoder ,
578+ value_fields_schema = value_fields_schema ,
555579 value_decoder = value_decoder ,
556580 )
557581
558582 def get_persistent_key (self , export_context : _TargetConnectorContext ) -> Any :
559- return self ._get_persistent_key_fn (
560- export_context .spec , export_context .target_name
583+ args = _build_args (
584+ self ._get_persistent_key_fn ,
585+ 1 ,
586+ spec = export_context .spec ,
587+ target_name = export_context .target_name ,
588+ )
589+ return dump_engine_object (self ._get_persistent_key_fn (* args ))
590+
591+ def get_setup_state (self , export_context : _TargetConnectorContext ) -> Any :
592+ get_persistent_state_fn = getattr (self ._connector_cls , "get_setup_state" , None )
593+ if get_persistent_state_fn is None :
594+ state = export_context .spec
595+ if not isinstance (state , self ._state_cls ):
596+ raise ValueError (
597+ f"Expect a get_setup_state() method for { self ._connector_cls } that returns an instance of { self ._state_cls } "
598+ )
599+ else :
600+ args = _build_args (
601+ get_persistent_state_fn ,
602+ 1 ,
603+ spec = export_context .spec ,
604+ key_fields_schema = export_context .key_fields_schema ,
605+ value_fields_schema = export_context .value_fields_schema ,
606+ )
607+ state = get_persistent_state_fn (* args )
608+ if not isinstance (state , self ._state_cls ):
609+ raise ValueError (
610+ f"Method { get_persistent_state_fn .__name__ } must return an instance of { self ._state_cls } , got { type (state )} "
611+ )
612+ return dump_engine_object (state )
613+
614+ async def prepare_async (self , export_context : _TargetConnectorContext ) -> None :
615+ prepare_fn = getattr (self ._connector_cls , "prepare" , None )
616+ if prepare_fn is None :
617+ export_context .prepared_spec = export_context .spec
618+ return
619+ args = _build_args (
620+ prepare_fn ,
621+ 1 ,
622+ spec = export_context .spec ,
623+ key_fields_schema = export_context .key_fields_schema ,
624+ value_fields_schema = export_context .value_fields_schema ,
561625 )
626+ async_prepare_fn = to_async_call (prepare_fn )
627+ export_context .prepared_spec = await async_prepare_fn (* args )
562628
563629 def describe_resource (self , key : Any ) -> str :
564630 describe_fn = getattr (self ._connector_cls , "describe" , None )
@@ -572,13 +638,13 @@ async def apply_setup_changes_async(
572638 ) -> None :
573639 for key , previous , current in changes :
574640 prev_specs = [
575- _load_spec_from_engine (self ._spec_cls , spec )
641+ _load_spec_from_engine (self ._state_cls , spec )
576642 if spec is not None
577643 else None
578644 for spec in previous
579645 ]
580646 curr_spec = (
581- _load_spec_from_engine (self ._spec_cls , current )
647+ _load_spec_from_engine (self ._state_cls , current )
582648 if current is not None
583649 else None
584650 )
@@ -611,7 +677,9 @@ async def mutate_async(
611677 )
612678
613679
614- def target_connector (spec_cls : type ) -> Callable [[type ], type ]:
680+ def target_connector (
681+ spec_cls : type , state_cls : type | None = None
682+ ) -> Callable [[type ], type ]:
615683 """
616684 Decorate a class to provide a target connector for an op.
617685 """
@@ -622,7 +690,7 @@ def target_connector(spec_cls: type) -> Callable[[type], type]:
622690
623691 # Register the target connector.
624692 def _inner (connector_cls : type ) -> type :
625- connector = _TargetConnector (spec_cls , connector_cls )
693+ connector = _TargetConnector (spec_cls , state_cls or spec_cls , connector_cls )
626694 _engine .register_target_connector (spec_cls .__name__ , connector )
627695 return connector_cls
628696
0 commit comments