Skip to content

Commit 3f2e813

Browse files
authored
Merge pull request #2 from d3nbr0/v1.1
aiocarrot task scheduler
2 parents 8692be1 + 7c6bb29 commit 3f2e813

File tree

10 files changed

+294
-47
lines changed

10 files changed

+294
-47
lines changed

aiocarrot/__meta__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.0.7'
1+
__version__ = '1.1.0'

aiocarrot/carrot.py

Lines changed: 68 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from loguru import logger
44

5+
from .scheduler import Scheduler
6+
57
from typing import Optional, TYPE_CHECKING
68

7-
import asyncio, aio_pika, ujson, uuid, copy, signal
9+
import asyncio, aio_pika, ujson, uuid, copy, signal, aiormq
810

911
if TYPE_CHECKING:
1012
from aiormq.abc import ConfirmationFrameType
@@ -23,6 +25,7 @@ class Carrot:
2325
_connection: Optional['aio_pika.abc.AbstractConnection'] = None
2426
_channel: Optional['aio_pika.abc.AbstractChannel'] = None
2527
_queue: Optional['aio_pika.abc.AbstractQueue'] = None
28+
_scheduler: Optional['Scheduler'] = None
2629

2730
def __init__(self, url: str, queue_name: str) -> None:
2831
"""
@@ -35,6 +38,7 @@ def __init__(self, url: str, queue_name: str) -> None:
3538
self._url = url
3639
self._tasks = []
3740
self._queue_name = queue_name
41+
self._scheduler = Scheduler(carrot=self)
3842

3943
async def send(self, _cnm: str, **kwargs) -> 'ConfirmationFrameType':
4044
"""
@@ -76,6 +80,19 @@ def setup_consumer(self, consumer: 'Consumer') -> None:
7680

7781
self._consumer = consumer
7882

83+
self._scheduler.clear()
84+
self._scheduler.stop()
85+
86+
for _, message in self._consumer._messages.items():
87+
if not message.schedule:
88+
continue
89+
90+
self._scheduler.add_task(message)
91+
92+
if self._is_consumer_alive and self._scheduler.has_tasks:
93+
scheduler_task = asyncio.create_task(self._scheduler.reload())
94+
self._tasks.append(scheduler_task)
95+
7996
async def run(self) -> None:
8097
"""
8198
Starts the main loop of the Carrot new message listener
@@ -99,6 +116,10 @@ async def run(self) -> None:
99116
logger.info(f' * {message_name}')
100117

101118
logger.info('')
119+
120+
if self._scheduler.has_tasks:
121+
asyncio.create_task(self._scheduler.start())
122+
102123
logger.info('Starting listener loop...')
103124

104125
signal.signal(signal.SIGINT, self._exit_signal_handler)
@@ -131,6 +152,7 @@ async def shutdown(self, silent: bool = False) -> None:
131152
:return:
132153
"""
133154

155+
self._scheduler.stop()
134156
pending_tasks = [x for x in self._tasks if not x.done()]
135157

136158
if len(pending_tasks) > 0:
@@ -163,51 +185,62 @@ async def _consumer_loop(self) -> None:
163185
logger.info('Consumer is successfully connected to queue')
164186

165187
async with queue.iterator() as queue_iterator:
166-
async for message in queue_iterator:
167-
for task in copy.copy(self._tasks):
168-
if task.done():
169-
self._tasks.remove(task)
188+
if not self._is_consumer_alive:
189+
return
190+
191+
try:
192+
await self._iterate_queue(queue_iterator)
193+
except aiormq.ChannelClosed:
194+
return
195+
196+
async def _iterate_queue(self, queue_iterator: 'aio_pika.abc.AbstractQueueIterator') -> None:
197+
""" Iterates over the queue iterator and passes the message on to the handler """
170198

171-
async with message.process():
172-
decoded_message: str = message.body.decode()
199+
async for message in queue_iterator:
200+
for task in copy.copy(self._tasks):
201+
if task.done():
202+
self._tasks.remove(task)
173203

174-
try:
175-
message_payload = ujson.loads(decoded_message)
204+
async with message.process():
205+
decoded_message: str = message.body.decode()
176206

177-
assert isinstance(message_payload, dict)
178-
except ujson.JSONDecodeError:
179-
logger.error(f'Error receiving the message (failed to receive JSON): {decoded_message}')
180-
continue
207+
try:
208+
message_payload = ujson.loads(decoded_message)
181209

