Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit f3cd051

Browse files
authored
Merge pull request #4 from hatchet-dev/feat-finally-fix-mypy
Feat: Finally fixing mypy all the way
2 parents e04a237 + e8722ca commit f3cd051

File tree

15 files changed

+275
-254
lines changed

15 files changed

+275
-254
lines changed

hatchet_sdk/clients/dispatcher/action_listener.py

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import json
33
import time
44
from dataclasses import dataclass, field
5-
from typing import Any, AsyncGenerator, List, Optional
5+
from typing import Any, AsyncGenerator, AsyncIterable, AsyncIterator, Optional, cast
66

77
import grpc
8-
from grpc._cython import cygrpc
8+
import grpc.aio
9+
from grpc._cython import cygrpc # type: ignore[attr-defined]
910

1011
from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt
1112
from hatchet_sdk.clients.run_event_listener import (
@@ -40,14 +41,14 @@
4041
@dataclass
4142
class GetActionListenerRequest:
4243
worker_name: str
43-
services: List[str]
44-
actions: List[str]
44+
services: list[str]
45+
actions: list[str]
4546
max_runs: Optional[int] = None
4647
_labels: dict[str, str | int] = field(default_factory=dict)
4748

4849
labels: dict[str, WorkerLabels] = field(init=False)
4950

50-
def __post_init__(self):
51+
def __post_init__(self) -> None:
5152
self.labels = {}
5253

5354
for key, value in self._labels.items():
@@ -69,16 +70,16 @@ class Action:
6970
step_id: str
7071
step_run_id: str
7172
action_id: str
72-
action_payload: str
7373
action_type: ActionType
7474
retry_count: int
75+
action_payload: JSONSerializableDict = field(default_factory=dict)
7576
additional_metadata: JSONSerializableDict = field(default_factory=dict)
7677

7778
child_workflow_index: int | None = None
7879
child_workflow_key: str | None = None
7980
parent_workflow_run_id: str | None = None
8081

81-
def __post_init__(self):
82+
def __post_init__(self) -> None:
8283
if isinstance(self.additional_metadata, str) and self.additional_metadata != "":
8384
try:
8485
self.additional_metadata = json.loads(self.additional_metadata)
@@ -114,11 +115,6 @@ def otel_attributes(self) -> dict[str, Any]:
114115
)
115116

116117

117-
START_STEP_RUN = 0
118-
CANCEL_STEP_RUN = 1
119-
START_GET_GROUP_KEY = 2
120-
121-
122118
@dataclass
123119
class ActionListener:
124120
config: ClientConfig
@@ -131,22 +127,22 @@ class ActionListener:
131127
last_connection_attempt: float = field(default=0, init=False)
132128
last_heartbeat_succeeded: bool = field(default=True, init=False)
133129
time_last_hb_succeeded: float = field(default=9999999999999, init=False)
134-
heartbeat_task: Optional[asyncio.Task] = field(default=None, init=False)
130+
heartbeat_task: Optional[asyncio.Task[None]] = field(default=None, init=False)
135131
run_heartbeat: bool = field(default=True, init=False)
136132
listen_strategy: str = field(default="v2", init=False)
137133
stop_signal: bool = field(default=False, init=False)
138134

139135
missed_heartbeats: int = field(default=0, init=False)
140136

141-
def __post_init__(self):
142-
self.client = DispatcherStub(new_conn(self.config, False))
143-
self.aio_client = DispatcherStub(new_conn(self.config, True))
137+
def __post_init__(self) -> None:
138+
self.client = DispatcherStub(new_conn(self.config, False)) # type: ignore[no-untyped-call]
139+
self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
144140
self.token = self.config.token
145141

146-
def is_healthy(self):
142+
def is_healthy(self) -> bool:
147143
return self.last_heartbeat_succeeded
148144

149-
async def heartbeat(self):
145+
async def heartbeat(self) -> None:
150146
# send a heartbeat every 4 seconds
151147
heartbeat_delay = 4
152148

@@ -206,7 +202,7 @@ async def heartbeat(self):
206202
break
207203
await asyncio.sleep(heartbeat_delay)
208204

209-
async def start_heartbeater(self):
205+
async def start_heartbeater(self) -> None:
210206
if self.heartbeat_task is not None:
211207
return
212208

@@ -220,10 +216,10 @@ async def start_heartbeater(self):
220216
raise e
221217
self.heartbeat_task = loop.create_task(self.heartbeat())
222218

223-
def __aiter__(self):
219+
def __aiter__(self) -> AsyncGenerator[Action | None, None]:
224220
return self._generator()
225221

