Skip to content

Commit 0a87644

Browse files
committed
adjust sync commands to work with shard client
by moving `get_server` logic to `get_client`, and adjusting operation methods, same methods can be used with shard as well (cherry picked from commit 21a1ee9)
1 parent 8ed6a04 commit 0a87644

File tree

2 files changed

+114
-419
lines changed

2 files changed

+114
-419
lines changed

django_valkey/base_client.py

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,16 @@ class ClientCommands(Generic[Backend]):
204204
def __contains__(self, key: KeyT) -> bool:
205205
return self.has_key(key)
206206

207-
def _get_client(self, write=True, tried=None, client=None):
207+
def _get_client(self, write=True, tried=None, client=None, **kwargs):
208208
if client:
209209
return client
210-
return self.get_client(write=write, tried=tried)
210+
return self.get_client(write=write, tried=tried, **kwargs)
211211

212212
def get_client(
213213
self: BaseClient,
214214
write: bool = True,
215215
tried: List[int] | None = None,
216+
**kwargs,
216217
) -> Backend | Any:
217218
"""
218219
Method used for obtain a raw valkey client.
@@ -402,10 +403,10 @@ def get(
402403
403404
Returns decoded value if key is found, the default if not.
404405
"""
405-
client = self._get_client(write=False, client=client)
406-
407406
key = self.make_key(key, version=version)
408407

408+
client = self._get_client(write=False, client=client, key=key)
409+
409410
try:
410411
value = client.get(key)
411412
except _main_exceptions as e:
@@ -422,10 +423,10 @@ def persist(
422423
version: int | None = None,
423424
client: Backend | Any | None = None,
424425
) -> bool:
425-
client = self._get_client(write=True, client=client)
426-
427426
key = self.make_key(key, version=version)
428427

428+
client = self._get_client(write=True, client=client, key=key)
429+
429430
return client.persist(key)
430431

