diff --git a/backend/app/task/celery.py b/backend/app/task/celery.py index e86e29912..8e9ffe1a0 100644 --- a/backend/app/task/celery.py +++ b/backend/app/task/celery.py @@ -43,7 +43,8 @@ def init_celery() -> celery.Celery: 'group': OVERWRITE_CELERY_RESULT_GROUP_TABLE_NAME, }, result_extended=True, - # result_expires=0, # 任务结果自动清理,0 或 None 表示不清理 + # result_expires=0, # 清理任务结果,默认每天凌晨 4 点,0 或 None 表示不清理 + # beat_sync_every=1, # 保存任务状态周期,默认 3 * 60 秒 beat_schedule=LOCAL_BEAT_SCHEDULE, beat_scheduler='backend.app.task.utils.schedulers:DatabaseScheduler', task_cls='backend.app.task.tasks.base:TaskBase', diff --git a/backend/app/task/utils/schedulers.py b/backend/app/task/utils/schedulers.py index 5644fe59f..80a6d4808 100644 --- a/backend/app/task/utils/schedulers.py +++ b/backend/app/task/utils/schedulers.py @@ -9,7 +9,9 @@ from celery import current_app, schedules from celery.beat import ScheduleEntry, Scheduler +from celery.signals import beat_init from celery.utils.log import get_logger +from redis.asyncio.lock import Lock from sqlalchemy import select from sqlalchemy.exc import DatabaseError, InterfaceError @@ -28,6 +30,38 @@ # 此计划程序必须比常规的 5 分钟更频繁地唤醒,因为它需要考虑对计划的外部更改 DEFAULT_MAX_INTERVAL = 5 # seconds +# 计划锁时长,避免重复创建 +DEFAULT_MAX_LOCK_TIMEOUT = 300 # seconds + +# 锁检测周期,应小于计划锁时长 +DEFAULT_LOCK_INTERVAL = 60 # seconds + +# Copied from: +# https://github.com/andymccurdy/redis-py/blob/master/redis/lock.py#L33 +# Changes: +# The second line from the bottom: The original Lua script intends +# to extend time to (lock remaining time + additional time); while +# the script here extend time to an expected expiration time. +# KEYS[1] - lock name +# ARGS[1] - token +# ARGS[2] - additional milliseconds +# return 1 if the locks time was extended, otherwise 0 +LUA_EXTEND_TO_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 +""" + logger = get_logger('fba.schedulers') @@ -260,6 +294,8 @@ def _unpack_options( class DatabaseScheduler(Scheduler): + """数据库调度程序""" + Entry = ModelEntry _schedule = None @@ -267,6 +303,9 @@ class DatabaseScheduler(Scheduler): _initial_read = True _heap_invalidated = False + lock: Lock | None = None + lock_key = f'{settings.CELERY_REDIS_PREFIX}:beat_lock' + def __init__(self, *args, **kwargs): self.app = kwargs['app'] self._dirty = set() @@ -315,6 +354,16 @@ def reserve(self, entry): self._dirty.add(new_entry.name) return new_entry + def close(self): + """重写父函数""" + if self.lock: + logger.info('beat: Releasing lock') + if run_await(self.lock.owned)(): + run_await(self.lock.release)() + self.lock = None + + super().close() + def sync(self): """重写父函数""" _tried = set() @@ -401,3 +450,48 @@ def schedule(self) -> dict[str, ModelEntry]: # logger.debug(self._schedule) return self._schedule + + +async def extend_scheduler_lock(lock): + """ + 延长调度程序锁 + + :param lock: 计划程序锁 + :return: + """ + while True: + await asyncio.sleep(DEFAULT_LOCK_INTERVAL) + if lock: + try: + await lock.extend(DEFAULT_MAX_LOCK_TIMEOUT) + except Exception as e: + logger.error(f'Failed to extend lock: {e}') + + +@beat_init.connect +def acquire_distributed_beat_lock(sender=None, *args, **kwargs): + """ + 尝试在启动时获取锁 + + :param sender: 接收方应响应的发送方 + :return: + """ + scheduler = sender.scheduler + if not scheduler.lock_key: + return + + logger.debug('beat: Acquiring lock...') + lock = redis_client.lock( + scheduler.lock_key, + timeout=DEFAULT_MAX_LOCK_TIMEOUT, + sleep=scheduler.max_interval, + ) + # overwrite redis-py's extend script + # which will add additional timeout instead of extend to a new timeout + lock.lua_extend = redis_client.register_script(LUA_EXTEND_TO_SCRIPT) + run_await(lock.acquire)() + logger.info('beat: Acquired lock') + scheduler.lock = lock + + loop = asyncio.get_event_loop() + loop.create_task(extend_scheduler_lock(scheduler.lock)) diff --git a/backend/utils/_await.py b/backend/utils/_await.py index ae61d9767..e1950fdea 100644 --- a/backend/utils/_await.py +++ b/backend/utils/_await.py @@ -1,17 +1,16 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- import asyncio import atexit import threading import weakref -from typing import Awaitable, Callable, TypeVar +from functools import wraps +from typing import Any, Awaitable, Callable, Coroutine, TypeVar T = TypeVar('T') class _TaskRunner: - """A task runner that runs an asyncio event loop on a background thread.""" + """在后台线程上运行 asyncio 事件循环的任务运行器""" def __init__(self): self.__loop: asyncio.AbstractEventLoop | None = None @@ -20,48 +19,55 @@ def __init__(self): atexit.register(self.close) def close(self): - """关闭事件循环""" + """关闭事件循环并清理""" if self.__loop: self.__loop.stop() + self.__loop = None + if self.__thread: + self.__thread.join() + self.__thread = None + name = f'TaskRunner-{threading.get_ident()}' + _runner_map.pop(name, None) def _target(self): - """后台线程目标""" - loop = self.__loop + """后台线程的目标函数""" try: - loop.run_forever() + self.__loop.run_forever() finally: - loop.close() + self.__loop.close() - def run(self, coro): - """在后台线程上同步运行协程""" + def run(self, coro: Awaitable[T]) -> T: + """在后台事件循环上运行协程并返回其结果""" with self.__lock: - name = f'{threading.current_thread().name} - runner' + name = f'TaskRunner-{threading.get_ident()}' if self.__loop is None: self.__loop = asyncio.new_event_loop() self.__thread = threading.Thread(target=self._target, daemon=True, name=name) self.__thread.start() - fut = asyncio.run_coroutine_threadsafe(coro, self.__loop) - return fut.result(None) + future = asyncio.run_coroutine_threadsafe(coro, self.__loop) + return future.result() _runner_map = weakref.WeakValueDictionary() -def run_await(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]: - """将协程包装在一个函数中,该函数会阻塞,直到它执行完为止""" +def run_await(coro: Callable[..., Awaitable[T]] | Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]: + """将协程包装在函数中,该函数将在后台事件循环上运行,直到它执行完为止""" + @wraps(coro) def wrapped(*args, **kwargs): - name = threading.current_thread().name inner = coro(*args, **kwargs) + if not asyncio.iscoroutine(inner) and not asyncio.isfuture(inner): + raise TypeError(f'Expected coroutine, got {type(inner)}') try: - # 如果当前此线程中正在运行循环 - # 使用任务运行程序 + # 如果事件循环正在运行,则使用任务调用 asyncio.get_running_loop() + name = f'TaskRunner-{threading.get_ident()}' if name not in _runner_map: _runner_map[name] = _TaskRunner() return _runner_map[name].run(inner) except RuntimeError: - # 如果没有,请创建一个新的事件循环 + # 如果没有,则创建一个新的事件循环 loop = asyncio.get_event_loop() return loop.run_until_complete(inner)