Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions lib/dl_task_processor/dl_task_processor/arq_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import functools
import logging
from typing import (
Any,
Expand All @@ -8,15 +9,18 @@
runtime_checkable,
)

from arq import (
create_pool,
cron,
)
from arq import ArqRedis
from arq import Retry as ArqRetry
from arq import cron
from arq.connections import RedisSettings as ArqRedisSettings
from arq.constants import (
default_queue_name,
expires_extra_ms,
)
from arq.cron import CronJob as _CronJob
import attr
from redis import RedisError
from redis.asyncio import Sentinel

from dl_configs.enums import RedisMode
from dl_task_processor.executor import (
Expand All @@ -30,10 +34,11 @@
)


EXECUTOR_KEY = "bi_executor"

LOGGER = logging.getLogger(__name__)

EXECUTOR_KEY = "bi_executor"
SOCKET_TIMEOUT = 3 # timeout for a single Redis command


CronTask: TypeAlias = _CronJob

Expand All @@ -58,8 +63,82 @@ async def __call__(self, ctx: dict[Any, Any], *args: Any, **kwargs: Any) -> Any:


async def create_redis_pool(settings: ArqRedisSettings) -> ArqRedis:
"""
Note: mostly copy-n-paste from `arq.create_pool`
The only meaningful change to the upstream version in passing additional timeouts to `Redis` & `Sentinel`
"""

LOGGER.info("Creating redis pool for an arq worker on %s, db=%s", settings.host, settings.database)
return await create_pool(settings)

assert not (
type(settings.host) is str and settings.sentinel
), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"

if settings.sentinel:
# def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
def pool_factory(db: int, username: str | None, password: str | None, encoding: str) -> ArqRedis:
client = Sentinel(
sentinels=settings.host,
ssl=settings.ssl,
socket_connect_timeout=settings.conn_timeout,
socket_timeout=SOCKET_TIMEOUT,
db=db,
username=username,
password=password,
encoding=encoding,
)
return client.master_for(settings.sentinel_master, redis_class=ArqRedis)

else:
pool_factory = functools.partial(
ArqRedis,
host=settings.host,
port=settings.port,
unix_socket_path=settings.unix_socket_path,
socket_connect_timeout=settings.conn_timeout,
socket_timeout=SOCKET_TIMEOUT,
ssl=settings.ssl,
ssl_keyfile=settings.ssl_keyfile,
ssl_certfile=settings.ssl_certfile,
ssl_cert_reqs=settings.ssl_cert_reqs,
ssl_ca_certs=settings.ssl_ca_certs,
ssl_ca_data=settings.ssl_ca_data,
ssl_check_hostname=settings.ssl_check_hostname,
)

retry = 0
while True:
try:
pool = pool_factory(
db=settings.database,
username=settings.username,
password=settings.password,
encoding="utf8",
)
pool.job_serializer = None
pool.job_deserializer = None
pool.default_queue_name = default_queue_name
pool.expires_extra_ms = expires_extra_ms
await pool.ping()

except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e:
if retry < settings.conn_retries:
LOGGER.warning(
"redis connection error %s:%s %s %s, %d retries remaining...",
settings.host,
settings.port,
e.__class__.__name__,
e,
settings.conn_retries - retry,
)
await asyncio.sleep(settings.conn_retry_delay)
retry = retry + 1
else:
raise
else:
if retry > 0:
LOGGER.info("redis connection successful")
return pool


@runtime_checkable
Expand All @@ -86,6 +165,7 @@ def create_arq_redis_settings(settings: _BIRedisSettings) -> ArqRedisSettings:
password=settings.PASSWORD,
database=settings.DB,
ssl=settings.SSL or False,
conn_timeout=3,
)
elif settings.MODE == RedisMode.sentinel:
redis_targets = [(host, settings.PORT) for host in settings.HOSTS]
Expand All @@ -96,6 +176,7 @@ def create_arq_redis_settings(settings: _BIRedisSettings) -> ArqRedisSettings:
sentinel_master=settings.CLUSTER_NAME,
database=settings.DB,
ssl=settings.SSL or False,
conn_timeout=3,
)
else:
raise ValueError(f"Unknown redis mode {settings.MODE}")
Expand Down
Loading