Skip to content

Commit 0bb4f32

Browse files
Add support for the set functions from issue #597
Co-authored-by: Ali Rezaei <[email protected]>
1 parent ce47b30 commit 0bb4f32

File tree

5 files changed

+654
-3
lines changed

5 files changed

+654
-3
lines changed

changelog.d/730.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support for sets and support basic operations, sadd, scard, sdiff, sdiffstore, sinter, sinterstore, smismember, sismember, smembers, smove, spop, srandmember, srem, sscan, sscan_iter, sunion, sunionstore

django_redis/cache.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,74 @@ def close(self, **kwargs):
185185
def touch(self, *args, **kwargs):
186186
return self.client.touch(*args, **kwargs)
187187

188+
@omit_exception
189+
def sadd(self, *args, **kwargs):
190+
return self.client.sadd(*args, **kwargs)
191+
192+
@omit_exception
193+
def scard(self, *args, **kwargs):
194+
return self.client.scard(*args, **kwargs)
195+
196+
@omit_exception
197+
def sdiff(self, *args, **kwargs):
198+
return self.client.sdiff(*args, **kwargs)
199+
200+
@omit_exception
201+
def sdiffstore(self, *args, **kwargs):
202+
return self.client.sdiffstore(*args, **kwargs)
203+
204+
@omit_exception
205+
def sinter(self, *args, **kwargs):
206+
return self.client.sinter(*args, **kwargs)
207+
208+
@omit_exception
209+
def sinterstore(self, *args, **kwargs):
210+
return self.client.sinterstore(*args, **kwargs)
211+
212+
@omit_exception
213+
def sismember(self, *args, **kwargs):
214+
return self.client.sismember(*args, **kwargs)
215+
216+
@omit_exception
217+
def smembers(self, *args, **kwargs):
218+
return self.client.smembers(*args, **kwargs)
219+
220+
@omit_exception
221+
def smove(self, *args, **kwargs):
222+
return self.client.smove(*args, **kwargs)
223+
224+
@omit_exception
225+
def spop(self, *args, **kwargs):
226+
return self.client.spop(*args, **kwargs)
227+
228+
@omit_exception
229+
def srandmember(self, *args, **kwargs):
230+
return self.client.srandmember(*args, **kwargs)
231+
232+
@omit_exception
233+
def srem(self, *args, **kwargs):
234+
return self.client.srem(*args, **kwargs)
235+
236+
@omit_exception
237+
def sscan(self, *args, **kwargs):
238+
return self.client.sscan(*args, **kwargs)
239+
240+
@omit_exception
241+
def sscan_iter(self, *args, **kwargs):
242+
return self.client.sscan_iter(*args, **kwargs)
243+
244+
@omit_exception
245+
def smismember(self, *args, **kwargs):
246+
return self.client.smismember(*args, **kwargs)
247+
248+
@omit_exception
249+
def sunion(self, *args, **kwargs):
250+
return self.client.sunion(*args, **kwargs)
251+
252+
@omit_exception
253+
def sunionstore(self, *args, **kwargs):
254+
return self.client.sunionstore(*args, **kwargs)
255+
188256
@omit_exception
189257
def hset(self, *args, **kwargs):
190258
return self.client.hset(*args, **kwargs)

django_redis/client/default.py

Lines changed: 280 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,26 @@
33
import socket
44
from collections import OrderedDict
55
from contextlib import suppress
6-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
6+
from typing import (
7+
Any,
8+
Dict,
9+
Iterable,
10+
Iterator,
11+
List,
12+
Optional,
13+
Set,
14+
Tuple,
15+
Union,
16+
cast,
17+
)
718

819
from django.conf import settings
920
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func
1021
from django.core.exceptions import ImproperlyConfigured
1122
from django.utils.module_loading import import_string
1223
from redis import Redis
1324
from redis.exceptions import ConnectionError, ResponseError, TimeoutError
14-
from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT
25+
from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT, PatternT
1526

1627
from django_redis import pool
1728
from django_redis.exceptions import CompressorError, ConnectionInterrupted
@@ -66,6 +77,14 @@ def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None:
6677
def __contains__(self, key: KeyT) -> bool:
6778
return self.has_key(key)
6879

