From 80418f93188ddce0d972c3700ee3ab7333705688 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Fri, 18 Jul 2025 21:05:04 +0800 Subject: [PATCH 1/3] Add distributed lock for scheduled task --- backend/app/task/service/scheduler_service.py | 18 +--- backend/app/task/utils/schedulers.py | 86 ++++++++++++++++--- backend/app/task/utils/tzcrontab.py | 35 +++----- 3 files changed, 87 insertions(+), 52 deletions(-) diff --git a/backend/app/task/service/scheduler_service.py b/backend/app/task/service/scheduler_service.py index 8c472101..5f474296 100644 --- a/backend/app/task/service/scheduler_service.py +++ b/backend/app/task/service/scheduler_service.py @@ -65,14 +65,7 @@ async def create(*, obj: CreateTaskSchedulerParam) -> None: if task_scheduler: raise errors.ConflictError(msg='任务调度已存在') if obj.type == TaskSchedulerType.CRONTAB: - crontab_split = obj.crontab.split(' ') - if len(crontab_split) != 5: - raise errors.RequestError(msg='Crontab 表达式非法') - crontab_verify('m', crontab_split[0]) - crontab_verify('h', crontab_split[1]) - crontab_verify('dow', crontab_split[2]) - crontab_verify('dom', crontab_split[3]) - crontab_verify('moy', crontab_split[4]) + crontab_verify(obj.crontab) await task_scheduler_dao.create(db, obj) @staticmethod @@ -92,14 +85,7 @@ async def update(*, pk: int, obj: UpdateTaskSchedulerParam) -> int: if await task_scheduler_dao.get_by_name(db, obj.name): raise errors.ConflictError(msg='任务调度已存在') if task_scheduler.type == TaskSchedulerType.CRONTAB: - crontab_split = obj.crontab.split(' ') - if len(crontab_split) != 5: - raise errors.RequestError(msg='Crontab 表达式非法') - crontab_verify('m', crontab_split[0]) - crontab_verify('h', crontab_split[1]) - crontab_verify('dow', crontab_split[2]) - crontab_verify('dom', crontab_split[3]) - crontab_verify('moy', crontab_split[4]) + crontab_verify(obj.crontab) count = await task_scheduler_dao.update(db, pk, obj) return count diff --git a/backend/app/task/utils/schedulers.py b/backend/app/task/utils/schedulers.py index b6a1119f..14b705c4 100644 --- a/backend/app/task/utils/schedulers.py +++ b/backend/app/task/utils/schedulers.py @@ -9,6 +9,7 @@ from celery import current_app, schedules from celery.beat import ScheduleEntry, Scheduler +from celery.signals import beat_init from celery.utils.log import get_logger from sqlalchemy import select from sqlalchemy.exc import DatabaseError, InterfaceError @@ -28,6 +29,35 @@ # 此计划程序必须比常规的 5 分钟更频繁地唤醒,因为它需要考虑对计划的外部更改 DEFAULT_MAX_INTERVAL = 5 # seconds +# 计划锁定时长,避免重复运行 +DEFAULT_MAX_LOCK = 300 # seconds + +# Copied from: +# https://github.com/andymccurdy/redis-py/blob/master/redis/lock.py#L33 +# Changes: +# The second line from the bottom: The original Lua script intends +# to extend time to (lock remaining time + additional time); while +# the script here extend time to an expected expiration time. +# KEYS[1] - lock name +# ARGS[1] - token +# ARGS[2] - additional milliseconds +# return 1 if the locks time was extended, otherwise 0 +LUA_EXTEND_TO_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 +""" + logger = get_logger('fba.schedulers') @@ -188,21 +218,12 @@ async def to_model_schedule(name: str, task: str, schedule: schedules.schedule | if not obj: obj = TaskScheduler(**CreateTaskSchedulerParam(task=task, **spec).model_dump()) elif isinstance(schedule, schedules.crontab): - crontab_minute = schedule._orig_minute if crontab_verify('m', schedule._orig_minute, False) else '*' - crontab_hour = schedule._orig_hour if crontab_verify('h', schedule._orig_hour, False) else '*' - crontab_day_of_week = ( - schedule._orig_day_of_week if crontab_verify('dom', schedule._orig_day_of_week, False) else '*' - ) - crontab_day_of_month = ( - schedule._orig_day_of_month if crontab_verify('dom', schedule._orig_day_of_month, False) else '*' - ) - crontab_month_of_year = ( - schedule._orig_month_of_year if crontab_verify('moy', schedule._orig_month_of_year, False) else '*' - ) + crontab = f'{schedule._orig_minute} {schedule._orig_hour} {schedule._orig_day_of_week} {schedule._orig_day_of_month} {schedule._orig_month_of_year}' # noqa: E501 + crontab_verify(crontab) spec = { 'name': name, 'type': TaskSchedulerType.CRONTAB.value, - 'crontab': f'{crontab_minute} {crontab_hour} {crontab_day_of_week} {crontab_day_of_month} {crontab_month_of_year}', # noqa: E501 + 'crontab': crontab, } stmt = select(TaskScheduler).filter_by(**spec) query = await db.execute(stmt) @@ -269,6 +290,8 @@ def _unpack_options( class DatabaseScheduler(Scheduler): + """数据库调度程序""" + Entry = ModelEntry _schedule = None @@ -276,6 +299,9 @@ class DatabaseScheduler(Scheduler): _initial_read = True _heap_invalidated = False + lock = None + lock_key = f'{settings.CELERY_REDIS_PREFIX}:beat_lock' + def __init__(self, *args, **kwargs): self.app = kwargs['app'] self._dirty = set() @@ -324,6 +350,16 @@ def reserve(self, entry): self._dirty.add(new_entry.name) return new_entry + def close(self): + """重写父函数""" + if self.lock: + logger.info('beat: Releasing lock') + if run_await(self.lock.owned)(): + run_await(self.lock.release)() + self.lock = None + + self.sync() + def sync(self): """重写父函数""" _tried = set() @@ -410,3 +446,29 @@ def schedule(self) -> dict[str, ModelEntry]: # logger.debug(self._schedule) return self._schedule + + +@beat_init.connect +def acquire_distributed_beat_lock(sender=None, *args, **kwargs): + """ + 尝试在启动时获取锁 + + :param sender: 接收方应响应的发送方 + :return: + """ + scheduler = sender.scheduler + if not scheduler.lock_key: + return + + logger.debug('beat: Acquiring lock...') + lock = redis_client.lock( + scheduler.lock_key, + timeout=DEFAULT_MAX_LOCK, + sleep=scheduler.max_interval, + ) + # overwrite redis-py's extend script + # which will add additional timeout instead of extend to a new timeout + lock.lua_extend = redis_client.register_script(LUA_EXTEND_TO_SCRIPT) + run_await(lock.acquire)() + logger.info('beat: Acquired lock') + scheduler.lock = lock diff --git a/backend/app/task/utils/tzcrontab.py b/backend/app/task/utils/tzcrontab.py index 9a995e2c..2260efc1 100644 --- a/backend/app/task/utils/tzcrontab.py +++ b/backend/app/task/utils/tzcrontab.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from datetime import datetime -from typing import Literal from celery import schedules from celery.schedules import ParseException, crontab_parser @@ -53,34 +52,22 @@ def __reduce__(self) -> tuple[type, tuple[str, str, str, str, str], None]: ) -def crontab_verify(filed: Literal['m', 'h', 'dow', 'dom', 'moy'], value: str, raise_exc: bool = True) -> bool: +def crontab_verify(crontab: str) -> None: """ 验证 Celery crontab 表达式 - :param filed: 验证的字段 - :param value: 验证的值 - :param raise_exc: 是否抛出异常 + :param crontab: 计划表达式 :return: """ - valid = True + crontab_split = crontab.split(' ') + if len(crontab_split) != 5: + raise errors.RequestError(msg='Crontab 表达式非法') try: - match filed: - case 'm': - crontab_parser(60, 0).parse(value) - case 'h': - crontab_parser(24, 0).parse(value) - case 'dow': - crontab_parser(7, 0).parse(value) - case 'dom': - crontab_parser(31, 1).parse(value) - case 'moy': - crontab_parser(12, 1).parse(value) - case _: - raise errors.ServerError(msg=f'无效字段:{filed}') + crontab_parser(60, 0).parse(crontab_split[0]) # minute + crontab_parser(24, 0).parse(crontab_split[1]) # hour + crontab_parser(7, 0).parse(crontab_split[2]) # day_of_week + crontab_parser(31, 1).parse(crontab_split[3]) # day_of_month + crontab_parser(12, 1).parse(crontab_split[4]) # month_of_year except ParseException: - valid = False - if raise_exc: - raise errors.RequestError(msg=f'crontab 值 {value} 非法') - - return valid + raise errors.RequestError(msg='Crontab 表达式非法') From 7aa3a66dfdf52155e5fa03b714640935568c6b30 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Mon, 21 Jul 2025 12:21:20 +0800 Subject: [PATCH 2/3] Add the task to extend lock --- backend/app/task/celery.py | 3 +- backend/app/task/utils/schedulers.py | 31 ++++++++++++++++--- backend/utils/_await.py | 46 ++++++++++++++++------------ 3 files changed, 55 insertions(+), 25 deletions(-) diff --git a/backend/app/task/celery.py b/backend/app/task/celery.py index e86e2991..8e9ffe1a 100644 --- a/backend/app/task/celery.py +++ b/backend/app/task/celery.py @@ -43,7 +43,8 @@ def init_celery() -> celery.Celery: 'group': OVERWRITE_CELERY_RESULT_GROUP_TABLE_NAME, }, result_extended=True, - # result_expires=0, # 任务结果自动清理,0 或 None 表示不清理 + # result_expires=0, # 清理任务结果,默认每天凌晨 4 点,0 或 None 表示不清理 + # beat_sync_every=1, # 保存任务状态周期,默认 3 * 60 秒 beat_schedule=LOCAL_BEAT_SCHEDULE, beat_scheduler='backend.app.task.utils.schedulers:DatabaseScheduler', task_cls='backend.app.task.tasks.base:TaskBase', diff --git a/backend/app/task/utils/schedulers.py b/backend/app/task/utils/schedulers.py index 14b705c4..ccdd3bcf 100644 --- a/backend/app/task/utils/schedulers.py +++ b/backend/app/task/utils/schedulers.py @@ -11,6 +11,7 @@ from celery.beat import ScheduleEntry, Scheduler from celery.signals import beat_init from celery.utils.log import get_logger +from redis.asyncio.lock import Lock from sqlalchemy import select from sqlalchemy.exc import DatabaseError, InterfaceError @@ -29,8 +30,11 @@ # 此计划程序必须比常规的 5 分钟更频繁地唤醒,因为它需要考虑对计划的外部更改 DEFAULT_MAX_INTERVAL = 5 # seconds -# 计划锁定时长,避免重复运行 -DEFAULT_MAX_LOCK = 300 # seconds +# 计划锁时长,避免重复创建 +DEFAULT_MAX_LOCK_TIMEOUT = 300 # seconds + +# 锁检测周期,应小于计划锁时长 +DEFAULT_LOCK_INTERVAL = 60 # seconds # Copied from: # https://github.com/andymccurdy/redis-py/blob/master/redis/lock.py#L33 @@ -299,7 +303,7 @@ class DatabaseScheduler(Scheduler): _initial_read = True _heap_invalidated = False - lock = None + lock: Lock | None = None lock_key = f'{settings.CELERY_REDIS_PREFIX}:beat_lock' def __init__(self, *args, **kwargs): @@ -448,6 +452,22 @@ def schedule(self) -> dict[str, ModelEntry]: return self._schedule +async def extend_scheduler_lock(lock): + """ + 延长调度程序锁 + + :param lock: 计划程序锁 + :return: + """ + while True: + await asyncio.sleep(DEFAULT_LOCK_INTERVAL) + if lock: + try: + await lock.extend(DEFAULT_MAX_LOCK_TIMEOUT) + except Exception as e: + logger.error(f'Failed to extend lock: {e}') + + @beat_init.connect def acquire_distributed_beat_lock(sender=None, *args, **kwargs): """ @@ -463,7 +483,7 @@ def acquire_distributed_beat_lock(sender=None, *args, **kwargs): logger.debug('beat: Acquiring lock...') lock = redis_client.lock( scheduler.lock_key, - timeout=DEFAULT_MAX_LOCK, + timeout=DEFAULT_MAX_LOCK_TIMEOUT, sleep=scheduler.max_interval, ) # overwrite redis-py's extend script @@ -472,3 +492,6 @@ def acquire_distributed_beat_lock(sender=None, *args, **kwargs): run_await(lock.acquire)() logger.info('beat: Acquired lock') scheduler.lock = lock + + loop = asyncio.get_event_loop() + loop.create_task(extend_scheduler_lock(scheduler.lock)) diff --git a/backend/utils/_await.py b/backend/utils/_await.py index ae61d976..e1950fde 100644 --- a/backend/utils/_await.py +++ b/backend/utils/_await.py @@ -1,17 +1,16 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- import asyncio import atexit import threading import weakref -from typing import Awaitable, Callable, TypeVar +from functools import wraps +from typing import Any, Awaitable, Callable, Coroutine, TypeVar T = TypeVar('T') class _TaskRunner: - """A task runner that runs an asyncio event loop on a background thread.""" + """在后台线程上运行 asyncio 事件循环的任务运行器""" def __init__(self): self.__loop: asyncio.AbstractEventLoop | None = None @@ -20,48 +19,55 @@ def __init__(self): atexit.register(self.close) def close(self): - """关闭事件循环""" + """关闭事件循环并清理""" if self.__loop: self.__loop.stop() + self.__loop = None + if self.__thread: + self.__thread.join() + self.__thread = None + name = f'TaskRunner-{threading.get_ident()}' + _runner_map.pop(name, None) def _target(self): - """后台线程目标""" - loop = self.__loop + """后台线程的目标函数""" try: - loop.run_forever() + self.__loop.run_forever() finally: - loop.close() + self.__loop.close() - def run(self, coro): - """在后台线程上同步运行协程""" + def run(self, coro: Awaitable[T]) -> T: + """在后台事件循环上运行协程并返回其结果""" with self.__lock: - name = f'{threading.current_thread().name} - runner' + name = f'TaskRunner-{threading.get_ident()}' if self.__loop is None: self.__loop = asyncio.new_event_loop() self.__thread = threading.Thread(target=self._target, daemon=True, name=name) self.__thread.start() - fut = asyncio.run_coroutine_threadsafe(coro, self.__loop) - return fut.result(None) + future = asyncio.run_coroutine_threadsafe(coro, self.__loop) + return future.result() _runner_map = weakref.WeakValueDictionary() -def run_await(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]: - """将协程包装在一个函数中,该函数会阻塞,直到它执行完为止""" +def run_await(coro: Callable[..., Awaitable[T]] | Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]: + """将协程包装在函数中,该函数将在后台事件循环上运行,直到它执行完为止""" + @wraps(coro) def wrapped(*args, **kwargs): - name = threading.current_thread().name inner = coro(*args, **kwargs) + if not asyncio.iscoroutine(inner) and not asyncio.isfuture(inner): + raise TypeError(f'Expected coroutine, got {type(inner)}') try: - # 如果当前此线程中正在运行循环 - # 使用任务运行程序 + # 如果事件循环正在运行,则使用任务调用 asyncio.get_running_loop() + name = f'TaskRunner-{threading.get_ident()}' if name not in _runner_map: _runner_map[name] = _TaskRunner() return _runner_map[name].run(inner) except RuntimeError: - # 如果没有,请创建一个新的事件循环 + # 如果没有,则创建一个新的事件循环 loop = asyncio.get_event_loop() return loop.run_until_complete(inner) From 367b83d188c157e419804b0f25698b7e2e6c877a Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Mon, 21 Jul 2025 12:24:15 +0800 Subject: [PATCH 3/3] Fix the close --- backend/app/task/utils/schedulers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/task/utils/schedulers.py b/backend/app/task/utils/schedulers.py index ccdd3bcf..80a6d480 100644 --- a/backend/app/task/utils/schedulers.py +++ b/backend/app/task/utils/schedulers.py @@ -362,7 +362,7 @@ def close(self): run_await(self.lock.release)() self.lock = None - self.sync() + super().close() def sync(self): """重写父函数"""