182-
message_id = message_payload.get('_cid')
183-
message_name = message_payload.get('_cnm')
210+
assert isinstance(message_payload, dict)
211+
except ujson.JSONDecodeError:
212+
logger.error(f'Error receiving the message (failed to receive JSON): {decoded_message}')
213+
continue
184214

185-
if not message_id:
186-
logger.error(
187-
'The message format could not be determined (identifier is missing): '
188-
f'{message_payload}'
189-
)
215+
message_id = message_payload.get('_cid')
216+
message_name = message_payload.get('_cnm')
190217

191-
continue
218+
if not message_id:
219+
logger.error(
220+
'The message format could not be determined (identifier is missing): '
221+
f'{message_payload}'
222+
)
223+
224+
continue
192225

193-
if not message_name:
194-
logger.error(
195-
'The message format could not be determined (message name is missing): '
196-
f'{message_payload}'
197-
)
226+
if not message_name:
227+
logger.error(
228+
'The message format could not be determined (message name is missing): '
229+
f'{message_payload}'
230+
)
198231

199-
continue
232+
continue
200233

201-
del message_payload['_cid']
202-
del message_payload['_cnm']
234+
del message_payload['_cid']
235+
del message_payload['_cnm']
203236

204-
task = asyncio.create_task(self._consumer.on_message(
205-
message_id,
206-
message_name,
207-
**message_payload,
208-
))
237+
task = asyncio.create_task(self._consumer.on_message(
238+
message_id,
239+
message_name,
240+
**message_payload,
241+
))
209242

210-
self._tasks.append(task)
243+
self._tasks.append(task)
211244

