Skip to content

Commit 80418f9

Browse files
committed
Add distributed lock for scheduled task
1 parent 016361b commit 80418f9

File tree

3 files changed

+87
-52
lines changed

3 files changed

+87
-52
lines changed

backend/app/task/service/scheduler_service.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,7 @@ async def create(*, obj: CreateTaskSchedulerParam) -> None:
6565
if task_scheduler:
6666
raise errors.ConflictError(msg='任务调度已存在')
6767
if obj.type == TaskSchedulerType.CRONTAB:
68-
crontab_split = obj.crontab.split(' ')
69-
if len(crontab_split) != 5:
70-
raise errors.RequestError(msg='Crontab 表达式非法')
71-
crontab_verify('m', crontab_split[0])
72-
crontab_verify('h', crontab_split[1])
73-
crontab_verify('dow', crontab_split[2])
74-
crontab_verify('dom', crontab_split[3])
75-
crontab_verify('moy', crontab_split[4])
68+
crontab_verify(obj.crontab)
7669
await task_scheduler_dao.create(db, obj)
7770

7871
@staticmethod
@@ -92,14 +85,7 @@ async def update(*, pk: int, obj: UpdateTaskSchedulerParam) -> int:
9285
if await task_scheduler_dao.get_by_name(db, obj.name):
9386
raise errors.ConflictError(msg='任务调度已存在')
9487
if task_scheduler.type == TaskSchedulerType.CRONTAB:
95-
crontab_split = obj.crontab.split(' ')
96-
if len(crontab_split) != 5:
97-
raise errors.RequestError(msg='Crontab 表达式非法')
98-
crontab_verify('m', crontab_split[0])
99-
crontab_verify('h', crontab_split[1])
100-
crontab_verify('dow', crontab_split[2])
101-
crontab_verify('dom', crontab_split[3])
102-
crontab_verify('moy', crontab_split[4])
88+
crontab_verify(obj.crontab)
10389
count = await task_scheduler_dao.update(db, pk, obj)
10490
return count
10591

backend/app/task/utils/schedulers.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from celery import current_app, schedules
1111
from celery.beat import ScheduleEntry, Scheduler
12+
from celery.signals import beat_init
1213
from celery.utils.log import get_logger
1314
from sqlalchemy import select
1415
from sqlalchemy.exc import DatabaseError, InterfaceError
@@ -28,6 +29,35 @@
2829
# 此计划程序必须比常规的 5 分钟更频繁地唤醒,因为它需要考虑对计划的外部更改
2930
DEFAULT_MAX_INTERVAL = 5 # seconds
3031

32+
# 计划锁定时长,避免重复运行
33+
DEFAULT_MAX_LOCK = 300 # seconds
34+
35+
# Copied from:
36+
# https://github.com/andymccurdy/redis-py/blob/master/redis/lock.py#L33
37+
# Changes:
38+
# The second line from the bottom: The original Lua script intends
39+
# to extend time to (lock remaining time + additional time); while
40+
# the script here extend time to an expected expiration time.
41+
# KEYS[1] - lock name
42+
# ARGS[1] - token
43+
# ARGS[2] - additional milliseconds
44+
# return 1 if the locks time was extended, otherwise 0
45+
LUA_EXTEND_TO_SCRIPT = """
46+
local token = redis.call('get', KEYS[1])
47+
if not token or token ~= ARGV[1] then
48+
return 0
49+
end
50+
local expiration = redis.call('pttl', KEYS[1])
51+
if not expiration then
52+
expiration = 0
53+
end
54+
if expiration < 0 then
55+
return 0
56+
end
57+
redis.call('pexpire', KEYS[1], ARGV[2])
58+
return 1
59+
"""
60+
3161
logger = get_logger('fba.schedulers')
3262

3363