226-
async def _generator(self) -> AsyncGenerator[Action, None]:
222+
async def _generator(self) -> AsyncGenerator[Action | None, None]:
227223
listener = None
228224

229225
while not self.stop_signal:
@@ -239,6 +235,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
239235
try:
240236
while not self.stop_signal:
241237
self.interrupt = Event_ts()
238+
239+
if listener is None:
240+
continue
241+
242242
t = asyncio.create_task(
243243
read_with_interrupt(listener, self.interrupt)
244244
)
@@ -251,7 +251,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
251251
)
252252

253253
t.cancel()
254-
listener.cancel()
254+
255+
if listener:
256+
listener.cancel()
257+
255258
break
256259

257260
assigned_action = t.result()
@@ -261,20 +264,23 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
261264
break
262265

263266
self.retries = 0
264-
assigned_action: AssignedAction
265267

266268
# Process the received action
267-
action_type = self.map_action_type(assigned_action.actionType)
269+
action_type = assigned_action.actionType
268270

269-
if (
270-
assigned_action.actionPayload is None
271-
or assigned_action.actionPayload == ""
272-
):
273-
action_payload = None
274-
else:
275-
action_payload = self.parse_action_payload(
276-
assigned_action.actionPayload
271+
action_payload = (
272+
{}
273+
if not assigned_action.actionPayload
274+
else self.parse_action_payload(assigned_action.actionPayload)
275+
)
276+
277+
try:
278+
additional_metadata = cast(
279+
dict[str, Any],
280+
json.loads(assigned_action.additional_metadata),
277281
)
282+
except json.JSONDecodeError:
283+
additional_metadata = {}
278284

279285
action = Action(
280286
tenant_id=assigned_action.tenantId,
@@ -290,7 +296,7 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
290296
action_payload=action_payload,
291297
action_type=action_type,
292298
retry_count=assigned_action.retryCount,
293-
additional_metadata=assigned_action.additional_metadata,
299+
additional_metadata=additional_metadata,
294300
child_workflow_index=assigned_action.child_workflow_index,
295301
child_workflow_key=assigned_action.child_workflow_key,
296302
parent_workflow_run_id=assigned_action.parent_workflow_run_id,
@@ -324,25 +330,15 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
324330

325331
self.retries = self.retries + 1
326332

327-
def parse_action_payload(self, payload: str):
333+
def parse_action_payload(self, payload: str) -> JSONSerializableDict:
328334
try:
329-
payload_data = json.loads(payload)
335+
return cast(JSONSerializableDict, json.loads(payload))
330336
except json.JSONDecodeError as e:
331337
raise ValueError(f"Error decoding payload: {e}")
332-
return payload_data
333-
334-
def map_action_type(self, action_type):
335-
if action_type == ActionType.START_STEP_RUN:
336-
return START_STEP_RUN
337-
elif action_type == ActionType.CANCEL_STEP_RUN:
338-
return CANCEL_STEP_RUN
339-
elif action_type == ActionType.START_GET_GROUP_KEY:
340-
return START_GET_GROUP_KEY
341-
else:
342-
# logger.error(f"Unknown action type: {action_type}")
343-
return None
344338

345-
async def get_listen_client(self):
339+
async def get_listen_client(
340+
self,
341+
) -> grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction]:
346342
current_time = int(time.time())
347343

348344
if (
@@ -370,7 +366,7 @@ async def get_listen_client(self):
370366
f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})"
371367
)
372368

373-
self.aio_client = DispatcherStub(new_conn(self.config, True))
369+
self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
374370

375371
if self.listen_strategy == "v2":
376372
# we should await for the listener to be established before
@@ -391,11 +387,14 @@ async def get_listen_client(self):
391387

392388
self.last_connection_attempt = current_time
393389

394-
return listener
390+
return cast(
391+
grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction], listener
392+
)
395393

396394
def cleanup(self) -> None:
397395
self.run_heartbeat = False
398-
self.heartbeat_task.cancel()
396+
if self.heartbeat_task is not None:
397+
self.heartbeat_task.cancel()
399398

400399
try:
401400
self.unregister()
@@ -405,9 +404,11 @@ def cleanup(self) -> None:
405404
if self.interrupt:
406405
self.interrupt.set()
407406

408-
def unregister(self):
407+
def unregister(self) -> WorkerUnsubscribeRequest:
409408
self.run_heartbeat = False
410-
self.heartbeat_task.cancel()
409+
410+
if self.heartbeat_task is not None:
411+
self.heartbeat_task.cancel()
411412

