Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit 62948af

Browse files
committed
fix: rest of the mypy errors
1 parent 5145119 commit 62948af

File tree

8 files changed

+192
-191
lines changed

8 files changed

+192
-191
lines changed

hatchet_sdk/clients/dispatcher/action_listener.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import json
33
import time
44
from dataclasses import dataclass, field
5-
from typing import Any, AsyncGenerator, List, Optional
5+
from typing import Any, AsyncGenerator, AsyncIterable, AsyncIterator, Optional, cast
66

77
import grpc
8-
from grpc._cython import cygrpc
8+
import grpc.aio
9+
from grpc._cython import cygrpc # type: ignore[attr-defined]
910

1011
from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt
1112
from hatchet_sdk.clients.run_event_listener import (
@@ -40,14 +41,14 @@
4041
@dataclass
4142
class GetActionListenerRequest:
4243
worker_name: str
43-
services: List[str]
44-
actions: List[str]
44+
services: list[str]
45+
actions: list[str]
4546
max_runs: Optional[int] = None
4647
_labels: dict[str, str | int] = field(default_factory=dict)
4748

4849
labels: dict[str, WorkerLabels] = field(init=False)
4950

50-
def __post_init__(self):
51+
def __post_init__(self) -> None:
5152
self.labels = {}
5253

5354
for key, value in self._labels.items():
@@ -78,7 +79,7 @@ class Action:
7879
child_workflow_key: str | None = None
7980
parent_workflow_run_id: str | None = None
8081

81-
def __post_init__(self):
82+
def __post_init__(self) -> None:
8283
if isinstance(self.additional_metadata, str) and self.additional_metadata != "":
8384
try:
8485
self.additional_metadata = json.loads(self.additional_metadata)
@@ -114,11 +115,6 @@ def otel_attributes(self) -> dict[str, Any]:
114115
)
115116

116117

117-
START_STEP_RUN = 0
118-
CANCEL_STEP_RUN = 1
119-
START_GET_GROUP_KEY = 2
120-
121-
122118
@dataclass
123119
class ActionListener:
124120
config: ClientConfig
@@ -131,22 +127,22 @@ class ActionListener:
131127
last_connection_attempt: float = field(default=0, init=False)
132128
last_heartbeat_succeeded: bool = field(default=True, init=False)
133129
time_last_hb_succeeded: float = field(default=9999999999999, init=False)
134-
heartbeat_task: Optional[asyncio.Task] = field(default=None, init=False)
130+
heartbeat_task: Optional[asyncio.Task[None]] = field(default=None, init=False)
135131
run_heartbeat: bool = field(default=True, init=False)
136132
listen_strategy: str = field(default="v2", init=False)
137133
stop_signal: bool = field(default=False, init=False)
138134

139135
missed_heartbeats: int = field(default=0, init=False)
140136

141-
def __post_init__(self):
142-
self.client = DispatcherStub(new_conn(self.config, False))
143-
self.aio_client = DispatcherStub(new_conn(self.config, True))
137+
def __post_init__(self) -> None:
138+
self.client = DispatcherStub(new_conn(self.config, False)) # type: ignore[no-untyped-call]
139+
self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
144140
self.token = self.config.token
145141

146-
def is_healthy(self):
142+
def is_healthy(self) -> bool:
147143
return self.last_heartbeat_succeeded
148144

149-
async def heartbeat(self):
145+
async def heartbeat(self) -> None:
150146
# send a heartbeat every 4 seconds
151147
heartbeat_delay = 4
152148

@@ -206,7 +202,7 @@ async def heartbeat(self):
206202
break
207203
await asyncio.sleep(heartbeat_delay)
208204

209-
async def start_heartbeater(self):
205+
async def start_heartbeater(self) -> None:
210206
if self.heartbeat_task is not None:
211207
return
212208

@@ -220,10 +216,10 @@ async def start_heartbeater(self):
220216
raise e
221217
self.heartbeat_task = loop.create_task(self.heartbeat())
222218

223-
def __aiter__(self):
219+
def __aiter__(self) -> AsyncGenerator[Action | None, None]:
224220
return self._generator()
225221

226-
async def _generator(self) -> AsyncGenerator[Action, None]:
222+
async def _generator(self) -> AsyncGenerator[Action | None, None]:
227223
listener = None
228224

229225
while not self.stop_signal:
@@ -239,6 +235,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
239235
try:
240236
while not self.stop_signal:
241237
self.interrupt = Event_ts()
238+
239+
if listener is None:
240+
continue
241+
242242
t = asyncio.create_task(
243243
read_with_interrupt(listener, self.interrupt)
244244
)
@@ -251,7 +251,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
251251
)
252252

253253
t.cancel()
254-
listener.cancel()
254+
255+
if listener:
256+
listener.cancel()
257+
255258
break
256259

257260
assigned_action = t.result()
@@ -261,10 +264,9 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
261264
break
262265

263266
self.retries = 0
264-
assigned_action: AssignedAction
265267

266268
# Process the received action
267-
action_type = self.map_action_type(assigned_action.actionType)
269+
action_type = assigned_action.actionType
268270