80+
def _has_compression_enabled(self) -> bool:
81+
return (
82+
self._options.get(
83+
"COMPRESSOR", "django_redis.compressors.identity.IdentityCompressor"
84+
)
85+
!= "django_redis.compressors.identity.IdentityCompressor"
86+
)
87+
6988
def get_next_client_index(
7089
self, write: bool = True, tried: Optional[List[int]] = None
7190
) -> int:
@@ -778,6 +797,265 @@ def make_pattern(
778797

779798
return CacheKey(self._backend.key_func(pattern, prefix, version_str))
780799

800+
def sadd(
801+
self,
802+
key: KeyT,
803+
*values: Any,
804+
version: Optional[int] = None,
805+
client: Optional[Redis] = None,
806+
) -> int:
807+
if client is None:
808+
client = self.get_client(write=True)
809+
810+
key = self.make_key(key, version=version)
811+
encoded_values = [self.encode(value) for value in values]
812+
return int(client.sadd(key, *encoded_values))
813+
814+
def scard(
815+
self,
816+
key: KeyT,
817+
version: Optional[int] = None,
818+
client: Optional[Redis] = None,
819+
) -> int:
820+
if client is None:
821+
client = self.get_client(write=False)
822+
823+
key = self.make_key(key, version=version)
824+
return int(client.scard(key))
825+
826+
def sdiff(
827+
self,
828+
*keys: KeyT,
829+
version: Optional[int] = None,
830+
client: Optional[Redis] = None,
831+
) -> Set:
832+
if client is None:
833+
client = self.get_client(write=False)
834+
835+
nkeys = [self.make_key(key, version=version) for key in keys]
836+
return {self.decode(value) for value in client.sdiff(*nkeys)}
837+
838+
def sdiffstore(
839+
self,
840+
dest: KeyT,
841+
*keys: KeyT,
842+
version_dest: Optional[int] = None,
843+
version_keys: Optional[int] = None,
844+
client: Optional[Redis] = None,
845+
) -> int:
846+
if client is None:
847+
client = self.get_client(write=True)
848+
849+
dest = self.make_key(dest, version=version_dest)
850+
nkeys = [self.make_key(key, version=version_keys) for key in keys]
851+
return int(client.sdiffstore(dest, *nkeys))
852+
853+
def sinter(
854+
self,
855+
*keys: KeyT,
856+
version: Optional[int] = None,
857+
client: Optional[Redis] = None,
858+
) -> Set:
859+
if client is None:
860+
client = self.get_client(write=False)
861+
862+
nkeys = [self.make_key(key, version=version) for key in keys]
863+
return {self.decode(value) for value in client.sinter(*nkeys)}
864+
865+
def sinterstore(
866+
self,
867+
dest: KeyT,
868+
*keys: KeyT,
869+
version: Optional[int] = None,
870+
client: Optional[Redis] = None,
871+
) -> int:
872+
if client is None:
873+
client = self.get_client(write=True)
874+
875+
dest = self.make_key(dest, version=version)
876+
nkeys = [self.make_key(key, version=version) for key in keys]
877+
return int(client.sinterstore(dest, *nkeys))
878+
879+
def smismember(
880+
self,
881+
key: KeyT,
882+
*members,
883+
version: Optional[int] = None,
884+
client: Optional[Redis] = None,
885+
) -> List[bool]:
886+
if client is None:
887+
client = self.get_client(write=False)
888+
889+
key = self.make_key(key, version=version)
890+
encoded_members = [self.encode(member) for member in members]
891+
892+
return [bool(value) for value in client.smismember(key, *encoded_members)]
893+
894+
def sismember(
895+
self,
896+
key: KeyT,
897+
member: Any,
898+
version: Optional[int] = None,
899+
client: Optional[Redis] = None,
900+
) -> bool:
901+
if client is None:
902+
client = self.get_client(write=False)
903+
904+
key = self.make_key(key, version=version)
905+
member = self.encode(member)
906+
return bool(client.sismember(key, member))
907+
908+
def smembers(
909+
self,
910+
key: KeyT,
911+
version: Optional[int] = None,
912+
client: Optional[Redis] = None,
913+
) -> Set:
914+
if client is None:
915+
client = self.get_client(write=False)
916+
917+
key = self.make_key(key, version=version)
918+
return {self.decode(value) for value in client.smembers(key)}
919+
920+
def smove(
921+
self,
922+
source: KeyT,
923+
destination: KeyT,
924+
member: Any,
925+
version: Optional[int] = None,
926+
client: Optional[Redis] = None,
927+
) -> bool:
928+
if client is None:
929+
client = self.get_client(write=True)
930+
931+
source = self.make_key(source, version=version)
932+
destination = self.make_key(destination)
933+
member = self.encode(member)
934+
return bool(client.smove(source, destination, member))
935+
936+
def spop(
937+
self,
938+
key: KeyT,
939+
count: Optional[int] = None,
940+
version: Optional[int] = None,
941+
client: Optional[Redis] = None,
942+
) -> Union[Set, Any]:
943+
if client is None:
944+
client = self.get_client(write=True)
945+
946+
nkey = self.make_key(key, version=version)
947+
result = client.spop(nkey, count)
948+
if result is None:
949+
return None
950+
if isinstance(result, list):
951+
return {self.decode(value) for value in result}
952+
return self.decode(result)
953+
954+
def srandmember(
955+
self,
956+
key: KeyT,
957+
count: Optional[int] = None,
958+
version: Optional[int] = None,
959+
client: Optional[Redis] = None,
960+
) -> Union[List, Any]:
961+
if client is None:
962+
client = self.get_client(write=False)
963+
964+
key = self.make_key(key, version=version)
965+
result = client.srandmember(key, count)
966+
if result is None:
967+
return None
968+
if isinstance(result, list):
969+
return [self.decode(value) for value in result]
970+
return self.decode(result)
971+
972+
def srem(
973+
self,
974+
key: KeyT,
975+
*members: EncodableT,
976+
version: Optional[int] = None,
977+
client: Optional[Redis] = None,
978+
) -> int:
979+
if client is None:
980+
client = self.get_client(write=True)
981+
982+
key = self.make_key(key, version=version)
983+
nmembers = [self.encode(member) for member in members]
984+
return int(client.srem(key, *nmembers))
985+
986+
def sscan(
987+
self,
988+
key: KeyT,
989+
match: Optional[str] = None,
990+
count: Optional[int] = 10,
991+
version: Optional[int] = None,
992+
client: Optional[Redis] = None,
993+
) -> Set[Any]:
994+
if self._has_compression_enabled() and match:
995+
err_msg = "Using match with compression is not supported."
996+
raise ValueError(err_msg)
997+
998+
if client is None:
999+
client = self.get_client(write=False)
1000+
1001+
key = self.make_key(key, version=version)
1002+
1003+
cursor, result = client.sscan(
1004+
key,
1005+
match=cast(PatternT, self.encode(match)) if match else None,
1006+
count=count,
1007+
)
1008+
return {self.decode(value) for value in result}
1009+
1010+
def sscan_iter(
1011+
self,
1012+
key: KeyT,
1013+
match: Optional[str] = None,
1014+
count: Optional[int] = 10,
1015+
version: Optional[int] = None,
1016+
client: Optional[Redis] = None,
1017+
) -> Iterator[Any]:
1018+
if self._has_compression_enabled() and match:
1019+
err_msg = "Using match with compression is not supported."
1020+
raise ValueError(err_msg)
1021+
1022+
if client is None:
1023+
client = self.get_client(write=False)
1024+
1025+
key = self.make_key(key, version=version)
1026+
for value in client.sscan_iter(
1027+
key,
1028+
match=cast(PatternT, self.encode(match)) if match else None,
1029+
count=count,
1030+
):
1031+
yield self.decode(value)
1032+
1033+
def sunion(
1034+
self,
1035+
*keys: KeyT,
1036+
version: Optional[int] = None,
1037+
client: Optional[Redis] = None,
1038+
) -> Set:
1039+
if client is None:
1040+
client = self.get_client(write=False)
1041+
1042+
nkeys = [self.make_key(key, version=version) for key in keys]
1043+
return {self.decode(value) for value in client.sunion(*nkeys)}
1044+
1045+
def sunionstore(
1046+
self,
1047+
destination: Any,
1048+
*keys: KeyT,
1049+
version: Optional[int] = None,
1050+
client: Optional[Redis] = None,
1051+
) -> int:
1052+
if client is None:
1053+
client = self.get_client(write=True)
1054+
1055+
destination = self.make_key(destination, version=version)
1056+
encoded_keys = [self.make_key(key, version=version) for key in keys]
1057+
return int(client.sunionstore(destination, *encoded_keys))
1058+
7811059
def close(self) -> None:
7821060
close_flag = self._options.get(
7831061
"CLOSE_CONNECTION",

0 commit comments

Comments
 (0)