Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/app/task/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def init_celery() -> celery.Celery:
'group': OVERWRITE_CELERY_RESULT_GROUP_TABLE_NAME,
},
result_extended=True,
# result_expires=0, # 任务结果自动清理,0 或 None 表示不清理
# result_expires=0, # 清理任务结果,默认每天凌晨 4 点,0 或 None 表示不清理
# beat_sync_every=1, # 保存任务状态周期,默认 3 * 60 秒
beat_schedule=LOCAL_BEAT_SCHEDULE,
beat_scheduler='backend.app.task.utils.schedulers:DatabaseScheduler',
task_cls='backend.app.task.tasks.base:TaskBase',
Expand Down
94 changes: 94 additions & 0 deletions backend/app/task/utils/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from celery import current_app, schedules
from celery.beat import ScheduleEntry, Scheduler
from celery.signals import beat_init
from celery.utils.log import get_logger
from redis.asyncio.lock import Lock
from sqlalchemy import select
from sqlalchemy.exc import DatabaseError, InterfaceError

Expand All @@ -28,6 +30,38 @@
# 此计划程序必须比常规的 5 分钟更频繁地唤醒,因为它需要考虑对计划的外部更改
DEFAULT_MAX_INTERVAL = 5 # seconds

# 计划锁时长,避免重复创建
DEFAULT_MAX_LOCK_TIMEOUT = 300 # seconds

# 锁检测周期,应小于计划锁时长
DEFAULT_LOCK_INTERVAL = 60 # seconds

# Copied from:
# https://github.com/andymccurdy/redis-py/blob/master/redis/lock.py#L33
# Changes:
# The second line from the bottom: The original Lua script intends
# to extend time to (lock remaining time + additional time); while
# the script here extend time to an expected expiration time.
# KEYS[1] - lock name
# ARGS[1] - token
# ARGS[2] - additional milliseconds
# return 1 if the locks time was extended, otherwise 0
LUA_EXTEND_TO_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
local expiration = redis.call('pttl', KEYS[1])
if not expiration then
expiration = 0
end
if expiration < 0 then
return 0
end
redis.call('pexpire', KEYS[1], ARGV[2])
return 1
"""

logger = get_logger('fba.schedulers')


Expand Down Expand Up @@ -260,13 +294,18 @@ def _unpack_options(


class DatabaseScheduler(Scheduler):
"""数据库调度程序"""

Entry = ModelEntry

_schedule = None
_last_update = None
_initial_read = True
_heap_invalidated = False

lock: Lock | None = None
lock_key = f'{settings.CELERY_REDIS_PREFIX}:beat_lock'

def __init__(self, *args, **kwargs):
self.app = kwargs['app']
self._dirty = set()
Expand Down Expand Up @@ -315,6 +354,16 @@ def reserve(self, entry):
self._dirty.add(new_entry.name)
return new_entry

def close(self):
"""重写父函数"""
if self.lock:
logger.info('beat: Releasing lock')
if run_await(self.lock.owned)():
run_await(self.lock.release)()
self.lock = None

super().close()

def sync(self):
"""重写父函数"""
_tried = set()
Expand Down Expand Up @@ -401,3 +450,48 @@ def schedule(self) -> dict[str, ModelEntry]:

# logger.debug(self._schedule)
return self._schedule


async def extend_scheduler_lock(lock):
"""
延长调度程序锁

:param lock: 计划程序锁
:return:
"""
while True:
await asyncio.sleep(DEFAULT_LOCK_INTERVAL)
if lock:
try:
await lock.extend(DEFAULT_MAX_LOCK_TIMEOUT)
except Exception as e:
logger.error(f'Failed to extend lock: {e}')


@beat_init.connect
def acquire_distributed_beat_lock(sender=None, *args, **kwargs):
"""
尝试在启动时获取锁

:param sender: 接收方应响应的发送方
:return:
"""
scheduler = sender.scheduler
if not scheduler.lock_key:
return

logger.debug('beat: Acquiring lock...')
lock = redis_client.lock(
scheduler.lock_key,
timeout=DEFAULT_MAX_LOCK_TIMEOUT,
sleep=scheduler.max_interval,
)
# overwrite redis-py's extend script
# which will add additional timeout instead of extend to a new timeout
lock.lua_extend = redis_client.register_script(LUA_EXTEND_TO_SCRIPT)
run_await(lock.acquire)()
logger.info('beat: Acquired lock')
scheduler.lock = lock

loop = asyncio.get_event_loop()
loop.create_task(extend_scheduler_lock(scheduler.lock))
46 changes: 26 additions & 20 deletions backend/utils/_await.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import atexit
import threading
import weakref

from typing import Awaitable, Callable, TypeVar
from functools import wraps
from typing import Any, Awaitable, Callable, Coroutine, TypeVar

T = TypeVar('T')


class _TaskRunner:
"""A task runner that runs an asyncio event loop on a background thread."""
"""在后台线程上运行 asyncio 事件循环的任务运行器"""

def __init__(self):
self.__loop: asyncio.AbstractEventLoop | None = None
Expand All @@ -20,48 +19,55 @@ def __init__(self):
atexit.register(self.close)

def close(self):
"""关闭事件循环"""
"""关闭事件循环并清理"""
if self.__loop:
self.__loop.stop()
self.__loop = None
if self.__thread:
self.__thread.join()
self.__thread = None
name = f'TaskRunner-{threading.get_ident()}'
_runner_map.pop(name, None)

def _target(self):
"""后台线程目标"""
loop = self.__loop
"""后台线程的目标函数"""
try:
loop.run_forever()
self.__loop.run_forever()
finally:
loop.close()
self.__loop.close()

def run(self, coro):
"""在后台线程上同步运行协程"""
def run(self, coro: Awaitable[T]) -> T:
"""在后台事件循环上运行协程并返回其结果"""
with self.__lock:
name = f'{threading.current_thread().name} - runner'
name = f'TaskRunner-{threading.get_ident()}'
if self.__loop is None:
self.__loop = asyncio.new_event_loop()
self.__thread = threading.Thread(target=self._target, daemon=True, name=name)
self.__thread.start()
fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
return fut.result(None)
future = asyncio.run_coroutine_threadsafe(coro, self.__loop)
return future.result()


_runner_map = weakref.WeakValueDictionary()


def run_await(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]:
"""将协程包装在一个函数中,该函数会阻塞,直到它执行完为止"""
def run_await(coro: Callable[..., Awaitable[T]] | Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
"""将协程包装在函数中,该函数将在后台事件循环上运行,直到它执行完为止"""

@wraps(coro)
def wrapped(*args, **kwargs):
name = threading.current_thread().name
inner = coro(*args, **kwargs)
if not asyncio.iscoroutine(inner) and not asyncio.isfuture(inner):
raise TypeError(f'Expected coroutine, got {type(inner)}')
try:
# 如果当前此线程中正在运行循环
# 使用任务运行程序
# 如果事件循环正在运行,则使用任务调用
asyncio.get_running_loop()
name = f'TaskRunner-{threading.get_ident()}'
if name not in _runner_map:
_runner_map[name] = _TaskRunner()
return _runner_map[name].run(inner)
except RuntimeError:
# 如果没有,请创建一个新的事件循环
# 如果没有,则创建一个新的事件循环
loop = asyncio.get_event_loop()
return loop.run_until_complete(inner)

Expand Down