2626)
2727from ..util .cli .arg import Arg
2828from ..util .cli .cmd import CMD
29+ from ..util .asynchelper import context_stacker , aenter_stack
2930from ..util .entrypoint import base_entry_point
3031
3132
@@ -101,7 +102,7 @@ def add_label(cls, *above):
101102 return list (above ) + cls .op .name .split ("_" )
102103
103104
104- def op (** kwargs ):
105+ def op (imp_enter = None , ctx_enter = None , ** kwargs ):
105106 def wrap (func ):
106107 if not "name" in kwargs :
107108 kwargs ["name" ] = func .__name__
@@ -115,7 +116,9 @@ def wrap(func):
115116 func , OperationImplementationContext
116117 ):
117118
118- class Implementation (OperationImplementation ):
119+ class Implementation (
120+ context_stacker (OperationImplementation , imp_enter )
121+ ):
119122
120123 op = func .op
121124 CONTEXT = func
@@ -124,14 +127,25 @@ class Implementation(OperationImplementation):
124127 return func
125128 else :
126129
127- class ImplementationContext (OperationImplementationContext ):
130+ class ImplementationContext (
131+ context_stacker (OperationImplementationContext , ctx_enter )
132+ ):
128133 async def run (
129134 self , inputs : Dict [str , Any ]
130135 ) -> Union [bool , Dict [str , Any ]]:
131136 # TODO Add auto thread pooling of non-async functions
137+ # If imp_enter or ctx_enter exist then bind the function to
138+ # the ImplementationContext so that it has access to the
139+ # context and it's parent
140+ if imp_enter is not None or ctx_enter is not None :
141+ return await (
142+ func .__get__ (self , self .__class__ )(** inputs )
143+ )
132144 return await func (** inputs )
133145
134- class Implementation (OperationImplementation ):
146+ class Implementation (
147+ context_stacker (OperationImplementation , imp_enter )
148+ ):
135149
136150 op = func .op
137151 CONTEXT = ImplementationContext
@@ -609,37 +623,6 @@ class BaseOrchestratorConfig(BaseConfig, NamedTuple):
609623
610624
611625class BaseOrchestratorContext (BaseDataFlowObjectContext ):
612- def __init__ (self , parent : "BaseOrchestrator" ) -> None :
613- super ().__init__ (parent )
614- self .__stack = None
615-
616- async def __aenter__ (self ) -> "BaseOrchestratorContext" :
617- """
618- Ahoy matey, enter if ye dare into the management of the contexts. Eh
619- well not sure if there's really any context being managed here...
620- """
621- self .__stack = AsyncExitStack ()
622- await self .__stack .__aenter__ ()
623- self .rctx = await self .__stack .enter_async_context (
624- self .parent .rchecker ()
625- )
626- self .ictx = await self .__stack .enter_async_context (
627- self .parent .input_network ()
628- )
629- self .octx = await self .__stack .enter_async_context (
630- self .parent .operation_network ()
631- )
632- self .lctx = await self .__stack .enter_async_context (
633- self .parent .lock_network ()
634- )
635- self .nctx = await self .__stack .enter_async_context (
636- self .parent .opimp_network ()
637- )
638- return self
639-
640- async def __aexit__ (self , exc_type , exc_value , traceback ):
641- await self .__stack .aclose ()
642-
643626 @abc .abstractmethod
644627 async def run_operations (
645628 self , strict : bool = False
@@ -651,29 +634,4 @@ async def run_operations(
651634
652635@base_entry_point ("dffml.orchestrator" , "dff" )
653636class BaseOrchestrator (BaseDataFlowObject ):
654- def __init__ (self , config : "BaseConfig" ) -> None :
655- super ().__init__ (config )
656- self .__stack = None
657-
658- async def __aenter__ (self ) -> "DataFlowFacilitator" :
659- self .__stack = AsyncExitStack ()
660- await self .__stack .__aenter__ ()
661- self .rchecker = await self .__stack .enter_async_context (
662- self .config .rchecker
663- )
664- self .input_network = await self .__stack .enter_async_context (
665- self .config .input_network
666- )
667- self .operation_network = await self .__stack .enter_async_context (
668- self .config .operation_network
669- )
670- self .lock_network = await self .__stack .enter_async_context (
671- self .config .lock_network
672- )
673- self .opimp_network = await self .__stack .enter_async_context (
674- self .config .opimp_network
675- )
676- return self
677-
678- async def __aexit__ (self , exc_type , exc_value , traceback ):
679- await self .__stack .aclose ()
637+ pass # pragma: no cov
0 commit comments