99from urllib .parse import urlparse
1010from uuid import uuid4
1111
12- import aioredis
13- from aioredis import MultiExecError , Redis
1412from pydantic .validators import make_arbitrary_type_validator
13+ from redis .asyncio import ConnectionPool , Redis
14+ from redis .asyncio .sentinel import Sentinel
15+ from redis .exceptions import RedisError , WatchError
1516
1617from .constants import default_queue_name , job_key_prefix , result_key_prefix
1718from .jobs import Deserializer , Job , JobDef , JobResult , Serializer , deserialize_job , serialize_job
@@ -70,20 +71,20 @@ def __repr__(self) -> str:
7071expires_extra_ms = 86_400_000
7172
7273
73- class ArqRedis (Redis ): # type: ignore
74+ class ArqRedis (Redis ): # type: ignore[misc]
7475 """
75- Thin subclass of ``aioredis .Redis`` which adds :func:`arq.connections.enqueue_job`.
76+ Thin subclass of ``redis.asyncio .Redis`` which adds :func:`arq.connections.enqueue_job`.
7677
7778 :param redis_settings: an instance of ``arq.connections.RedisSettings``.
7879 :param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
7980 :param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
8081 :param default_queue_name: the default queue name to use, defaults to ``arq.queue``.
81- :param kwargs: keyword arguments directly passed to ``aioredis .Redis``.
82+ :param kwargs: keyword arguments directly passed to ``redis.asyncio .Redis``.
8283 """
8384
8485 def __init__ (
8586 self ,
86- pool_or_conn : Any ,
87+ pool_or_conn : Optional [ ConnectionPool ] = None ,
8788 job_serializer : Optional [Serializer ] = None ,
8889 job_deserializer : Optional [Deserializer ] = None ,
8990 default_queue_name : str = default_queue_name ,
@@ -92,7 +93,9 @@ def __init__(
9293 self .job_serializer = job_serializer
9394 self .job_deserializer = job_deserializer
9495 self .default_queue_name = default_queue_name
95- super ().__init__ (pool_or_conn , ** kwargs )
96+ if pool_or_conn :
97+ kwargs ['connection_pool' ] = pool_or_conn
98+ super ().__init__ (** kwargs )
9699
97100 async def enqueue_job (
98101 self ,
@@ -129,14 +132,10 @@ async def enqueue_job(
129132 defer_by_ms = to_ms (_defer_by )
130133 expires_ms = to_ms (_expires )
131134
132- with await self as conn :
133- pipe = conn .pipeline ()
134- pipe .unwatch ()
135- pipe .watch (job_key )
136- job_exists = pipe .exists (job_key )
137- job_result_exists = pipe .exists (result_key_prefix + job_id )
138- await pipe .execute ()
139- if await job_exists or await job_result_exists :
135+ async with self .pipeline (transaction = True ) as pipe :
136+ await pipe .watch (job_key )
137+ if any (await asyncio .gather (pipe .exists (job_key ), pipe .exists (result_key_prefix + job_id ))):
138+ await pipe .reset ()
140139 return None
141140
142141 enqueue_time_ms = timestamp_ms ()
@@ -150,24 +149,22 @@ async def enqueue_job(
150149 expires_ms = expires_ms or score - enqueue_time_ms + expires_extra_ms
151150
152151 job = serialize_job (function , args , kwargs , _job_try , enqueue_time_ms , serializer = self .job_serializer )
153- tr = conn . multi_exec ()
154- tr .psetex (job_key , expires_ms , job )
155- tr .zadd (_queue_name , score , job_id )
152+ pipe . multi ()
153+ pipe .psetex (job_key , expires_ms , job )
154+ pipe .zadd (_queue_name , { job_id : score } )
156155 try :
157- await tr .execute ()
158- except MultiExecError :
156+ await pipe .execute ()
157+ except WatchError :
159158 # job got enqueued since we checked 'job_exists'
160- # https://github.com/samuelcolvin/arq/issues/131, avoid warnings in log
161- await asyncio .gather (* tr ._results , return_exceptions = True )
162159 return None
163160 return Job (job_id , redis = self , _queue_name = _queue_name , _deserializer = self .job_deserializer )
164161
165- async def _get_job_result (self , key : str ) -> JobResult :
166- job_id = key [len (result_key_prefix ) :]
162+ async def _get_job_result (self , key : bytes ) -> JobResult :
163+ job_id = key [len (result_key_prefix ) :]. decode ()
167164 job = Job (job_id , self , _deserializer = self .job_deserializer )
168165 r = await job .result_info ()
169166 if r is None :
170- raise KeyError (f'job "{ key } " not found' )
167+ raise KeyError (f'job "{ key . decode () } " not found' )
171168 r .job_id = job_id
172169 return r
173170
@@ -179,8 +176,8 @@ async def all_job_results(self) -> List[JobResult]:
179176 results = await asyncio .gather (* [self ._get_job_result (k ) for k in keys ])
180177 return sorted (results , key = attrgetter ('enqueue_time' ))
181178
182- async def _get_job_def (self , job_id : str , score : int ) -> JobDef :
183- v = await self .get (job_key_prefix + job_id , encoding = None )
179+ async def _get_job_def (self , job_id : bytes , score : int ) -> JobDef :
180+ v = await self .get (job_key_prefix + job_id . decode () )
184181 jd = deserialize_job (v , deserializer = self .job_deserializer )
185182 jd .score = score
186183 return jd
@@ -189,8 +186,8 @@ async def queued_jobs(self, *, queue_name: str = default_queue_name) -> List[Job
189186 """
190187 Get information about queued, mostly useful when testing.
191188 """
192- jobs = await self .zrange (queue_name , withscores = True )
193- return await asyncio .gather (* [self ._get_job_def (job_id , score ) for job_id , score in jobs ])
189+ jobs = await self .zrange (queue_name , withscores = True , start = 0 , end = - 1 )
190+ return await asyncio .gather (* [self ._get_job_def (job_id , int ( score ) ) for job_id , score in jobs ])
194191
195192
196193async def create_pool (
@@ -204,8 +201,7 @@ async def create_pool(
204201 """
205202 Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
206203
207- Similar to ``aioredis.create_redis_pool`` except it returns a :class:`arq.connections.ArqRedis` instance,
208- thus allowing job enqueuing.
204+ Returns a :class:`arq.connections.ArqRedis` instance, thus allowing job enqueuing.
209205 """
210206 settings : RedisSettings = RedisSettings () if settings_ is None else settings_
211207
@@ -214,32 +210,33 @@ async def create_pool(
214210 ), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"
215211
216212 if settings .sentinel :
217- addr : Any = settings .host
218213
219- async def pool_factory (* args : Any , ** kwargs : Any ) -> Redis :
220- client = await aioredis . sentinel . create_sentinel_pool (* args , ssl = settings .ssl , ** kwargs )
221- return client .master_for (settings .sentinel_master )
214+ def pool_factory (* args : Any , ** kwargs : Any ) -> ArqRedis :
215+ client = Sentinel (* args , sentinels = settings . host , ssl = settings .ssl , ** kwargs )
216+ return client .master_for (settings .sentinel_master , redis_class = ArqRedis )
222217
223218 else :
224219 pool_factory = functools .partial (
225- aioredis .create_pool , create_connection_timeout = settings .conn_timeout , ssl = settings .ssl
220+ ArqRedis ,
221+ host = settings .host ,
222+ port = settings .port ,
223+ socket_connect_timeout = settings .conn_timeout ,
224+ ssl = settings .ssl ,
226225 )
227- addr = settings .host , settings .port
228226
229227 try :
230- pool = await pool_factory (addr , db = settings .database , password = settings .password , encoding = 'utf8' )
231- pool = ArqRedis (
232- pool ,
233- job_serializer = job_serializer ,
234- job_deserializer = job_deserializer ,
235- default_queue_name = default_queue_name ,
236- )
228+ pool = pool_factory (db = settings .database , password = settings .password , encoding = 'utf8' )
229+ pool .job_serializer = job_serializer
230+ pool .job_deserializer = job_deserializer
231+ pool .default_queue_name = default_queue_name
232+ await pool .ping ()
237233
238- except (ConnectionError , OSError , aioredis . RedisError , asyncio .TimeoutError ) as e :
234+ except (ConnectionError , OSError , RedisError , asyncio .TimeoutError ) as e :
239235 if retry < settings .conn_retries :
240236 logger .warning (
241- 'redis connection error %s %s %s, %d retries remaining...' ,
242- addr ,
237+ 'redis connection error %s:%s %s %s, %d retries remaining...' ,
238+ settings .host ,
239+ settings .port ,
243240 e .__class__ .__name__ ,
244241 e ,
245242 settings .conn_retries - retry ,
@@ -264,17 +261,16 @@ async def pool_factory(*args: Any, **kwargs: Any) -> Redis:
264261
265262
266263async def log_redis_info (redis : Redis , log_func : Callable [[str ], Any ]) -> None :
267- with await redis as r :
268- info_server , info_memory , info_clients , key_count = await asyncio .gather (
269- r .info (section = 'Server' ),
270- r .info (section = 'Memory' ),
271- r .info (section = 'Clients' ),
272- r .dbsize (),
273- )
274-
275- redis_version = info_server .get ('server' , {}).get ('redis_version' , '?' )
276- mem_usage = info_memory .get ('memory' , {}).get ('used_memory_human' , '?' )
277- clients_connected = info_clients .get ('clients' , {}).get ('connected_clients' , '?' )
264+ async with redis .pipeline (transaction = True ) as pipe :
265+ pipe .info (section = 'Server' )
266+ pipe .info (section = 'Memory' )
267+ pipe .info (section = 'Clients' )
268+ pipe .dbsize ()
269+ info_server , info_memory , info_clients , key_count = await pipe .execute ()
270+
271+ redis_version = info_server .get ('redis_version' , '?' )
272+ mem_usage = info_memory .get ('used_memory_human' , '?' )
273+ clients_connected = info_clients .get ('connected_clients' , '?' )
278274
279275 log_func (
280276 f'redis_version={ redis_version } '
0 commit comments