44import pickle
55import threading
66from datetime import date , datetime , timezone
7+ from enum import Enum
78from time import time
89from typing import Any
910
3233# Debounce our JSON validation a bit in order to not cause too much additional
3334# load everywhere
3435_last_validation_log : float | None = None
36+ Pipeline = Any
37+ # TODO type Pipeline instead of using Any here
3538
3639
3740def _validate_json_roundtrip (value : dict [str , Any ], model : type [models .Model ]) -> None :
@@ -49,6 +52,13 @@ def _validate_json_roundtrip(value: dict[str, Any], model: type[models.Model]) -
4952 logger .exception ("buffer.invalid_value" , extra = {"value" : value , "model" : model })
5053
5154
55+ class RedisOperation (Enum ):
56+ SET_ADD = "sadd"
57+ SET_GET = "smembers"
58+ HASH_ADD = "hset"
59+ HASH_GET_ALL = "hgetall"
60+
61+
5262class PendingBuffer :
5363 def __init__ (self , size : int ):
5464 assert size > 0
@@ -208,6 +218,48 @@ def get(
208218 col : (int (results [i ]) if results [i ] is not None else 0 ) for i , col in enumerate (columns )
209219 }
210220
221+ def get_redis_connection (self , key : str ) -> Pipeline | None :
222+ if is_instance_redis_cluster (self .cluster , self .is_redis_cluster ):
223+ conn = self .cluster
224+ elif is_instance_rb_cluster (self .cluster , self .is_redis_cluster ):
225+ conn = self .cluster .get_local_client_for_key (key )
226+ else :
227+ raise AssertionError ("unreachable" )
228+
229+ pipe = conn .pipeline ()
230+ return pipe
231+
232+ def _execute_redis_operation (self , key : str , operation : RedisOperation , * args : Any ) -> Any :
233+ pending_key = self ._make_pending_key_from_key (key )
234+ pipe = self .get_redis_connection (pending_key )
235+ if pipe :
236+ getattr (pipe , operation .value )(key , * args )
237+ if args :
238+ pipe .expire (key , self .key_expire )
239+ return pipe .execute ()
240+
241+ def push_to_set (self , key : str , value : list [int ] | int ) -> None :
242+ self ._execute_redis_operation (key , RedisOperation .SET_ADD , value )
243+
244+ def get_set (self , key : str ) -> list [set [int ]]:
245+ return self ._execute_redis_operation (key , RedisOperation .SET_GET )
246+
247+ def push_to_hash (
248+ self ,
249+ model : type [models .Model ],
250+ filters : dict [str , models .Model | str | int ],
251+ field : str ,
252+ value : int ,
253+ ) -> None :
254+ key = self ._make_key (model , filters )
255+ self ._execute_redis_operation (key , RedisOperation .HASH_ADD , field , value )
256+
257+ def get_hash (
258+ self , model : type [models .Model ], field : dict [str , models .Model | str | int ]
259+ ) -> dict [str , str ]:
260+ key = self ._make_key (model , field )
261+ return self ._execute_redis_operation (key , RedisOperation .HASH_GET_ALL )
262+
211263 def incr (
212264 self ,
213265 model : type [models .Model ],
@@ -226,19 +278,13 @@ def incr(
226278 - Perform a set on signal_only (only if True)
227279 - Add hashmap key to pending flushes
228280 """
229-
230281 key = self ._make_key (model , filters )
231282 pending_key = self ._make_pending_key_from_key (key )
232283 # We can't use conn.map() due to wanting to support multiple pending
233284 # keys (one per Redis partition)
234- if is_instance_redis_cluster (self .cluster , self .is_redis_cluster ):
235- conn = self .cluster
236- elif is_instance_rb_cluster (self .cluster , self .is_redis_cluster ):
237- conn = self .cluster .get_local_client_for_key (key )
238- else :
239- raise AssertionError ("unreachable" )
240-
241- pipe = conn .pipeline ()
285+ pipe = self .get_redis_connection (key )
286+ if not pipe :
287+ return
242288 pipe .hsetnx (key , "m" , f"{ model .__module__ } .{ model .__name__ } " )
243289 _validate_json_roundtrip (filters , model )
244290
0 commit comments