66import dataclasses
77import inspect
88from enum import Enum
9- from typing import Any , Awaitable , Callable , Protocol , dataclass_transform , Annotated
9+ from typing import (
10+ Any ,
11+ Awaitable ,
12+ Callable ,
13+ Protocol ,
14+ dataclass_transform ,
15+ Annotated ,
16+ get_args ,
17+ )
1018
1119from . import _engine # type: ignore
12- from .convert import encode_engine_value , make_engine_value_decoder
13- from .typing import TypeAttr , encode_enriched_type , resolve_forward_ref
20+ from .convert import (
21+ encode_engine_value ,
22+ make_engine_value_decoder ,
23+ make_engine_struct_decoder ,
24+ )
25+ from .typing import (
26+ TypeAttr ,
27+ encode_enriched_type ,
28+ resolve_forward_ref ,
29+ analyze_type_info ,
30+ AnalyzedAnyType ,
31+ AnalyzedDictType ,
32+ )
1433
1534
1635class OpCategory (Enum ):
@@ -65,6 +84,22 @@ class Executor(Protocol):
6584 op_category : OpCategory
6685
6786
87+ def _load_spec_from_engine (spec_cls : type , spec : dict [str , Any ]) -> Any :
88+ """
89+ Load a spec from the engine.
90+ """
91+ return spec_cls (** spec )
92+
93+
94+ def _get_required_method (cls : type , name : str ) -> Callable [..., Any ]:
95+ method = getattr (cls , name , None )
96+ if method is None :
97+ raise ValueError (f"Method { name } () is required for { cls .__name__ } " )
98+ if not inspect .isfunction (method ):
99+ raise ValueError (f"Method { cls .__name__ } .{ name } () is not a function" )
100+ return method
101+
102+
68103class _FunctionExecutorFactory :
69104 _spec_cls : type
70105 _executor_cls : type
@@ -76,7 +111,7 @@ def __init__(self, spec_cls: type, executor_cls: type):
76111 def __call__ (
77112 self , spec : dict [str , Any ], * args : Any , ** kwargs : Any
78113 ) -> tuple [dict [str , Any ], Executor ]:
79- spec = self ._spec_cls ( ** spec )
114+ spec = _load_spec_from_engine ( self ._spec_cls , spec )
80115 executor = self ._executor_cls (spec )
81116 result_type = executor .analyze (* args , ** kwargs )
82117 return (encode_enriched_type (result_type ), executor )
@@ -359,3 +394,201 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
359394 return _Spec ()
360395
361396 return _inner
397+
398+
399+ ########################################################
400+ # Customized target connector
401+ ########################################################
402+
403+
404+ @dataclasses .dataclass
405+ class _TargetConnectorContext :
406+ target_name : str
407+ spec : Any
408+ key_decoder : Callable [[Any ], Any ]
409+ value_decoder : Callable [[Any ], Any ]
410+
411+
412+ class _TargetConnector :
413+ """
414+ The connector class passed to the engine.
415+ """
416+
417+ _spec_cls : type
418+ _connector_cls : type
419+
420+ _get_persistent_key_fn : Callable [[_TargetConnectorContext , str ], Any ]
421+ _apply_setup_change_async_fn : Callable [
422+ [Any , dict [str , Any ] | None , dict [str , Any ] | None ], Awaitable [None ]
423+ ]
424+ _mutate_async_fn : Callable [..., Awaitable [None ]]
425+ _mutatation_type : AnalyzedDictType | None
426+
427+ def __init__ (self , spec_cls : type , connector_cls : type ):
428+ self ._spec_cls = spec_cls
429+ self ._connector_cls = connector_cls
430+
431+ self ._get_persistent_key_fn = _get_required_method (
432+ connector_cls , "get_persistent_key"
433+ )
434+ self ._apply_setup_change_async_fn = _to_async_call (
435+ _get_required_method (connector_cls , "apply_setup_change" )
436+ )
437+
438+ mutate_fn = _get_required_method (connector_cls , "mutate" )
439+ self ._mutate_async_fn = _to_async_call (mutate_fn )
440+
441+ # Store the type annotation for later use
442+ self ._mutatation_type = self ._analyze_mutate_mutation_type (
443+ connector_cls , mutate_fn
444+ )
445+
446+ @staticmethod
447+ def _analyze_mutate_mutation_type (
448+ connector_cls : type , mutate_fn : Callable [..., Any ]
449+ ) -> AnalyzedDictType | None :
450+ # Validate mutate_fn signature and extract type annotation
451+ mutate_sig = inspect .signature (mutate_fn )
452+ params = list (mutate_sig .parameters .values ())
453+
454+ if len (params ) != 1 :
455+ raise ValueError (
456+ f"Method { connector_cls .__name__ } .mutate(*args) must have exactly one parameter, "
457+ f"got { len (params )} "
458+ )
459+
460+ param = params [0 ]
461+ if param .kind != inspect .Parameter .VAR_POSITIONAL :
462+ raise ValueError (
463+ f"Method { connector_cls .__name__ } .mutate(*args) parameter must be *args format, "
464+ f"got { param .kind .name } "
465+ )
466+
467+ # Extract type annotation
468+ analyzed_args_type = analyze_type_info (param .annotation )
469+ if isinstance (analyzed_args_type .variant , AnalyzedAnyType ):
470+ return None
471+
472+ if analyzed_args_type .base_type is tuple :
473+ args = get_args (analyzed_args_type .core_type )
474+ if not args :
475+ return None
476+ if len (args ) == 2 :
477+ mutation_type = analyze_type_info (args [1 ])
478+ if isinstance (mutation_type .variant , AnalyzedAnyType ):
479+ return None
480+ if isinstance (mutation_type .variant , AnalyzedDictType ):
481+ return mutation_type .variant
482+
483+ raise ValueError (
484+ f"Method { connector_cls .__name__ } .mutate(*args) parameter must be a tuple with "
485+ f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
486+ "got {args_type}"
487+ )
488+
489+ def create_export_context (
490+ self ,
491+ name : str ,
492+ spec : dict [str , Any ],
493+ key_fields_schema : list [Any ],
494+ value_fields_schema : list [Any ],
495+ ) -> _TargetConnectorContext :
496+ key_annotation , value_annotation = (
497+ (
498+ self ._mutatation_type .key_type ,
499+ self ._mutatation_type .value_type ,
500+ )
501+ if self ._mutatation_type is not None
502+ else (None , None )
503+ )
504+
505+ if len (key_fields_schema ) == 1 :
506+ key_decoder = make_engine_value_decoder (
507+ ["(key)" ], key_fields_schema [0 ]["type" ], key_annotation
508+ )
509+ else :
510+ key_decoder = make_engine_struct_decoder (
511+ ["(key)" ], key_fields_schema , analyze_type_info (key_annotation )
512+ )
513+
514+ value_decoder = make_engine_struct_decoder (
515+ ["(value)" ], value_fields_schema , analyze_type_info (value_annotation )
516+ )
517+
518+ return _TargetConnectorContext (
519+ target_name = name ,
520+ spec = _load_spec_from_engine (self ._spec_cls , spec ),
521+ key_decoder = key_decoder ,
522+ value_decoder = value_decoder ,
523+ )
524+
525+ def get_persistent_key (self , export_context : _TargetConnectorContext ) -> Any :
526+ return self ._get_persistent_key_fn (
527+ export_context .spec , export_context .target_name
528+ )
529+
530+ def describe_resource (self , key : Any ) -> str :
531+ describe_fn = getattr (self ._connector_cls , "describe" , None )
532+ if describe_fn is None :
533+ return str (key )
534+ return str (describe_fn (key ))
535+
536+ async def apply_setup_changes_async (
537+ self ,
538+ changes : list [tuple [Any , list [dict [str , Any ] | None ], dict [str , Any ] | None ]],
539+ ) -> None :
540+ for key , previous , current in changes :
541+ prev_specs = [
542+ _load_spec_from_engine (self ._spec_cls , spec )
543+ if spec is not None
544+ else None
545+ for spec in previous
546+ ]
547+ curr_spec = (
548+ _load_spec_from_engine (self ._spec_cls , current )
549+ if current is not None
550+ else None
551+ )
552+ for prev_spec in prev_specs :
553+ await self ._apply_setup_change_async_fn (key , prev_spec , curr_spec )
554+
555+ @staticmethod
556+ def _decode_mutation (
557+ context : _TargetConnectorContext , mutation : list [tuple [Any , Any | None ]]
558+ ) -> tuple [Any , dict [Any , Any | None ]]:
559+ return (
560+ context .spec ,
561+ {
562+ context .key_decoder (key ): context .value_decoder (value )
563+ for key , value in mutation
564+ },
565+ )
566+
567+ async def mutate_async (
568+ self ,
569+ mutations : list [tuple [_TargetConnectorContext , list [tuple [Any , Any | None ]]]],
570+ ) -> None :
571+ await self ._mutate_async_fn (
572+ * (
573+ self ._decode_mutation (context , mutation )
574+ for context , mutation in mutations
575+ )
576+ )
577+
578+
579+ def target_connector (spec_cls : type ) -> Callable [[type ], type ]:
580+ """
581+ Decorate a class to provide a target connector for an op.
582+ """
583+
584+ # Validate the spec_cls is a TargetSpec.
585+ if not issubclass (spec_cls , TargetSpec ):
586+ raise ValueError (f"Expect a TargetSpec, got { spec_cls } " )
587+
588+ # Register the target connector.
589+ def _inner (connector_cls : type ) -> type :
590+ connector = _TargetConnector (spec_cls , connector_cls )
591+ _engine .register_target_connector (spec_cls .__name__ , connector )
592+ return connector_cls
593+
594+ return _inner
0 commit comments