1111from typing import (
1212 Any ,
1313 Optional ,
14+ Protocol ,
1415 Union ,
16+ runtime_checkable ,
1517)
1618
1719from cached_property import (
4547)
4648
4749
48- class LocalSagaStepDecoratorWrapper (SagaStepDecoratorWrapper ):
50+ @runtime_checkable
51+ class LocalSagaStepDecoratorWrapper (SagaStepDecoratorWrapper , Protocol ):
4952 """TODO"""
5053
5154 meta : LocalSagaStepDecoratorMeta
@@ -112,10 +115,12 @@ def __call__(self, func: LocalCallback) -> LocalSagaStepDecoratorWrapper:
112115 # noinspection PyTypeChecker
113116 return func
114117
115- def on_execute (self , callback : LocalCallback , parameters : Optional [SagaContext ] = None , ** kwargs ) -> LocalSagaStep :
118+ def on_execute (
119+ self , operation : Union [SagaOperation , LocalCallback ], parameters : Optional [SagaContext ] = None , ** kwargs
120+ ) -> LocalSagaStep :
116121 """On execute method.
117122
118- :param callback : The callback function to be called.
123+ :param operation : The callback function to be called.
119124 :param parameters: A mapping of named parameters to be passed to the callback.
120125 :param kwargs: A set of named arguments to be passed to the callback. ``parameters`` has order if it is not
121126 ``None``.
@@ -124,14 +129,19 @@ def on_execute(self, callback: LocalCallback, parameters: Optional[SagaContext]
124129 if self .on_execute_operation is not None :
125130 raise MultipleOnExecuteException ()
126131
127- self .on_execute_operation = SagaOperation (callback , parameters , ** kwargs )
132+ if not isinstance (operation , SagaOperation ):
133+ operation = SagaOperation (operation , parameters , ** kwargs )
134+
135+ self .on_execute_operation = operation
128136
129137 return self
130138
131- def on_failure (self , callback : LocalCallback , parameters : Optional [SagaContext ] = None , ** kwargs ) -> LocalSagaStep :
139+ def on_failure (
140+ self , operation : Union [SagaOperation , LocalCallback ], parameters : Optional [SagaContext ] = None , ** kwargs
141+ ) -> LocalSagaStep :
132142 """On failure method.
133143
134- :param callback : The callback function to be called.
144+ :param operation : The callback function to be called.
135145 :param parameters: A mapping of named parameters to be passed to the callback.
136146 :param kwargs: A set of named arguments to be passed to the callback. ``parameters`` has order if it is not
137147 ``None``.
@@ -140,7 +150,10 @@ def on_failure(self, callback: LocalCallback, parameters: Optional[SagaContext]
140150 if self .on_failure_operation is not None :
141151 raise MultipleOnFailureException ()
142152
143- self .on_failure_operation = SagaOperation (callback , parameters , ** kwargs )
153+ if not isinstance (operation , SagaOperation ):
154+ operation = SagaOperation (operation , parameters , ** kwargs )
155+
156+ self .on_failure_operation = operation
144157
145158 return self
146159
0 commit comments