412413
try:
413414
req = self.aio_client.Unsubscribe(
@@ -417,6 +418,6 @@ def unregister(self):
417418
)
418419
if self.interrupt is not None:
419420
self.interrupt.set()
420-
return req
421+
return cast(WorkerUnsubscribeRequest, req)
421422
except grpc.RpcError as e:
422423
raise Exception(f"Failed to unsubscribe: {e}")

hatchet_sdk/clients/dispatcher/dispatcher.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, cast
22

3+
import grpc.aio
34
from google.protobuf.timestamp_pb2 import Timestamp
45

56
from hatchet_sdk.clients.dispatcher.action_listener import (
@@ -69,7 +70,7 @@ async def get_action_listener(
6970

7071
async def send_step_action_event(
7172
self, action: Action, event_type: StepActionEventType, payload: str
72-
) -> Any:
73+
) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse] | None:
7374
try:
7475
return await self._try_send_step_action_event(action, event_type, payload)
7576
except Exception as e:
@@ -84,12 +85,12 @@ async def send_step_action_event(
8485
"Failed to send finished event: " + str(e),
8586
)
8687

87-
return
88+
return None
8889

8990
@tenacity_retry
9091
async def _try_send_step_action_event(
9192
self, action: Action, event_type: StepActionEventType, payload: str
92-
) -> Any:
93+
) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse]:
9394
eventTimestamp = Timestamp()
9495
eventTimestamp.GetCurrentTime()
9596

@@ -105,15 +106,17 @@ async def _try_send_step_action_event(
105106
eventPayload=payload,
106107
)
107108

108-
## TODO: What does this return?
109-
return await self.aio_client.SendStepActionEvent(
110-
event,
111-
metadata=get_metadata(self.token),
109+
return cast(
110+
grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse],
111+
await self.aio_client.SendStepActionEvent(
112+
event,
113+
metadata=get_metadata(self.token),
114+
),
112115
)
113116

114117
async def send_group_key_action_event(
115118
self, action: Action, event_type: GroupKeyActionEventType, payload: str
116-
) -> Any:
119+
) -> grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse]:
117120
eventTimestamp = Timestamp()
118121
eventTimestamp.GetCurrentTime()
119122

@@ -128,9 +131,12 @@ async def send_group_key_action_event(
128131
)
129132

130133
## TODO: What does this return?
131-
return await self.aio_client.SendGroupKeyActionEvent(
132-
event,
133-
metadata=get_metadata(self.token),
134+
return cast(
135+
grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse],
136+
await self.aio_client.SendGroupKeyActionEvent(
137+
event,
138+
metadata=get_metadata(self.token),
139+
),
134140
)
135141

136142
def put_overrides_data(self, data: OverridesData) -> ActionEventResponse:

hatchet_sdk/clients/event_ts.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import asyncio
2-
from typing import Any
2+
from typing import Any, TypeVar, cast
3+
4+
import grpc.aio
5+
from grpc._cython import cygrpc # type: ignore[attr-defined]
36

47

58
class Event_ts(asyncio.Event):
@@ -20,9 +23,19 @@ def clear(self) -> None:
2023
self._loop.call_soon_threadsafe(super().clear)
2124

2225

23-
async def read_with_interrupt(listener: Any, interrupt: Event_ts) -> Any:
26+
TRequest = TypeVar("TRequest")
27+
TResponse = TypeVar("TResponse")
28+
29+
30+
async def read_with_interrupt(
31+
listener: grpc.aio.UnaryStreamCall[TRequest, TResponse], interrupt: Event_ts
32+
) -> TResponse:
2433
try:
2534
result = await listener.read()
26-
return result
35+
36+
if result is cygrpc.EOF:
37+
raise ValueError("Unexpected EOF")
38+
39+
return cast(TResponse, result)
2740
finally:
2841
interrupt.set()

hatchet_sdk/clients/rest/api_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def __aenter__(self):
9797
async def __aexit__(self, exc_type, exc_value, traceback):
9898
await self.close()
9999

100-
async def close(self):
100+
async def close(self) -> None:
101101
await self.rest_client.close()
102102

103103
@property

hatchet_sdk/clients/rest/tenacity_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def tenacity_alert_retry(retry_state: tenacity.RetryCallState) -> None:
2727
)
2828

2929

30-
def tenacity_should_retry(ex: Exception) -> bool:
30+
def tenacity_should_retry(ex: BaseException) -> bool:
3131
if isinstance(ex, (grpc.aio.AioRpcError, grpc.RpcError)):
3232
if ex.code() in [
3333
grpc.StatusCode.UNIMPLEMENTED,

0 commit comments

Comments
 (0)