269271
if (
270272
assigned_action.actionPayload is None
@@ -287,7 +289,8 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
287289
step_id=assigned_action.stepId,
288290
step_run_id=assigned_action.stepRunId,
289291
action_id=assigned_action.actionId,
290-
action_payload=action_payload,
292+
## TODO: Figure out this type - maybe needs to be dumped to JSON?
293+
action_payload=action_payload, # type: ignore[arg-type]
291294
action_type=action_type,
292295
retry_count=assigned_action.retryCount,
293296
additional_metadata=assigned_action.additional_metadata,
@@ -324,25 +327,15 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
324327

325328
self.retries = self.retries + 1
326329

327-
def parse_action_payload(self, payload: str):
330+
def parse_action_payload(self, payload: str) -> JSONSerializableDict:
328331
try:
329-
payload_data = json.loads(payload)
332+
return cast(JSONSerializableDict, json.loads(payload))
330333
except json.JSONDecodeError as e:
331334
raise ValueError(f"Error decoding payload: {e}")
332-
return payload_data
333-
334-
def map_action_type(self, action_type):
335-
if action_type == ActionType.START_STEP_RUN:
336-
return START_STEP_RUN
337-
elif action_type == ActionType.CANCEL_STEP_RUN:
338-
return CANCEL_STEP_RUN
339-
elif action_type == ActionType.START_GET_GROUP_KEY:
340-
return START_GET_GROUP_KEY
341-
else:
342-
# logger.error(f"Unknown action type: {action_type}")
343-
return None
344335

345-
async def get_listen_client(self):
336+
async def get_listen_client(
337+
self,
338+
) -> grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction]:
346339
current_time = int(time.time())
347340

348341
if (
@@ -370,7 +363,8 @@ async def get_listen_client(self):
370363
f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})"
371364
)
372365

373-
self.aio_client = DispatcherStub(new_conn(self.config, True))
366+
## TODO: Figure out how to get type support for these
367+
self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
374368

375369
if self.listen_strategy == "v2":
376370
# we should await for the listener to be established before
@@ -391,11 +385,14 @@ async def get_listen_client(self):
391385

392386
self.last_connection_attempt = current_time
393387

394-
return listener
388+
return cast(
389+
grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction], listener
390+
)
395391

396392
def cleanup(self) -> None:
397393
self.run_heartbeat = False
398-
self.heartbeat_task.cancel()
394+
if self.heartbeat_task is not None:
395+
self.heartbeat_task.cancel()
399396

400397
try:
401398
self.unregister()
@@ -405,9 +402,11 @@ def cleanup(self) -> None:
405402
if self.interrupt:
406403
self.interrupt.set()
407404

408-
def unregister(self):
405+
def unregister(self) -> WorkerUnsubscribeRequest:
409406
self.run_heartbeat = False
410-
self.heartbeat_task.cancel()
407+
408+
if self.heartbeat_task is not None:
409+
self.heartbeat_task.cancel()
411410

412411
try:
413412
req = self.aio_client.Unsubscribe(
@@ -417,6 +416,6 @@ def unregister(self):
417416
)
418417
if self.interrupt is not None:
419418
self.interrupt.set()
420-
return req
419+
return cast(WorkerUnsubscribeRequest, req)
421420
except grpc.RpcError as e:
422421
raise Exception(f"Failed to unsubscribe: {e}")

hatchet_sdk/clients/event_ts.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2-
from typing import Any
2+
from typing import Any, TypeVar, cast
3+
4+
import grpc.aio
35

46

57
class Event_ts(asyncio.Event):
@@ -20,9 +22,14 @@ def clear(self) -> None:
2022
self._loop.call_soon_threadsafe(super().clear)
2123

2224

23-
async def read_with_interrupt(listener: Any, interrupt: Event_ts) -> Any:
25+
TRequest = TypeVar("TRequest")
26+
TResponse = TypeVar("TResponse")
27+
28+
29+
async def read_with_interrupt(
30+
listener: grpc.aio.UnaryStreamCall[TRequest, TResponse], interrupt: Event_ts
31+
) -> Any:
2432
try:
25-
result = await listener.read()
26-
return result
33+
return cast(Any, await listener.read())
2734
finally:
2835
interrupt.set()

hatchet_sdk/clients/rest/api_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def __aenter__(self):
9797
async def __aexit__(self, exc_type, exc_value, traceback):
9898
await self.close()
9999

100-
async def close(self):
100+
async def close(self) -> None:
101101
await self.rest_client.close()
102102

103103
@property

hatchet_sdk/clients/rest/tenacity_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def tenacity_alert_retry(retry_state: tenacity.RetryCallState) -> None:
2727
)
2828

2929

30-
def tenacity_should_retry(ex: Exception) -> bool:
30+
def tenacity_should_retry(ex: BaseException) -> bool:
3131
if isinstance(ex, (grpc.aio.AioRpcError, grpc.RpcError)):
3232
if ex.code() in [
3333
grpc.StatusCode.UNIMPLEMENTED,

0 commit comments

Comments
 (0)