Skip to content

Commit e2d171b

Browse files
committed
feat: add schedule source
1 parent e452db0 commit e2d171b

File tree

10 files changed

+270
-82
lines changed

10 files changed

+270
-82
lines changed

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
services:
22
ydb:
3-
image: ydbplatform/local-ydb:24.4
3+
image: ydbplatform/local-ydb:25.2
44
platform: linux/amd64
55
hostname: localhost
66
ports:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from taskiq_ydb.broker import YdbBroker
22
from taskiq_ydb.result_backend import YdbResultBackend
3+
from taskiq_ydb.schedule_source import YdbScheduleSource
34

45

56
__all__ = [
67
'YdbBroker',
78
'YdbResultBackend',
9+
'YdbScheduleSource',
810
]
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@ class YdbBroker(AsyncBroker):
2121
def __init__(
2222
self,
2323
driver_config: ydb.aio.driver.DriverConfig,
24-
topic_path: str = 'taskiq-tasks',
24+
topic_path: str = 'taskiq_tasks',
2525
connection_timeout: int = 5,
2626
read_timeout: int = 5,
2727
) -> None:
2828
"""
2929
Construct new broker.
3030
31-
:param driver_config: YDB driver configuration.
32-
:param topic_path: Path to the topic where tasks will be stored.
33-
:param connection_timeout: Timeout for connection to database during startup.
34-
:param read_timeout: Timeout for read topic operations.
31+
Args:
32+
driver_config: YDB driver configuration.
33+
topic_path: Path to the topic where tasks will be stored.
34+
connection_timeout: Timeout for connection to database during startup.
35+
read_timeout: Timeout for read topic operations.
3536
"""
3637
super().__init__()
3738
self._driver = ydb.aio.Driver(driver_config=driver_config)
Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ def __init__( # noqa: PLR0913
3636
"""
3737
Construct new result backend.
3838
39-
:param driver_config: YDB driver configuration.
40-
:param table_name: Table name for storing task results.
41-
:param table_primary_key_type: Type of primary key in table.
42-
:param serializer: Serializer for task results.
43-
:param pool_size: YDB session pool size.
44-
:param connection_timeout: Timeout for connection to database during startup.
45-
39+
Args:
40+
driver_config: YDB driver configuration.
41+
table_name: Table name for storing task results.
42+
table_primary_key_type: Type of primary key in table.
43+
serializer: Serializer for task results.
44+
pool_size: YDB session pool size.
45+
connection_timeout: Timeout for connection to database during startup.
4646
"""
4747
self._driver = ydb.aio.Driver(driver_config=driver_config)
4848
self._table_name: tp.Final = table_name
@@ -101,8 +101,9 @@ async def set_result(
101101
"""
102102
Set result to the YDB table.
103103
104-
:param task_id: ID of the task.
105-
:param result: result of the task
104+
Args:
105+
task_id: ID of the task.
106+
result: result of the task
106107
"""
107108
task_id_in_ydb = uuid.UUID(task_id) if self._table_primary_key_type == ydb.PrimitiveType.UUID else task_id
108109
query = f"""
@@ -130,7 +131,8 @@ async def is_result_ready(
130131
"""
131132
Return whether the result is ready.
132133
133-
:param task_id: ID of the task.
134+
Args:
135+
task_id: ID of the task.
134136
"""
135137
task_id_in_ydb = uuid.UUID(task_id) if self._table_primary_key_type == ydb.PrimitiveType.UUID else task_id
136138
query = f"""
@@ -158,10 +160,15 @@ async def get_result(
158160
"""
159161
Retrieve result from the task.
160162
161-
:param task_id: task's id.
162-
:param with_logs: if True it will download task's logs.
163-
:raises ResultIsMissingError: if there is no result when trying to get it.
164-
:return: TaskiqResult.
163+
Args:
164+
task_id: task's id.
165+
with_logs: if True it will download task's logs.
166+
167+
Raises:
168+
ResultIsMissingError: if there is no result when trying to get it.
169+
170+
Returns:
171+
TaskiqResult.
165172
"""
166173
task_id_in_ydb = uuid.UUID(task_id) if self._table_primary_key_type == ydb.PrimitiveType.UUID else task_id
167174
query = f"""

src/taskiq_ydb/schedule_source.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import logging
5+
import typing as tp
6+
import uuid
7+
8+
import ydb # type: ignore[import-untyped]
9+
import ydb.aio # type: ignore[import-untyped]
10+
from pydantic import ValidationError
11+
from taskiq import ScheduleSource
12+
from taskiq.scheduler.scheduled_task import ScheduledTask
13+
from taskiq.serializers import PickleSerializer
14+
15+
from taskiq_ydb.exceptions import DatabaseConnectionError
16+
17+
18+
if tp.TYPE_CHECKING:
19+
from taskiq.abc.broker import AsyncBroker
20+
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class YdbScheduleSource(ScheduleSource):
26+
"""Schedule source that uses YDB to store schedules in YDB database."""
27+
28+
def __init__(
29+
self,
30+
broker: AsyncBroker,
31+
driver_config: ydb.aio.driver.DriverConfig,
32+
table_name: str = 'taskiq_schedules',
33+
pool_size: int = 5,
34+
connection_timeout: int = 5,
35+
) -> None:
36+
"""
37+
Construct new schedule source.
38+
39+
Args:
40+
broker: The TaskIQ broker instance to use for finding and managing tasks.
41+
driver_config: YDB driver configuration.
42+
table_name: Table name for storing task results.
43+
pool_size: YDB session pool size.
44+
connection_timeout: Timeout for connection to database during startup.
45+
"""
46+
self._broker: tp.Final = broker
47+
self._driver = ydb.aio.Driver(driver_config=driver_config)
48+
self._table_name: tp.Final = table_name
49+
self._pool_size: tp.Final = pool_size
50+
self._pool: ydb.aio.SessionPool
51+
self._connection_timeout: tp.Final = connection_timeout
52+
self._serializer: tp.Final = PickleSerializer()
53+
54+
55+
async def startup(self) -> None:
56+
"""
57+
Initialize the result backend.
58+
59+
Construct new connection pool
60+
and create new table for results if not exists.
61+
"""
62+
try:
63+
logger.debug('Waiting for YDB driver to be ready')
64+
await self._driver.wait(fail_fast=True, timeout=self._connection_timeout)
65+
except (ydb.issues.ConnectionLost, asyncio.exceptions.TimeoutError) as exception:
66+
raise DatabaseConnectionError from exception
67+
self._pool = ydb.aio.SessionPool(self._driver, size=self._pool_size)
68+
session = await self._pool.acquire()
69+
70+
table_path = f'{self._driver._driver_config.database}/{self._table_name}' # noqa: SLF001
71+
try:
72+
logger.debug('Checking if table %s exists', self._table_name)
73+
existing_table = await session.describe_table(table_path)
74+
except ydb.issues.SchemeError:
75+
existing_table = None
76+
if not existing_table:
77+
logger.debug('Table %s does not exist, creating...', self._table_name)
78+
await session.create_table(
79+
table_path,
80+
ydb.TableDescription()
81+
.with_column(ydb.Column('id', ydb.PrimitiveType.UUID))
82+
.with_column(ydb.Column('task_name', ydb.PrimitiveType.Utf8))
83+
.with_column(ydb.Column('schedule', ydb.PrimitiveType.String))
84+
.with_primary_key('id'),
85+
)
86+
logger.debug('Table %s created', self._table_name)
87+
else:
88+
logger.debug('Table %s already exists', self._table_name)
89+
90+
# Load existing schedules from labels in tasks
91+
schedule_tasks = self._extract_scheduled_tasks_from_broker()
92+
for task in await self.get_schedules():
93+
await self.delete_schedule(task.schedule_id)
94+
for schedule_task in schedule_tasks:
95+
await self.add_schedule(schedule_task)
96+
97+
async def shutdown(self) -> None:
98+
"""Close the connection pool."""
99+
await asyncio.to_thread(self._driver.topic_client.close)
100+
if hasattr(self, '_pool'):
101+
await self._pool.stop(timeout=10)
102+
await self._driver.stop(timeout=10)
103+
104+
def _extract_scheduled_tasks_from_broker(self) -> list[ScheduledTask]:
105+
"""
106+
Extract schedules from tasks that were registered in broker.
107+
108+
Returns:
109+
A list of ScheduledTask instances extracted from the task's labels.
110+
"""
111+
scheduled_tasks_for_creation: list[ScheduledTask] = []
112+
for task_name, task in self._broker.get_all_tasks().items():
113+
if 'schedule' not in task.labels:
114+
logger.debug('Task %s has no schedule, skipping', task_name)
115+
continue
116+
if not isinstance(task.labels['schedule'], list):
117+
logger.warning(
118+
'Schedule for task %s is not a list, skipping',
119+
task_name,
120+
)
121+
continue
122+
for schedule in task.labels['schedule']:
123+
try:
124+
new_schedule = ScheduledTask.model_validate(
125+
{
126+
'task_name': task_name,
127+
'labels': schedule.get('labels', {}),
128+
'args': schedule.get('args', []),
129+
'kwargs': schedule.get('kwargs', {}),
130+
'schedule_id': str(uuid.uuid4()),
131+
'cron': schedule.get('cron', None),
132+
'cron_offset': schedule.get('cron_offset', None),
133+
'time': schedule.get('time', None),
134+
},
135+
)
136+
scheduled_tasks_for_creation.append(new_schedule)
137+
except ValidationError: # noqa: PERF203
138+
logger.exception(
139+
'Schedule for task %s is not valid, skipping',
140+
task_name,
141+
)
142+
continue
143+
return scheduled_tasks_for_creation
144+
145+
async def get_schedules(self) -> list[ScheduledTask]:
146+
"""Get list of taskiq schedules."""
147+
query = f"""
148+
SELECT schedule FROM {self._table_name};
149+
""" # noqa: S608
150+
session = await self._pool.acquire()
151+
result_sets = await session.transaction().execute(
152+
await session.prepare(query),
153+
commit_tx=True,
154+
)
155+
await self._pool.release(session)
156+
scheduled_tasks = []
157+
for result_set in result_sets:
158+
rows = result_set.rows
159+
for row in rows:
160+
scheduled_tasks.append( # noqa: PERF401
161+
self._serializer.loadb(row.schedule),
162+
)
163+
return scheduled_tasks
164+
165+
async def add_schedule(
166+
self,
167+
schedule: ScheduledTask,
168+
) -> None:
169+
"""
170+
Add a new schedule.
171+
172+
This function is used to add new schedules.
173+
It's a convenient helper for people who want to add new schedules
174+
for the current source.
175+
176+
Args:
177+
schedule: schedule to add.
178+
"""
179+
schedule_id = uuid.UUID(schedule.schedule_id)
180+
query = f"""
181+
DECLARE $id AS Uuid;
182+
DECLARE $task_name AS Utf8;
183+
DECLARE $schedule AS String;
184+
185+
UPSERT INTO {self._table_name} (id, task_name, schedule)
186+
VALUES ($id, $task_name, $schedule);
187+
"""
188+
session = await self._pool.acquire()
189+
await session.transaction().execute(
190+
await session.prepare(query),
191+
{
192+
'$id': schedule_id,
193+
'$task_name': schedule.task_name,
194+
'$schedule': self._serializer.dumpb(schedule),
195+
},
196+
commit_tx=True,
197+
)
198+
await self._pool.release(session)
199+
200+
async def delete_schedule(self, schedule_id: str) -> None:
201+
"""
202+
Method to delete schedule by id.
203+
204+
This is useful for schedule cancelation.
205+
206+
Args:
207+
schedule_id: id of schedule to delete.
208+
"""
209+
schedule_id_uuid = uuid.UUID(schedule_id)
210+
query = f"""
211+
DECLARE $id AS Uuid;
212+
213+
DELETE FROM {self._table_name}
214+
WHERE id = $id;
215+
""" # noqa: S608
216+
session = await self._pool.acquire()
217+
await session.transaction().execute(
218+
await session.prepare(query),
219+
{
220+
'$id': schedule_id_uuid,
221+
},
222+
commit_tx=True,
223+
)
224+
await self._pool.release(session)
225+
226+
async def post_send(
227+
self,
228+
task: ScheduledTask,
229+
) -> None:
230+
"""
231+
Delete schedule if it was one-time task.
232+
233+
Args:
234+
task: task that just have sent
235+
"""
236+
if task.time:
237+
await self.delete_schedule(task.schedule_id)

tests/conftest.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,6 @@
88
import taskiq_ydb
99

1010

11-
if tp.TYPE_CHECKING:
12-
import asyncio
13-
14-
15-
@pytest.fixture(scope='session')
16-
def event_loop() -> 'tp.Generator[asyncio.AbstractEventLoop, None]':
17-
import asyncio
18-
19-
loop = asyncio.new_event_loop()
20-
asyncio.set_event_loop(loop)
21-
22-
yield loop
23-
loop.close()
24-
25-
26-
2711
@pytest.fixture
2812
def driver_config() -> ydb.aio.driver.DriverConfig:
2913
return ydb.aio.driver.DriverConfig(

tests/test_broker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,16 @@ async def test_when_two_workers_are_listening__then_one_worker_receive_message(
116116
await ydb_broker.kick(valid_broker_message)
117117
await asyncio.sleep(0.3)
118118

119-
recieved = 0
119+
received = 0
120120
for task in [worker1_task, worker2_task]:
121121
try:
122122
task.result()
123-
recieved += 1
123+
received += 1
124124
except asyncio.exceptions.InvalidStateError: # noqa: PERF203
125125
pass
126126

127127
# then
128-
assert recieved == 1
128+
assert received == 1
129129

130130
worker1_task.cancel()
131131
worker2_task.cancel()

0 commit comments

Comments
 (0)