1616import logging
1717import pickle
1818from collections import defaultdict
19- from typing import TYPE_CHECKING , Any , Dict , Iterable , Optional , Union
19+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
2020
2121import jump
2222from prometheus_client import Counter , Histogram
23+ from txredisapi import RedisError
2324
2425from twisted .internet import defer
2526
@@ -82,6 +83,7 @@ def __init__(self, hs: "HomeServer"):
8283 host = shard ["host" ],
8384 port = shard ["port" ],
8485 reconnect = True ,
86+ replyTimeout = 5 ,
8587 ),
8688 )
8789
@@ -131,17 +133,25 @@ async def mset(
131133 self ._redis_shards [shard_id ].mset (values )
132134 for shard_id , values in shard_id_to_encoded_values .items ()
133135 ]
134- await make_deferred_yieldable (
135- defer .gatherResults (deferreds , consumeErrors = True )
136- ).addErrback (unwrapFirstError )
136+ try :
137+ await make_deferred_yieldable (
138+ defer .gatherResults (deferreds , consumeErrors = True )
139+ ).addErrback (unwrapFirstError )
140+ except RedisError as e :
141+ logger .error ("Failed to set on one or more Redis shards: %r" , e )
137142
138143 async def set (self , cache_name : str , key : str , value : Any ) -> None :
139144 await self .mset (cache_name , {key : value })
140145
141146 async def _mget_shard (
142147 self , shard_id : int , key_mapping : Dict [str , str ]
143148 ) -> Dict [str , Any ]:
144- results = await self ._redis_shards [shard_id ].mget (list (key_mapping .values ()))
149+ shard = self ._redis_shards [shard_id ]
150+ try :
151+ results = await shard .mget (list (key_mapping .values ()))
152+ except RedisError as e :
153+ logger .error ("Failed to get from Redis %r: %r" , shard , e )
154+ return {}
145155 original_keys = list (key_mapping .keys ())
146156 mapped_results : Dict [str , Any ] = {}
147157 for i , result in enumerate (results ):
@@ -150,12 +160,12 @@ async def _mget_shard(
150160 try :
151161 result = pickle .loads (result )
152162 except Exception as e :
153- logger .warning ("Failed to decode cache result: %r" , e )
163+ logger .error ("Failed to decode cache result: %r" , e )
154164 else :
155165 mapped_results [original_keys [i ]] = result
156166 return mapped_results
157167
158- async def mget (self , cache_name : str , keys : Iterable [str ]) -> Dict [str , Any ]:
168+ async def mget (self , cache_name : str , keys : List [str ]) -> Dict [str , Any ]:
159169 """Look up a key/value combinations in the named cache."""
160170
161171 if not self .is_enabled ():
0 commit comments