@@ -188,21 +218,12 @@ async def to_model_schedule(name: str, task: str, schedule: schedules.schedule |
188218
if not obj:
189219
obj = TaskScheduler(**CreateTaskSchedulerParam(task=task, **spec).model_dump())
190220
elif isinstance(schedule, schedules.crontab):
191-
crontab_minute = schedule._orig_minute if crontab_verify('m', schedule._orig_minute, False) else '*'
192-
crontab_hour = schedule._orig_hour if crontab_verify('h', schedule._orig_hour, False) else '*'
193-
crontab_day_of_week = (
194-
schedule._orig_day_of_week if crontab_verify('dom', schedule._orig_day_of_week, False) else '*'
195-
)
196-
crontab_day_of_month = (
197-
schedule._orig_day_of_month if crontab_verify('dom', schedule._orig_day_of_month, False) else '*'
198-
)
199-
crontab_month_of_year = (
200-
schedule._orig_month_of_year if crontab_verify('moy', schedule._orig_month_of_year, False) else '*'
201-
)
221+
crontab = f'{schedule._orig_minute} {schedule._orig_hour} {schedule._orig_day_of_week} {schedule._orig_day_of_month} {schedule._orig_month_of_year}' # noqa: E501
222+
crontab_verify(crontab)
202223
spec = {
203224
'name': name,
204225
'type': TaskSchedulerType.CRONTAB.value,
205-
'crontab': f'{crontab_minute} {crontab_hour} {crontab_day_of_week} {crontab_day_of_month} {crontab_month_of_year}', # noqa: E501
226+
'crontab': crontab,
206227
}
207228
stmt = select(TaskScheduler).filter_by(**spec)
208229
query = await db.execute(stmt)
@@ -269,13 +290,18 @@ def _unpack_options(
269290

270291

271292
class DatabaseScheduler(Scheduler):
293+
"""数据库调度程序"""
294+
272295
Entry = ModelEntry
273296

274297
_schedule = None
275298
_last_update = None
276299
_initial_read = True
277300
_heap_invalidated = False
278301

302+
lock = None
303+
lock_key = f'{settings.CELERY_REDIS_PREFIX}:beat_lock'
304+
279305
def __init__(self, *args, **kwargs):
280306
self.app = kwargs['app']
281307
self._dirty = set()
@@ -324,6 +350,16 @@ def reserve(self, entry):
324350
self._dirty.add(new_entry.name)
325351
return new_entry
326352

353+
def close(self):
354+
"""重写父函数"""
355+
if self.lock:
356+
logger.info('beat: Releasing lock')
357+
if run_await(self.lock.owned)():
358+
run_await(self.lock.release)()
359+
self.lock = None
360+
361+
self.sync()
362+
327363
def sync(self):
328364
"""重写父函数"""
329365
_tried = set()
@@ -410,3 +446,29 @@ def schedule(self) -> dict[str, ModelEntry]:
410446

411447
# logger.debug(self._schedule)
412448
return self._schedule
449+
450+
451+
@beat_init.connect
452+
def acquire_distributed_beat_lock(sender=None, *args, **kwargs):
453+
"""
454+
尝试在启动时获取锁
455+
456+
:param sender: 接收方应响应的发送方
457+
:return:
458+
"""
459+
scheduler = sender.scheduler
460+
if not scheduler.lock_key:
461+
return
462+
463+
logger.debug('beat: Acquiring lock...')
464+
lock = redis_client.lock(
465+
scheduler.lock_key,
466+
timeout=DEFAULT_MAX_LOCK,
467+
sleep=scheduler.max_interval,
468+
)
469+
# overwrite redis-py's extend script
470+
# which will add additional timeout instead of extend to a new timeout
471+
lock.lua_extend = redis_client.register_script(LUA_EXTEND_TO_SCRIPT)
472+
run_await(lock.acquire)()
473+
logger.info('beat: Acquired lock')
474+
scheduler.lock = lock

backend/app/task/utils/tzcrontab.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33
from datetime import datetime
4-
from typing import Literal
54

65
from celery import schedules
76
from celery.schedules import ParseException, crontab_parser
@@ -53,34 +52,22 @@ def __reduce__(self) -> tuple[type, tuple[str, str, str, str, str], None]:
5352
)
5453

5554

56-
def crontab_verify(filed: Literal['m', 'h', 'dow', 'dom', 'moy'], value: str, raise_exc: bool = True) -> bool:
55+
def crontab_verify(crontab: str) -> None:
5756
"""
5857
验证 Celery crontab 表达式
5958
60-
:param filed: 验证的字段
61-
:param value: 验证的值
62-
:param raise_exc: 是否抛出异常
59+
:param crontab: 计划表达式
6360
:return:
6461
"""
65-
valid = True
62+
crontab_split = crontab.split(' ')
63+
if len(crontab_split) != 5:
64+
raise errors.RequestError(msg='Crontab 表达式非法')
6665

6766
try:
68-
match filed:
69-
case 'm':
70-
crontab_parser(60, 0).parse(value)
71-
case 'h':
72-
crontab_parser(24, 0).parse(value)
73-
case 'dow':
74-
crontab_parser(7, 0).parse(value)
75-
case 'dom':
76-
crontab_parser(31, 1).parse(value)
77-
case 'moy':
78-
crontab_parser(12, 1).parse(value)
79-
case _:
80-
raise errors.ServerError(msg=f'无效字段:{filed}')
67+
crontab_parser(60, 0).parse(crontab_split[0]) # minute
68+
crontab_parser(24, 0).parse(crontab_split[1]) # hour
69+
crontab_parser(7, 0).parse(crontab_split[2]) # day_of_week
70+
crontab_parser(31, 1).parse(crontab_split[3]) # day_of_month
71+
crontab_parser(12, 1).parse(crontab_split[4]) # month_of_year
8172
except ParseException:
82-
valid = False
83-
if raise_exc:
84-
raise errors.RequestError(msg=f'crontab 值 {value} 非法')
85-
86-
return valid
73+
raise errors.RequestError(msg='Crontab 表达式非法')

0 commit comments

Comments
 (0)