212245
async def _get_queue(self) -> 'aio_pika.abc.AbstractQueue':
213246
"""

aiocarrot/consumer/consumer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from loguru import logger
77

88
from abc import ABC
9-
from typing import TYPE_CHECKING
9+
from typing import Optional, TYPE_CHECKING
1010
from copy import deepcopy
1111

1212
if TYPE_CHECKING:
@@ -52,7 +52,7 @@ class Consumer(AbstractConsumer):
5252
def __init__(self) -> None:
5353
self._messages = {}
5454

55-
def create_message(self, name: str, handler: 'Callable') -> None:
55+
def create_message(self, name: str, handler: 'Callable', schedule: Optional[str] = None) -> None:
5656
split_name = name.split()
5757

5858
if len(split_name) != 1:
@@ -62,19 +62,20 @@ def create_message(self, name: str, handler: 'Callable') -> None:
6262
name=name,
6363
handler=handler,
6464
dependant=get_dependant(handler=handler),
65+
schedule=schedule,
6566
)
6667

6768
self._messages[name] = message
6869

69-
def message(self, name: str | list[str]) -> 'Callable':
70+
def message(self, name: str | list[str], schedule: Optional[str] = None) -> 'Callable':
7071
def decorator(func: 'Callable') -> 'Callable':
7172
if not isinstance(name, list):
7273
message_names = [name]
7374
else:
7475
message_names = name
7576

7677
for message_name in message_names:
77-
self.create_message(name=message_name, handler=func)
78+
self.create_message(name=message_name, handler=func, schedule=schedule)
7879

7980
return func
8081

aiocarrot/consumer/types.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from pydantic import TypeAdapter, ValidationError, ConfigDict
1+
from typing import Optional, Callable, NamedTuple, TypeVar, Annotated
2+
from dataclasses import dataclass
3+
4+
from pydantic import TypeAdapter, ValidationError
25
from pydantic.fields import FieldInfo
36
from pydantic_core import PydanticUndefined as Undefined
47

5-
from typing import Callable, NamedTuple, TypeVar, Annotated
6-
from dataclasses import dataclass
7-
88

99
T = TypeVar('T')
1010

@@ -33,7 +33,7 @@ def __post_init__(self) -> None:
3333
Annotated[self.field_info.annotation, self.field_info],
3434
)
3535

36-
def validate(self, value: any) -> tuple[T | None, str | None]:
36+
def validate(self, value: any) -> tuple[Optional[T], Optional[str]]:
3737
try:
3838
return self._type_adapter.validate_python(value, from_attributes=True), None
3939
except ValidationError as exc:
@@ -48,9 +48,10 @@ def __init__(self, *, params: list[Field] = None):
4848

4949

5050
class Message(NamedTuple):
51-
name: str | None
51+
name: Optional[str]
5252
handler: Callable
5353
dependant: Dependant
54+
schedule: Optional[str] = None
5455

5556

5657
__all__ = (

aiocarrot/scheduler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .scheduler import Scheduler

aiocarrot/scheduler/scheduler.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import asyncio
2+
3+
from typing import TYPE_CHECKING
4+
from datetime import datetime, timezone
5+
6+
from loguru import logger
7+
8+
from .types import Task
9+
10+
if TYPE_CHECKING:
11+
from ..carrot import Carrot
12+
from ..consumer.types import Message
13+
14+
15+
class Scheduler:
16+
""" Common Carrot task scheduler """
17+
18+
_carrot: 'Carrot'
19+
_tasks: list[Task] = None
20+
_is_alive: bool = False
21+
_max_interval: int = 10
22+
_adjust: float = 0.001
23+
24+
def __init__(self, carrot: 'Carrot') -> None:
25+
""" The basic and basic task scheduler for the aiocarrot framework """
26+
27+
self._carrot = carrot
28+
self._tasks = []
29+
30+
def add_task(self, message: 'Message') -> None:
31+
if not message.schedule:
32+
raise AttributeError(f'Message <{message.name}> has no schedule attribute')
33+
34+
try:
35+
task = Task(message=message, carrot=self._carrot)
36+
except (ValueError, TypeError):
37+
return logger.error(f'Issue cron format error for message <{message.name}>')
38+
39+
self._tasks.append(task)
40+
41+
async def start(self) -> None:
42+
if self._is_alive:
43+
return logger.error('Scheduler already running')
44+
45+
if len(self._tasks) == 0:
46+
return logger.warning('Scheduler cannot be started because there are no scheduled tasks')
47+
48+
logger.info(f'Scheduler registered {len(self._tasks)} messages')
49+
self._is_alive = True
50+
51+
while self._is_alive:
52+
next_task_at = self._get_next_time()
53+
now = datetime.now(timezone.utc)
54+
55+
delta = (next_task_at - now).total_seconds()
56+
if delta < 0:
57+
delta = self._max_interval
58+
59+
interval = min(delta, self._max_interval) + self._adjust
60+
61+
await asyncio.sleep(interval)
62+
63+
now = datetime.now(timezone.utc)
64+
65+
for task in self._tasks:
66+
if not task.is_due(now):
67+
continue
68+
69+
await task.run()
70+
71+
async def reload(self) -> None:
72+
logger.info(f'Reloading scheduler after changing consumer...')
73+
74+
await asyncio.sleep(5)
75+
await self.start()
76+
77+
def stop(self) -> None:
78+
self._is_alive = False
79+
80+
def clear(self) -> None:
81+
self._tasks = []
82+
83+
@property
84+
def has_tasks(self) -> bool:
85+
return self._tasks and len(self._tasks) > 0
86+
87+
def _get_next_time(self) -> datetime:
88+
return min(self._tasks, key=lambda x: x.next_run).next_run
89+
90+
91+
__all__ = (
92+
'Scheduler',
93+
)

aiocarrot/scheduler/types.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import TYPE_CHECKING
2+
from datetime import datetime, timezone
3+
4+
from loguru import logger
5+
6+
from croniter import croniter
7+
8+
if TYPE_CHECKING:
9+
from ..carrot import Carrot
10+
from ..consumer.types import Message
11+
12+
13+
class Task:
14+
""" The basic class of the task for the scheduler """
15+
16+
_carrot: 'Carrot'
17+
_message: 'Message'
18+
_cron: croniter
19+
_next_run: datetime
20+
_resync_at_next_tick: bool = False
21+
22+
def __init__(self, carrot: 'Carrot', message: 'Message') -> None:
23+
self._carrot = carrot
24+
self._message = message
25+
26+
now = datetime.now(timezone.utc)
27+
28+
self._cron = croniter(message.schedule, now)
29+
self._sync(now)
30+
31+
async def run(self):
32+
logger.info(f'Scheduled task <{self._message.name}> has been queued')
33+
await self._carrot.send(self._message.name)
34+
self._resync_at_next_tick = True
35+
36+
def is_due(self, now: datetime) -> bool:
37+
if self._resync_at_next_tick:
38+
self._sync(now)
39+
40+
return now >= self._next_run
41+
42+
@property
43+
def next_run(self) -> datetime:
44+
return self._next_run
45+
46+
def _sync(self, start_datetime: datetime) -> None:
47+
self._next_run = self._cron.get_next(datetime, start_time=start_datetime)
48+
self._resync_at_next_tick = False
49+
50+
51+
__all__ = (
52+
'Task',
53+
)

0 commit comments

Comments
 (0)