Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit 0973017

Browse files
committed
fix: remove some Any types
1 parent 49693cd commit 0973017

File tree

3 files changed

+41
-26
lines changed

3 files changed

+41
-26
lines changed

hatchet_sdk/clients/dispatcher/dispatcher.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, cast
22

3+
import grpc.aio
34
from google.protobuf.timestamp_pb2 import Timestamp
45

56
from hatchet_sdk.clients.dispatcher.action_listener import (
@@ -69,7 +70,7 @@ async def get_action_listener(
6970

7071
async def send_step_action_event(
7172
self, action: Action, event_type: StepActionEventType, payload: str
72-
) -> Any:
73+
) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse] | None:
7374
try:
7475
return await self._try_send_step_action_event(action, event_type, payload)
7576
except Exception as e:
@@ -84,12 +85,12 @@ async def send_step_action_event(
8485
"Failed to send finished event: " + str(e),
8586
)
8687

87-
return
88+
return None
8889

8990
@tenacity_retry
9091
async def _try_send_step_action_event(
9192
self, action: Action, event_type: StepActionEventType, payload: str
92-
) -> Any:
93+
) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse]:
9394
eventTimestamp = Timestamp()
9495
eventTimestamp.GetCurrentTime()
9596

@@ -105,15 +106,17 @@ async def _try_send_step_action_event(
105106
eventPayload=payload,
106107
)
107108

108-
## TODO: What does this return?
109-
return await self.aio_client.SendStepActionEvent(
110-
event,
111-
metadata=get_metadata(self.token),
109+
return cast(
110+
grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse],
111+
await self.aio_client.SendStepActionEvent(
112+
event,
113+
metadata=get_metadata(self.token),
114+
),
112115
)
113116

114117
async def send_group_key_action_event(
115118
self, action: Action, event_type: GroupKeyActionEventType, payload: str
116-
) -> Any:
119+
) -> grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse]:
117120
eventTimestamp = Timestamp()
118121
eventTimestamp.GetCurrentTime()
119122

@@ -128,9 +131,12 @@ async def send_group_key_action_event(
128131
)
129132

130133
## TODO: What does this return?
131-
return await self.aio_client.SendGroupKeyActionEvent(
132-
event,
133-
metadata=get_metadata(self.token),
134+
return cast(
135+
grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse],
136+
await self.aio_client.SendGroupKeyActionEvent(
137+
event,
138+
metadata=get_metadata(self.token),
139+
),
134140
)
135141

136142
def put_overrides_data(self, data: OverridesData) -> ActionEventResponse:

hatchet_sdk/clients/event_ts.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, TypeVar, cast
33

44
import grpc.aio
5+
from grpc._cython import cygrpc # type: ignore[attr-defined]
56

67

78
class Event_ts(asyncio.Event):
@@ -28,8 +29,13 @@ def clear(self) -> None:
2829

2930
async def read_with_interrupt(
3031
listener: grpc.aio.UnaryStreamCall[TRequest, TResponse], interrupt: Event_ts
31-
) -> Any:
32+
) -> TResponse:
3233
try:
33-
return cast(Any, await listener.read())
34+
result = await listener.read()
35+
36+
if result is cygrpc.EOF:
37+
raise ValueError("Unexpected EOF")
38+
39+
return cast(TResponse, result)
3440
finally:
3541
interrupt.set()

hatchet_sdk/worker/worker.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,20 @@
1919

2020
from hatchet_sdk import Context
2121
from hatchet_sdk.client import Client, new_client_raw
22+
from hatchet_sdk.clients.dispatcher.action_listener import Action
2223
from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts
2324
from hatchet_sdk.loader import ClientConfig
2425
from hatchet_sdk.logger import logger
2526
from hatchet_sdk.utils.types import WorkflowValidator
2627
from hatchet_sdk.utils.typing import is_basemodel_subclass
27-
from hatchet_sdk.worker.action_listener_process import worker_action_listener_process
28-
from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager
28+
from hatchet_sdk.worker.action_listener_process import (
29+
ActionEvent,
30+
worker_action_listener_process,
31+
)
32+
from hatchet_sdk.worker.runner.run_loop_manager import (
33+
STOP_LOOP_TYPE,
34+
WorkerActionRunLoopManager,
35+
)
2936
from hatchet_sdk.workflow import WorkflowInterface
3037

3138
T = TypeVar("T")
@@ -74,13 +81,13 @@ def __init__(
7481
self._status: WorkerStatus
7582

7683
self.action_listener_process: BaseProcess
77-
self.action_listener_health_check: asyncio.Task[Any]
84+
self.action_listener_health_check: asyncio.Task[None]
7885
self.action_runner: WorkerActionRunLoopManager
7986

8087
self.ctx = multiprocessing.get_context("spawn")
8188

82-
self.action_queue: "Queue[Any]" = self.ctx.Queue()
83-
self.event_queue: "Queue[Any]" = self.ctx.Queue()
89+
self.action_queue: "Queue[Action | STOP_LOOP_TYPE]" = self.ctx.Queue()
90+
self.event_queue: "Queue[ActionEvent]" = self.ctx.Queue()
8491

8592
self.loop: asyncio.AbstractEventLoop
8693

@@ -193,12 +200,10 @@ async def start_health_server(self) -> None:
193200

194201
logger.info(f"healthcheck server running on port {port}")
195202

196-
def start(
197-
self, options: WorkerStartOptions = WorkerStartOptions()
198-
) -> Future[asyncio.Task[Any] | None]:
203+
def start(self, options: WorkerStartOptions = WorkerStartOptions()) -> None:
199204
self.owned_loop = self.setup_loop(options.loop)
200205

201-
f = asyncio.run_coroutine_threadsafe(
206+
asyncio.run_coroutine_threadsafe(
202207
self.async_start(options, _from_start=True), self.loop
203208
)
204209

@@ -209,14 +214,12 @@ def start(
209214
if self.handle_kill:
210215
sys.exit(0)
211216

212-
return f
213-
214217
## Start methods
215218
async def async_start(
216219
self,
217220
options: WorkerStartOptions = WorkerStartOptions(),
218221
_from_start: bool = False,
219-
) -> Any | None:
222+
) -> None:
220223
main_pid = os.getpid()
221224
logger.info("------------------------------------------")
222225
logger.info("STARTING HATCHET...")
@@ -245,7 +248,7 @@ async def async_start(
245248
self._check_listener_health()
246249
)
247250

248-
return await self.action_listener_health_check
251+
await self.action_listener_health_check
249252

250253
def _run_action_runner(self) -> WorkerActionRunLoopManager:
251254
# Retrieve the shared queue

0 commit comments

Comments
 (0)