22import json
33import time
44from dataclasses import dataclass , field
5- from typing import Any , AsyncGenerator , List , Optional
5+ from typing import Any , AsyncGenerator , AsyncIterable , AsyncIterator , Optional , cast
66
77import grpc
8- from grpc ._cython import cygrpc
8+ import grpc .aio
9+ from grpc ._cython import cygrpc # type: ignore[attr-defined]
910
1011from hatchet_sdk .clients .event_ts import Event_ts , read_with_interrupt
1112from hatchet_sdk .clients .run_event_listener import (
4041@dataclass
4142class GetActionListenerRequest :
4243 worker_name : str
43- services : List [str ]
44- actions : List [str ]
44+ services : list [str ]
45+ actions : list [str ]
4546 max_runs : Optional [int ] = None
4647 _labels : dict [str , str | int ] = field (default_factory = dict )
4748
4849 labels : dict [str , WorkerLabels ] = field (init = False )
4950
50- def __post_init__ (self ):
51+ def __post_init__ (self ) -> None :
5152 self .labels = {}
5253
5354 for key , value in self ._labels .items ():
@@ -69,16 +70,16 @@ class Action:
6970 step_id : str
7071 step_run_id : str
7172 action_id : str
72- action_payload : str
7373 action_type : ActionType
7474 retry_count : int
75+ action_payload : JSONSerializableDict = field (default_factory = dict )
7576 additional_metadata : JSONSerializableDict = field (default_factory = dict )
7677
7778 child_workflow_index : int | None = None
7879 child_workflow_key : str | None = None
7980 parent_workflow_run_id : str | None = None
8081
81- def __post_init__ (self ):
82+ def __post_init__ (self ) -> None :
8283 if isinstance (self .additional_metadata , str ) and self .additional_metadata != "" :
8384 try :
8485 self .additional_metadata = json .loads (self .additional_metadata )
@@ -114,11 +115,6 @@ def otel_attributes(self) -> dict[str, Any]:
114115 )
115116
116117
117- START_STEP_RUN = 0
118- CANCEL_STEP_RUN = 1
119- START_GET_GROUP_KEY = 2
120-
121-
122118@dataclass
123119class ActionListener :
124120 config : ClientConfig
@@ -131,22 +127,22 @@ class ActionListener:
131127 last_connection_attempt : float = field (default = 0 , init = False )
132128 last_heartbeat_succeeded : bool = field (default = True , init = False )
133129 time_last_hb_succeeded : float = field (default = 9999999999999 , init = False )
134- heartbeat_task : Optional [asyncio .Task ] = field (default = None , init = False )
130+ heartbeat_task : Optional [asyncio .Task [ None ] ] = field (default = None , init = False )
135131 run_heartbeat : bool = field (default = True , init = False )
136132 listen_strategy : str = field (default = "v2" , init = False )
137133 stop_signal : bool = field (default = False , init = False )
138134
139135 missed_heartbeats : int = field (default = 0 , init = False )
140136
141- def __post_init__ (self ):
142- self .client = DispatcherStub (new_conn (self .config , False ))
143- self .aio_client = DispatcherStub (new_conn (self .config , True ))
137+ def __post_init__ (self ) -> None :
138+ self .client = DispatcherStub (new_conn (self .config , False )) # type: ignore[no-untyped-call]
139+ self .aio_client = DispatcherStub (new_conn (self .config , True )) # type: ignore[no-untyped-call]
144140 self .token = self .config .token
145141
146- def is_healthy (self ):
142+ def is_healthy (self ) -> bool :
147143 return self .last_heartbeat_succeeded
148144
149- async def heartbeat (self ):
145+ async def heartbeat (self ) -> None :
150146 # send a heartbeat every 4 seconds
151147 heartbeat_delay = 4
152148
@@ -206,7 +202,7 @@ async def heartbeat(self):
206202 break
207203 await asyncio .sleep (heartbeat_delay )
208204
209- async def start_heartbeater (self ):
205+ async def start_heartbeater (self ) -> None :
210206 if self .heartbeat_task is not None :
211207 return
212208
@@ -220,10 +216,10 @@ async def start_heartbeater(self):
220216 raise e
221217 self .heartbeat_task = loop .create_task (self .heartbeat ())
222218
223- def __aiter__ (self ):
219+ def __aiter__ (self ) -> AsyncGenerator [ Action | None , None ] :
224220 return self ._generator ()
225221
226- async def _generator (self ) -> AsyncGenerator [Action , None ]:
222+ async def _generator (self ) -> AsyncGenerator [Action | None , None ]:
227223 listener = None
228224
229225 while not self .stop_signal :
@@ -239,6 +235,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
239235 try :
240236 while not self .stop_signal :
241237 self .interrupt = Event_ts ()
238+
239+ if listener is None :
240+ continue
241+
242242 t = asyncio .create_task (
243243 read_with_interrupt (listener , self .interrupt )
244244 )
@@ -251,7 +251,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
251251 )
252252
253253 t .cancel ()
254- listener .cancel ()
254+
255+ if listener :
256+ listener .cancel ()
257+
255258 break
256259
257260 assigned_action = t .result ()
@@ -261,20 +264,23 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
261264 break
262265
263266 self .retries = 0
264- assigned_action : AssignedAction
265267
266268 # Process the received action
267- action_type = self . map_action_type ( assigned_action .actionType )
269+ action_type = assigned_action .actionType
268270
269- if (
270- assigned_action .actionPayload is None
271- or assigned_action .actionPayload == ""
272- ):
273- action_payload = None
274- else :
275- action_payload = self .parse_action_payload (
276- assigned_action .actionPayload
271+ action_payload = (
272+ {}
273+ if not assigned_action .actionPayload
274+ else self .parse_action_payload (assigned_action .actionPayload )
275+ )
276+
277+ try :
278+ additional_metadata = cast (
279+ dict [str , Any ],
280+ json .loads (assigned_action .additional_metadata ),
277281 )
282+ except json .JSONDecodeError :
283+ additional_metadata = {}
278284
279285 action = Action (
280286 tenant_id = assigned_action .tenantId ,
@@ -290,7 +296,7 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
290296 action_payload = action_payload ,
291297 action_type = action_type ,
292298 retry_count = assigned_action .retryCount ,
293- additional_metadata = assigned_action . additional_metadata ,
299+ additional_metadata = additional_metadata ,
294300 child_workflow_index = assigned_action .child_workflow_index ,
295301 child_workflow_key = assigned_action .child_workflow_key ,
296302 parent_workflow_run_id = assigned_action .parent_workflow_run_id ,
@@ -324,25 +330,15 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
324330
325331 self .retries = self .retries + 1
326332
327- def parse_action_payload (self , payload : str ):
333+ def parse_action_payload (self , payload : str ) -> JSONSerializableDict :
328334 try :
329- payload_data = json .loads (payload )
335+ return cast ( JSONSerializableDict , json .loads (payload ) )
330336 except json .JSONDecodeError as e :
331337 raise ValueError (f"Error decoding payload: { e } " )
332- return payload_data
333-
334- def map_action_type (self , action_type ):
335- if action_type == ActionType .START_STEP_RUN :
336- return START_STEP_RUN
337- elif action_type == ActionType .CANCEL_STEP_RUN :
338- return CANCEL_STEP_RUN
339- elif action_type == ActionType .START_GET_GROUP_KEY :
340- return START_GET_GROUP_KEY
341- else :
342- # logger.error(f"Unknown action type: {action_type}")
343- return None
344338
345- async def get_listen_client (self ):
339+ async def get_listen_client (
340+ self ,
341+ ) -> grpc .aio .UnaryStreamCall [WorkerListenRequest , AssignedAction ]:
346342 current_time = int (time .time ())
347343
348344 if (
@@ -370,7 +366,7 @@ async def get_listen_client(self):
370366 f"action listener connection interrupted, retrying... ({ self .retries } /{ DEFAULT_ACTION_LISTENER_RETRY_COUNT } )"
371367 )
372368
373- self .aio_client = DispatcherStub (new_conn (self .config , True ))
369+ self .aio_client = DispatcherStub (new_conn (self .config , True )) # type: ignore[no-untyped-call]
374370
375371 if self .listen_strategy == "v2" :
376372 # we should await for the listener to be established before
@@ -391,11 +387,14 @@ async def get_listen_client(self):
391387
392388 self .last_connection_attempt = current_time
393389
394- return listener
390+ return cast (
391+ grpc .aio .UnaryStreamCall [WorkerListenRequest , AssignedAction ], listener
392+ )
395393
396394 def cleanup (self ) -> None :
397395 self .run_heartbeat = False
398- self .heartbeat_task .cancel ()
396+ if self .heartbeat_task is not None :
397+ self .heartbeat_task .cancel ()
399398
400399 try :
401400 self .unregister ()
@@ -405,9 +404,11 @@ def cleanup(self) -> None:
405404 if self .interrupt :
406405 self .interrupt .set ()
407406
408- def unregister (self ):
407+ def unregister (self ) -> WorkerUnsubscribeRequest :
409408 self .run_heartbeat = False
410- self .heartbeat_task .cancel ()
409+
410+ if self .heartbeat_task is not None :
411+ self .heartbeat_task .cancel ()
411412
412413 try :
413414 req = self .aio_client .Unsubscribe (
@@ -417,6 +418,6 @@ def unregister(self):
417418 )
418419 if self .interrupt is not None :
419420 self .interrupt .set ()
420- return req
421+ return cast ( WorkerUnsubscribeRequest , req )
421422 except grpc .RpcError as e :
422423 raise Exception (f"Failed to unsubscribe: { e } " )
0 commit comments