diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 9b38394..f8e2819 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -1511,18 +1511,11 @@ def increment(self, obj): class TestMockClient(ClientTestMixin, unittest.TestCase): - def make_client(self, mock_socket_values, **kwargs): - client = MockMemcacheClient("localhost", **kwargs) - client.sock = MockSocket(list(mock_socket_values)) - return client + def make_client(self, mock_socket_values=None, **kwargs): + return MockMemcacheClient("localhost", **kwargs) def test_get_found(self): - client = self.make_client( - [ - b"STORED\r\n", - b"VALUE key 0 5\r\nvalue\r\nEND\r\n", - ] - ) + client = self.make_client() result = client.set(b"key", b"value", noreply=False) result = client.get(b"key") assert result == b"value" @@ -1539,15 +1532,7 @@ def deserialize(self, key, value, flags): return json.loads(value.decode("UTF-8")) return value - client = self.make_client( - [ - b"STORED\r\n", - b"VALUE key1 0 5\r\nhello\r\nEND\r\n", - b"STORED\r\n", - b'VALUE key2 0 18\r\n{"hello": "world"}\r\nEND\r\n', - ], - serde=JsonSerde(), - ) + client = self.make_client(serde=JsonSerde()) result = client.set(b"key1", b"hello", noreply=False) result = client.get(b"key1") @@ -1557,6 +1542,59 @@ def deserialize(self, key, value, flags): result = client.get(b"key2") assert result == dict(hello="world") + def test_gets_not_found(self): + client = self.make_client() + result = client.gets(b"key") + assert result == (None, None) + + def test_gets_not_found_defaults(self): + client = self.make_client() + result = client.gets(b"key", default="foo", cas_default="bar") + assert result == ("foo", "bar") + + @mock.patch('time.time_ns', return_value=10) + def test_gets_found(self, _): + client = self.make_client() + result = client.set(b"key", b"value", noreply=False) + result = client.gets(b"key") + assert result == (b"value", b"10") + + def test_gets_many_none_found(self): + client = self.make_client([b"END\r\n"]) + result = client.gets_many([b"key1", b"key2"]) + assert result == {} + + @mock.patch('time.time_ns', return_value=11) + def test_gets_many_some_found(self, _): + client = self.make_client() + result = client.set(b"key1", b"value", noreply=False) + result = client.gets_many([b"key1", b"key2"]) + assert result == {b"key1": (b"value", b"11")} + + @mock.patch('time.time_ns', return_value=123) + def test_cas_stored(self, _): + client = self.make_client() + client.set(b"key", b"existing") + result = client.cas(b"key", b"value", b"123", noreply=False) + assert result is True + + result = client.get(b"key") + assert result == b"value" + + def test_cas_exists(self): + client = self.make_client() + client.set(b"key", b"existing") + result = client.cas(b"key", b"value", b"123", noreply=False) + assert result is False + + def test_cas_not_found(self): + client = self.make_client() + result = client.cas(b"key", b"value", b"123", noreply=False) + assert result is None + + result = client.get(b"key") + assert result is None + class TestPrefixedClient(ClientTestMixin, unittest.TestCase): def make_client(self, mock_socket_values, **kwargs): diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py index 52b1732..4c750d7 100644 --- a/pymemcache/test/utils.py +++ b/pymemcache/test/utils.py @@ -38,6 +38,7 @@ def __init__( **kwargs, ): self._contents = {} + self._cas_ids = {} # maps keys to bytes CAS tokens def _serializer(key, value): if isinstance(value, str): @@ -68,6 +69,7 @@ def check_key(self, key): def clear(self): """Method used to clear/reset mock cache""" self._contents.clear() + self._cas_ids.clear() def get(self, key, default=None): key = self.check_key(key) @@ -92,6 +94,28 @@ def get_many(self, keys): get_multi = get_many + def gets(self, key, default=None, cas_default=None): + not_found = [] + + value = self.get(key, default=not_found) + if value is not_found: + return default, cas_default + + cas_token = self._cas_ids.setdefault(key, str(time.time_ns()).encode()) + return value, cas_token + + def gets_many(self, keys): + not_found = [] + + out = {} + for key in keys: + value, cas = self.gets(key, default=not_found) + if value is not not_found: + out[key] = (value, cas) + return out + + get_multi = get_many + def set(self, key, value, expire=0, noreply=True, flags=None): key = self.check_key(key) if isinstance(value, str) and not isinstance(value, bytes): @@ -106,6 +130,7 @@ def set(self, key, value, expire=0, noreply=True, flags=None): expire += time.time() self._contents[key] = expire, value, flags + self._cas_ids[key] = str(time.time_ns()).encode() return True def set_many(self, values, expire=0, noreply=True, flags=None): @@ -189,7 +214,7 @@ def stats(self, *_args): "stat_key_prefix": "", "umask": 0o644, "detail_enabled": False, - "cas_enabled": False, + "cas_enabled": True, "auth_enabled_sasl": False, "maxconns_fast": False, "slab_reassign": False, @@ -203,8 +228,19 @@ def replace(self, key, value, expire=0, noreply=True, flags=None): self.set(key, value, expire, noreply, flags=flags) return noreply or present - def cas(self, key, value, cas, expire=0, noreply=False, flags=None): - raise MemcacheClientError("CAS is not enabled for this instance") + def cas(self, key, value, cas_token, expire=0, noreply=False, **kwargs): + if not isinstance(cas_token, (int, str, bytes)): + raise MemcacheIllegalInputError(f'cas must be integer, string, or bytes, got bad value: {cas_token}') + + key = self.check_key(key) + + if key not in self._contents: + return True if noreply else None + + elif self._cas_ids.get(key) != cas_token: + return True if noreply else False + + return self.set(key, value, noreply=noreply, **kwargs) def touch(self, key, expire=0, noreply=True): current = self.get(key)