|
1 | 1 | import contextlib |
| 2 | +import json |
2 | 3 | import logging |
| 4 | +from collections.abc import AsyncIterable |
3 | 5 | from datetime import timedelta |
4 | 6 | from typing import Final |
5 | 7 |
|
@@ -125,6 +127,36 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: |
125 | 127 | async def remove_task(self, task_id: TaskID) -> None: |
126 | 128 | await self._redis_client_sdk.redis.delete(_build_key(task_id)) |
127 | 129 |
|
| 130 | + async def append_task_result(self, task_id: TaskID, data: dict) -> None: |
| 131 | + stream_key = f"task:{task_id}" |
| 132 | + await self._redis_client_sdk.redis.xadd( |
| 133 | + stream_key, |
| 134 | + data, |
| 135 | + ) |
| 136 | + |
| 137 | + async def stream_task_result( |
| 138 | + self, task_id: str, last_id: str = "0-0" |
| 139 | + ) -> AsyncIterable[dict]: |
| 140 | + stream_key = f"task:{task_id}" |
| 141 | + while True: |
| 142 | + result = await self._redis_client_sdk.redis.xread( |
| 143 | + {stream_key: last_id}, block=5000 |
| 144 | + ) |
| 145 | + if not result: |
| 146 | + continue |
| 147 | + |
| 148 | + for _, entries in result: |
| 149 | + for entry_id, fields in entries: |
| 150 | + last_id = entry_id |
| 151 | + data = { |
| 152 | + k: json.loads(v) if k == "data" else v |
| 153 | + for k, v in fields.items() |
| 154 | + } |
| 155 | + yield data |
| 156 | + |
| 157 | + if data.get("type") == "done": |
| 158 | + return |
| 159 | + |
128 | 160 | async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None: |
129 | 161 | await self._redis_client_sdk.redis.hset( |
130 | 162 | name=_build_key(task_id), |
|
0 commit comments