22import json
33import time
44from dataclasses import dataclass , field
5- from typing import Any , AsyncGenerator , List , Optional
5+ from typing import Any , AsyncGenerator , AsyncIterable , AsyncIterator , Optional , cast
66
77import grpc
8- from grpc ._cython import cygrpc
8+ import grpc .aio
9+ from grpc ._cython import cygrpc # type: ignore[attr-defined]
910
1011from hatchet_sdk .clients .event_ts import Event_ts , read_with_interrupt
1112from hatchet_sdk .clients .run_event_listener import (
4041@dataclass
4142class GetActionListenerRequest :
4243 worker_name : str
43- services : List [str ]
44- actions : List [str ]
44+ services : list [str ]
45+ actions : list [str ]
4546 max_runs : Optional [int ] = None
4647 _labels : dict [str , str | int ] = field (default_factory = dict )
4748
4849 labels : dict [str , WorkerLabels ] = field (init = False )
4950
50- def __post_init__ (self ):
51+ def __post_init__ (self ) -> None :
5152 self .labels = {}
5253
5354 for key , value in self ._labels .items ():
@@ -78,7 +79,7 @@ class Action:
7879 child_workflow_key : str | None = None
7980 parent_workflow_run_id : str | None = None
8081
81- def __post_init__ (self ):
82+ def __post_init__ (self ) -> None :
8283 if isinstance (self .additional_metadata , str ) and self .additional_metadata != "" :
8384 try :
8485 self .additional_metadata = json .loads (self .additional_metadata )
@@ -114,11 +115,6 @@ def otel_attributes(self) -> dict[str, Any]:
114115 )
115116
116117
117- START_STEP_RUN = 0
118- CANCEL_STEP_RUN = 1
119- START_GET_GROUP_KEY = 2
120-
121-
122118@dataclass
123119class ActionListener :
124120 config : ClientConfig
@@ -131,22 +127,22 @@ class ActionListener:
131127 last_connection_attempt : float = field (default = 0 , init = False )
132128 last_heartbeat_succeeded : bool = field (default = True , init = False )
133129 time_last_hb_succeeded : float = field (default = 9999999999999 , init = False )
134- heartbeat_task : Optional [asyncio .Task ] = field (default = None , init = False )
130+ heartbeat_task : Optional [asyncio .Task [ None ] ] = field (default = None , init = False )
135131 run_heartbeat : bool = field (default = True , init = False )
136132 listen_strategy : str = field (default = "v2" , init = False )
137133 stop_signal : bool = field (default = False , init = False )
138134
139135 missed_heartbeats : int = field (default = 0 , init = False )
140136
141- def __post_init__ (self ):
142- self .client = DispatcherStub (new_conn (self .config , False ))
143- self .aio_client = DispatcherStub (new_conn (self .config , True ))
137+ def __post_init__ (self ) -> None :
138+ self .client = DispatcherStub (new_conn (self .config , False )) # type: ignore[no-untyped-call]
139+ self .aio_client = DispatcherStub (new_conn (self .config , True )) # type: ignore[no-untyped-call]
144140 self .token = self .config .token
145141
146- def is_healthy (self ):
142+ def is_healthy (self ) -> bool :
147143 return self .last_heartbeat_succeeded
148144
149- async def heartbeat (self ):
145+ async def heartbeat (self ) -> None :
150146 # send a heartbeat every 4 seconds
151147 heartbeat_delay = 4
152148
@@ -206,7 +202,7 @@ async def heartbeat(self):
206202 break
207203 await asyncio .sleep (heartbeat_delay )
208204
209- async def start_heartbeater (self ):
205+ async def start_heartbeater (self ) -> None :
210206 if self .heartbeat_task is not None :
211207 return
212208
@@ -220,10 +216,10 @@ async def start_heartbeater(self):
220216 raise e
221217 self .heartbeat_task = loop .create_task (self .heartbeat ())
222218
223- def __aiter__ (self ):
219+ def __aiter__ (self ) -> AsyncGenerator [ Action | None , None ] :
224220 return self ._generator ()
225221
226- async def _generator (self ) -> AsyncGenerator [Action , None ]:
222+ async def _generator (self ) -> AsyncGenerator [Action | None , None ]:
227223 listener = None
228224
229225 while not self .stop_signal :
@@ -239,6 +235,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
239235 try :
240236 while not self .stop_signal :
241237 self .interrupt = Event_ts ()
238+
239+ if listener is None :
240+ continue
241+
242242 t = asyncio .create_task (
243243 read_with_interrupt (listener , self .interrupt )
244244 )
@@ -251,7 +251,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
251251 )
252252
253253 t .cancel ()
254- listener .cancel ()
254+
255+ if listener :
256+ listener .cancel ()
257+
255258 break
256259
257260 assigned_action = t .result ()
@@ -261,10 +264,9 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
261264 break
262265
263266 self .retries = 0
264- assigned_action : AssignedAction
265267
266268 # Process the received action
267- action_type = self . map_action_type ( assigned_action .actionType )
269+ action_type = assigned_action .actionType
268270
269271 if (
270272 assigned_action .actionPayload is None
@@ -287,7 +289,8 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
287289 step_id = assigned_action .stepId ,
288290 step_run_id = assigned_action .stepRunId ,
289291 action_id = assigned_action .actionId ,
290- action_payload = action_payload ,
292+ ## TODO: Figure out this type - maybe needs to be dumped to JSON?
293+ action_payload = action_payload , # type: ignore[arg-type]
291294 action_type = action_type ,
292295 retry_count = assigned_action .retryCount ,
293296 additional_metadata = assigned_action .additional_metadata ,
@@ -324,25 +327,15 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
324327
325328 self .retries = self .retries + 1
326329
327- def parse_action_payload (self , payload : str ):
330+ def parse_action_payload (self , payload : str ) -> JSONSerializableDict :
328331 try :
329- payload_data = json .loads (payload )
332+ return cast ( JSONSerializableDict , json .loads (payload ) )
330333 except json .JSONDecodeError as e :
331334 raise ValueError (f"Error decoding payload: { e } " )
332- return payload_data
333-
334- def map_action_type (self , action_type ):
335- if action_type == ActionType .START_STEP_RUN :
336- return START_STEP_RUN
337- elif action_type == ActionType .CANCEL_STEP_RUN :
338- return CANCEL_STEP_RUN
339- elif action_type == ActionType .START_GET_GROUP_KEY :
340- return START_GET_GROUP_KEY
341- else :
342- # logger.error(f"Unknown action type: {action_type}")
343- return None
344335
345- async def get_listen_client (self ):
336+ async def get_listen_client (
337+ self ,
338+ ) -> grpc .aio .UnaryStreamCall [WorkerListenRequest , AssignedAction ]:
346339 current_time = int (time .time ())
347340
348341 if (
@@ -370,7 +363,8 @@ async def get_listen_client(self):
370363 f"action listener connection interrupted, retrying... ({ self .retries } /{ DEFAULT_ACTION_LISTENER_RETRY_COUNT } )"
371364 )
372365
373- self .aio_client = DispatcherStub (new_conn (self .config , True ))
366+ ## TODO: Figure out how to get type support for these
367+ self .aio_client = DispatcherStub (new_conn (self .config , True )) # type: ignore[no-untyped-call]
374368
375369 if self .listen_strategy == "v2" :
376370 # we should await for the listener to be established before
@@ -391,11 +385,14 @@ async def get_listen_client(self):
391385
392386 self .last_connection_attempt = current_time
393387
394- return listener
388+ return cast (
389+ grpc .aio .UnaryStreamCall [WorkerListenRequest , AssignedAction ], listener
390+ )
395391
396392 def cleanup (self ) -> None :
397393 self .run_heartbeat = False
398- self .heartbeat_task .cancel ()
394+ if self .heartbeat_task is not None :
395+ self .heartbeat_task .cancel ()
399396
400397 try :
401398 self .unregister ()
@@ -405,9 +402,11 @@ def cleanup(self) -> None:
405402 if self .interrupt :
406403 self .interrupt .set ()
407404
408- def unregister (self ):
405+ def unregister (self ) -> WorkerUnsubscribeRequest :
409406 self .run_heartbeat = False
410- self .heartbeat_task .cancel ()
407+
408+ if self .heartbeat_task is not None :
409+ self .heartbeat_task .cancel ()
411410
412411 try :
413412 req = self .aio_client .Unsubscribe (
@@ -417,6 +416,6 @@ def unregister(self):
417416 )
418417 if self .interrupt is not None :
419418 self .interrupt .set ()
420- return req
419+ return cast ( WorkerUnsubscribeRequest , req )
421420 except grpc .RpcError as e :
422421 raise Exception (f"Failed to unsubscribe: { e } " )
0 commit comments