|
7 | 7 | from settings_library.redis import RedisDatabase, RedisSettings |
8 | 8 |
|
9 | 9 | from ...redis._client import RedisClientSDK |
| 10 | +from ...redis._utils import handle_redis_returns_union_types |
10 | 11 | from ..models import TaskContext, TaskData, TaskId |
11 | 12 | from .base import BaseStore |
12 | 13 |
|
@@ -46,41 +47,47 @@ def _get_key(self, store_type: str, name: str) -> str: |
46 | 47 | return f"{self.namespace}:{store_type}:{name}" |
47 | 48 |
|
48 | 49 | async def get_task_data(self, task_id: TaskId) -> TaskData | None: |
49 | | - result: Any | None = await self._redis.hget( |
50 | | - self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), task_id |
51 | | - ) # type: ignore[misc] |
| 50 | + result: Any | None = await handle_redis_returns_union_types( |
| 51 | + self._redis.hget(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), task_id) |
| 52 | + ) |
52 | 53 | return TypeAdapter(TaskData).validate_json(result) if result else None |
53 | 54 |
|
54 | 55 | async def set_task_data(self, task_id: TaskId, value: TaskData) -> None: |
55 | 56 | _logger.debug( |
56 | 57 | "Setting task data for task_id=%s with data value=%s", task_id, value |
57 | 58 | ) |
58 | | - await self._redis.hset( |
59 | | - self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), |
60 | | - task_id, |
61 | | - value.model_dump_json(), |
62 | | - ) # type: ignore[misc] |
| 59 | + await handle_redis_returns_union_types( |
| 60 | + self._redis.hset( |
| 61 | + self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), |
| 62 | + task_id, |
| 63 | + value.model_dump_json(), |
| 64 | + ) |
| 65 | + ) |
63 | 66 |
|
64 | 67 | async def list_tasks_data(self) -> list[TaskData]: |
65 | | - result: list[Any] = await self._redis.hvals( |
66 | | - self._get_redis_hash_key(_STORE_TYPE_TASK_DATA) |
67 | | - ) # type: ignore[misc] |
| 68 | + result: list[Any] = await handle_redis_returns_union_types( |
| 69 | + self._redis.hvals(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA)) |
| 70 | + ) |
68 | 71 | return [TypeAdapter(TaskData).validate_json(item) for item in result] |
69 | 72 |
|
70 | 73 | async def delete_task_data(self, task_id: TaskId) -> None: |
71 | | - await self._redis.hdel(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), task_id) # type: ignore[misc] |
| 74 | + await handle_redis_returns_union_types( |
| 75 | + self._redis.hdel(self._get_redis_hash_key(_STORE_TYPE_TASK_DATA), task_id) |
| 76 | + ) |
72 | 77 |
|
73 | 78 | async def set_as_cancelled( |
74 | 79 | self, task_id: TaskId, with_task_context: TaskContext |
75 | 80 | ) -> None: |
76 | | - await self._redis.hset( |
77 | | - self._get_redis_hash_key(_STORE_TYPE_CANCELLED_TASKS), |
78 | | - task_id, |
79 | | - json_dumps(with_task_context), |
80 | | - ) # type: ignore[misc] |
| 81 | + await handle_redis_returns_union_types( |
| 82 | + self._redis.hset( |
| 83 | + self._get_redis_hash_key(_STORE_TYPE_CANCELLED_TASKS), |
| 84 | + task_id, |
| 85 | + json_dumps(with_task_context), |
| 86 | + ) |
| 87 | + ) |
81 | 88 |
|
82 | 89 | async def get_cancelled(self) -> dict[TaskId, TaskContext]: |
83 | | - result: dict[str, str | None] = await self._redis.hgetall( |
84 | | - self._get_redis_hash_key(_STORE_TYPE_CANCELLED_TASKS) |
85 | | - ) # type: ignore[misc] |
| 90 | + result: dict[str, str | None] = await handle_redis_returns_union_types( |
| 91 | + self._redis.hgetall(self._get_redis_hash_key(_STORE_TYPE_CANCELLED_TASKS)) |
| 92 | + ) |
86 | 93 | return {task_id: json_loads(context) for task_id, context in result.items()} |
0 commit comments