Skip to content

Commit a2cb434

Browse files
fix: typecheck
1 parent 8e58365 commit a2cb434

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ async def consume_task_events(
172172
continue
173173

174174
try:
175-
event = TypeAdapter(TaskEvent).validate_json(raw_event)
175+
event: TaskEvent = TypeAdapter(TaskEvent).validate_json(
176+
raw_event
177+
)
176178
event.event_id = msg_id
177179
yield event
178180
except ValidationError:

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ async def consume_task_events(
192192
self,
193193
task_filter: TaskFilter,
194194
task_uuid: TaskUUID,
195-
last_id: str,
195+
last_id: str | None = None,
196196
) -> AsyncIterator[TaskEvent]:
197197
task_id = task_filter.create_task_id(task_uuid=task_uuid)
198198
async for event in self._task_info_store.consume_task_events(

packages/service-library/src/servicelib/celery/models.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,23 @@ class TaskMetadata(BaseModel):
127127
queue: TasksQueue = TasksQueue.DEFAULT
128128

129129

130-
class TaskDataEvent(BaseModel):
131-
type: Literal["data"] = "data"
130+
class BaseTaskEvent(BaseModel):
132131
event_id: str | None = None
132+
133+
134+
class TaskDataEvent(BaseTaskEvent):
135+
type: Literal["data"]
133136
data: Any
134137

135138

136-
class TaskStatusEvent(BaseModel):
137-
type: Literal["status"] = "status"
138-
event_id: str | None = None
139+
class TaskStatusEvent(BaseTaskEvent):
140+
type: Literal["status"]
139141
data: Literal["done", "error"]
140142

141143

142-
TaskEvent: TypeAlias = Annotated[
143-
TaskDataEvent | TaskStatusEvent, Field(discriminator="type")
144+
TaskEvent = Annotated[
145+
TaskDataEvent | TaskStatusEvent,
146+
Field(discriminator="type"),
144147
]
145148

146149

@@ -213,7 +216,7 @@ async def publish_task_event(self, task_id: TaskID, event: TaskEvent) -> None: .
213216
def consume_task_events(
214217
self,
215218
task_id: TaskID,
216-
last_id: str,
219+
last_id: str | None,
217220
) -> AsyncIterator[TaskEvent]: ...
218221

219222

0 commit comments

Comments
 (0)