431432
def expire(
@@ -438,10 +439,10 @@ def expire(
438439
if timeout is DEFAULT_TIMEOUT:
439440
timeout = self._backend.default_timeout # type: ignore
440441

441-
client = self._get_client(write=True, client=client)
442-
443442
key = self.make_key(key, version=version)
444443

444+
client = self._get_client(write=True, client=client, key=key)
445+
445446
# for some strange reason mypy complains,
446447
# saying that timeout type is float | timedelta
447448
return client.expire(key, timeout) # type: ignore
@@ -457,10 +458,10 @@ def expire_at(
457458
Set an expiry flag on a ``key`` to ``when``, which can be represented
458459
as an integer indicating unix time or a Python datetime object.
459460
"""
460-
client = self._get_client(write=True, client=client)
461-
462461
key = self.make_key(key, version=version)
463462

463+
client = self._get_client(write=True, client=client, key=key)
464+
464465
return client.expireat(key, when)
465466

466467
def pexpire(
@@ -473,10 +474,10 @@ def pexpire(
473474
if timeout is DEFAULT_TIMEOUT:
474475
timeout = self._backend.default_timeout # type: ignore
475476

476-
client = self._get_client(write=True, client=client)
477-
478477
key = self.make_key(key, version=version)
479478

479+
client = self._get_client(write=True, client=client, key=key)
480+
480481
# TODO: see if the casting is necessary
481482
# for some strange reason mypy complains,
482483
# saying that timeout type is float | timedelta
@@ -493,10 +494,10 @@ def pexpire_at(
493494
Set an expiry flag on a ``key`` to ``when``, which can be represented
494495
as an integer indicating unix time or a Python datetime object.
495496
"""
496-
client = self._get_client(write=True, client=client)
497-
498497
key = self.make_key(key, version=version)
499498

499+
client = self._get_client(write=True, client=client, key=key)
500+
500501
return client.pexpireat(key, when)
501502

502503
def get_lock(
@@ -511,9 +512,10 @@ def get_lock(
511512
lock_class=None,
512513
thread_local: bool = True,
513514
) -> "Lock":
514-
client = self._get_client(write=True, client=client)
515-
516515
key = self.make_key(key, version=version)
516+
517+
client = self._get_client(write=True, client=client, key=key)
518+
517519
return client.lock(
518520
key,
519521
timeout=timeout,
@@ -537,10 +539,12 @@ def delete(
537539
"""
538540
Remove a key from the cache.
539541
"""
540-
client = self._get_client(write=True, client=client)
542+
key = self.make_key(key, version=version, prefix=prefix)
543+
544+
client = self._get_client(write=True, client=client, key=key)
541545

542546
try:
543-
return client.delete(self.make_key(key, version=version, prefix=prefix))
547+
return client.delete(key)
544548
except _main_exceptions as e:
545549
raise ConnectionInterrupted(connection=client) from e
546550

@@ -717,10 +721,10 @@ def _incr(
717721
client: Backend | Any | None = None,
718722
ignore_key_check: bool = False,
719723
) -> int:
720-
client = self._get_client(write=True, client=client)
721-
722724
key = self.make_key(key, version=version)
723725

726+
client = self._get_client(write=True, client=client, key=key)
727+
724728
try:
725729
try:
726730
# if key expired after exists check, then we get
@@ -806,9 +810,10 @@ def ttl(
806810
Executes TTL valkey command and return the "time-to-live" of specified key.
807811
If key is a non-volatile key, it returns None.
808812
"""
809-
client = self._get_client(write=False, client=client)
810-
811813
key = self.make_key(key, version=version)
814+
815+
client = self._get_client(write=False, client=client, key=key)
816+
812817
if not client.exists(key):
813818
return 0
814819

@@ -832,9 +837,9 @@ def pttl(
832837
Executes PTTL valkey command and return the "time-to-live" of specified key.
833838
If key is a non-volatile key, it returns None.
834839
"""
835-
client = self._get_client(write=False, client=client)
836-
837840
key = self.make_key(key, version=version)
841+
client = self._get_client(write=False, client=client, key=key)
842+
838843
if not client.exists(key):
839844
return 0
840845

@@ -857,10 +862,9 @@ def has_key(
857862
"""
858863
Test if key exists.
859864
"""
860-
861-
client = self._get_client(write=False, client=client)
862-
863865
key = self.make_key(key, version=version)
866+
867+
client = self._get_client(write=False, client=client, key=key)
864868
try:
865869
return client.exists(key) == 1
866870
except _main_exceptions as e:
@@ -912,9 +916,10 @@ def sadd(
912916
version: int | None = None,
913917
client: Backend | Any | None = None,
914918
) -> int:
915-
client = self._get_client(write=True, client=client)
916-
917919
key = self.make_key(key, version=version)
920+
921+
client = self._get_client(write=True, client=client, key=key)
922+
918923
encoded_values = [self.encode(value) for value in values]
919924
return client.sadd(key, *encoded_values)
920925

@@ -924,9 +929,10 @@ def scard(
924929
version: int | None = None,
925930
client: Backend | Any | None = None,
926931
) -> int:
927-
client = self._get_client(write=False, client=client)
928-
929932
key = self.make_key(key, version=version)
933+
934+
client = self._get_client(write=False, client=client, key=key)
935+
930936
return client.scard(key)
931937

932938
def sdiff(
@@ -985,9 +991,10 @@ def smismember(
985991
version: int | None = None,
986992
client: Backend | Any | None = None,
987993
) -> List[bool]:
988-
client = self._get_client(write=False, client=client)
989-
990994
key = self.make_key(key, version=version)
995+
996+
client = self._get_client(write=False, client=client, key=key)
997+
991998
encoded_members = [self.encode(member) for member in members]
992999

9931000
return [bool(value) for value in client.smismember(key, *encoded_members)]
@@ -999,9 +1006,10 @@ def sismember(
9991006
version: int | None = None,
10001007
client: Backend | Any | None = None,
10011008
) -> bool:
1002-
client = self._get_client(write=False, client=client)
1003-
10041009
key = self.make_key(key, version=version)
1010+
1011+
client = self._get_client(write=False, client=client, key=key)
1012+
10051013
member = self.encode(member)
10061014
return bool(client.sismember(key, member))
10071015

@@ -1011,9 +1019,10 @@ def smembers(
10111019
version: int | None = None,
10121020
client: Backend | Any | None = None,
10131021
) -> Set[Any]:
1014-
client = self._get_client(write=False, client=client)
1015-
10161022
key = self.make_key(key, version=version)
1023+
1024+
client = self._get_client(write=False, client=client, key=key)
1025+
10171026
return {self.decode(value) for value in client.smembers(key)}
10181027

10191028
def smove(
@@ -1024,10 +1033,11 @@ def smove(
10241033
version: int | None = None,
10251034
client: Backend | Any | None = None,
10261035
) -> bool:
1027-
client = self._get_client(write=True, client=client)
1028-
10291036
source = self.make_key(source, version=version)
10301037
destination = self.make_key(destination)
1038+
1039+
client = self._get_client(write=True, client=client, key=source)
1040+
10311041
member = self.encode(member)
10321042
return client.smove(source, destination, member)
10331043

@@ -1038,9 +1048,10 @@ def spop(
10381048
version: int | None = None,
10391049
client: Backend | Any | None = None,
10401050
) -> Set | Any:
1041-
client = self._get_client(write=True, client=client)
1042-
10431051
nkey = self.make_key(key, version=version)
1052+
1053+
client = self._get_client(write=True, client=client, key=nkey)
1054+
10441055
result = client.spop(nkey, count)
10451056
return self._decode_iterable_result(result)
10461057

@@ -1051,9 +1062,10 @@ def srandmember(
10511062
version: int | None = None,
10521063
client: Backend | Any | None = None,
10531064
) -> List | Any:
1054-
client = self._get_client(write=False, client=client)
1055-
10561065
key = self.make_key(key, version=version)
1066+
1067+
client = self._get_client(write=False, client=client, key=key)
1068+
10571069
result = client.srandmember(key, count)
10581070
return self._decode_iterable_result(result, convert_to_set=False)
10591071

@@ -1064,9 +1076,10 @@ def srem(
10641076
version: int | None = None,
10651077
client: Backend | Any | None = None,
10661078
) -> int:
1067-
client = self._get_client(write=True, client=client)
1068-
10691079
key = self.make_key(key, version=version)
1080+
1081+
client = self._get_client(write=True, client=client, key=key)
1082+
10701083
nmembers = [self.encode(member) for member in members]
10711084
return client.srem(key, *nmembers)
10721085

@@ -1082,10 +1095,10 @@ def sscan(
10821095
err_msg = "Using match with compression is not supported."
10831096
raise ValueError(err_msg)
10841097

1085-
client = self._get_client(write=False, client=client)
1086-
10871098
key = self.make_key(key, version=version)
10881099

1100+
client = self._get_client(write=False, client=client, key=key)
1101+
10891102
cursor, result = client.sscan(
10901103
key,
10911104
match=cast(PatternT, self.encode(match)) if match else None,
@@ -1105,9 +1118,10 @@ def sscan_iter(
11051118
err_msg = "Using match with compression is not supported."
11061119
raise ValueError(err_msg)
11071120

1108-
client = self._get_client(write=False, client=client)
1109-
11101121
key = self.make_key(key, version=version)
1122+
1123+
client = self._get_client(write=False, client=client, key=key)
1124+
11111125
for value in client.sscan_iter(
11121126
key,
11131127
match=cast(PatternT, self.encode(match)) if match else None,
@@ -1170,9 +1184,10 @@ def touch(
11701184
if timeout is DEFAULT_TIMEOUT:
11711185
timeout = self._backend.default_timeout
11721186

1173-
client = self._get_client(write=True, client=client)
1174-
11751187
key = self.make_key(key, version=version)
1188+
1189+
client = self._get_client(write=True, client=client, key=key)
1190+
11761191
if timeout is None:
11771192
return bool(client.persist(key))
11781193

0 commit comments

Comments
 (0)