Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ I leveraged sanic [middleware](https://sanic.dev/en/guide/basics/middleware.md#l

With the current testing setup using threading, the maximum number of concurrent active requests is ~5 on my machine due to request response timing. I'm not aware of a clean way to test different environment-level config values and assert different responses for the same test case... That being said, I temporarily set the configuration to `CONCURRENT_REQUESTS_MAX=2` and ran the tests by hand to verify the correct (503) response. Could probably achieve this with a script, passing the variable to Make, i.e. `CONCURRENT_REQUESTS_MAX=5 make test`, but adding conditional asserts based on an env variable feels clunky.

## Redis Serialization Protocol
Implemented separately as simple asyncio TCP server but using the same ClientCache, configured similarly. Specification stated needs to handle `GET`. In testing with redis-py, also need to handle `CLIENT SETINFO` commands on initial call. For now, just respond OK, but could extend server to track connected client history and support other `CLIENT` commands. This server implementation also tracks the number of connected client and increment/decrements appropriately on the main connection handler.

# Improvements
- refactor cache implementation to make use of threading native types to allow several worker processes.
- ghcr to improve ci time, though not significant for this small project.
Expand All @@ -67,3 +70,6 @@ Sun
- 2 hr refactoring/tidying up
- 2 hr docs, ci
- 2 hr impl/debug configurable max concurrent requests (incl. digging around in sanic docs)

Mon
- 2.5hr reading docs + simple resp server impl working locally with redis-py
2 changes: 1 addition & 1 deletion proxy/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ COPY ./ /code/

RUN pip install .

CMD ["python", "-m", "proxy.server"]
CMD ["python", "-m", "proxy.server_http"]
8 changes: 8 additions & 0 deletions proxy/Dockerfile.resp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FROM python:3.10

WORKDIR /code


COPY server.py .

CMD ["python", "-m" , "server_resp"]
3 changes: 2 additions & 1 deletion proxy/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ dev = ["black", "isort", "ruff"]
where = ["src"]

[project.scripts]
server = "proxy:server.main"
server_http = "proxy:server_http.main"
server_resp = "proxy:server_resp.main"

[tool.black]
line-length = 120
Expand Down
20 changes: 14 additions & 6 deletions proxy/src/proxy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import redis.asyncio as redis


class LocalCache(Protocol):
class ICache(Protocol):
"""Cache Interface"""

def get(self, key: Any, default: Optional[Any] = None) -> Optional[Any]:
Expand All @@ -27,7 +27,7 @@ class TTLLRUCacheConfig:
ttl_seconds: int = 60


class TTLLRUCache(LocalCache):
class TTLLRUCache(ICache):
"""
A thread-safe TTL LRU cache implementation.
NOTE: Currently doesn't support shared access by multiple workers.
Expand Down Expand Up @@ -101,22 +101,30 @@ def __is_expired(self, key: Any) -> bool:
return time_ns() > self.__expiry_map[key]


class ClientCache:
def __init__(self, host, port, cache: LocalCache):
class ClientCache(ICache):
def __init__(self, host, port, cache: ICache):
self.__cache = cache
self.__pool = redis.ConnectionPool.from_url(url=f"redis://{host}:{port}", decode_responses=True)
self.__client = redis.Redis(connection_pool=self.__pool)

async def get(self, key: Any) -> Any:
async def get(self, key: Any, default: Optional[Any] = None) -> Optional[Any]:
timer = time_ns()
if val := await self.__cache.get(key):
logger.debug(f"Cache hit time: {(time_ns() - timer) / 1e6} ms")
logger.debug(f"Cache hit for key: {key}, value: {val}")
return val

logger.info(f"Cache miss for key: {key}")
value = await self.__client.get(key)
try:
value = await self.__client.get(key)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return default

logger.debug(f"Cache miss time: {(time_ns() - timer) / 1e6} ms")
if not value:
logger.debug(f"Key: {key} not found in redis")
return default
logger.debug(f"Retrieved value from redis for key: {key}, value: {value}")

if value:
Expand Down
52 changes: 52 additions & 0 deletions proxy/src/proxy/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# https://redis.io/docs/latest/develop/reference/protocol-spec/
# always begin with 'type': simple, bulk, or aggregate
RESP_TYPE_BYTE_MAP = {
"simple_string": b"+",
"simple_error": b"-",
"integer": b":",
"bulk_string": b"$",
"array": b"*",
"null": b"_",
"boolean": b"#",
"double": b",",
"big_num": b"(",
"bulk_error": b"!",
"verbatim_string": b"=",
"map": b"%",
"attribute": b"`",
"set": b"~",
"push": b">",
}

# always terminate with CRLF
CRLF = "\r\n"

class SerializationError(Exception):
pass

async def deserialize_resp(request):
# commands are received as an array(*) of bulk string ($)
commands = request.decode().strip().split(CRLF)

if commands[0][0] == "*":
num_elements = int(commands[0][1])
elements = []
i = 1
while i < len(commands):
# bulk string $[length] to extract substring
if commands[i][0] == "$":
length = int(commands[i][1:])
elements.append(commands[i + 1])
i += 2
else:
i += 1
else:
raise SerializationError("Invalid RESP protocol format")

return elements

async def serialize_resp(value: any):
# TODO(kevinc): handle other primitive types, for now just always bulk strings
if not value:
return f"$-1{CRLF}"
return f"${len(value)}{CRLF}{value}{CRLF}"
File renamed without changes.
83 changes: 83 additions & 0 deletions proxy/src/proxy/server_resp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Redis serialization protocol (RESP3) server
# https://redis.io/docs/latest/develop/reference/protocol-spec/
import asyncio
import logging
from asyncio import Lock
from dataclasses import dataclass

from proxy.cache import ICache, ClientCache, TTLLRUCache, TTLLRUCacheConfig
from proxy.serialization import CRLF, deserialize_resp, serialize_resp


@dataclass
class RESPServerConfig:
host: str = "0.0.0.0"
port: int = 6379
max_clients: int = 100

class RESPProxyServer:
def __init__(self, config: RESPServerConfig, cache: ICache):
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s")
self.logger = logging.getLogger("RESP Server")
self.config = config
self.__cache = cache
self.__lock = Lock()
self.__clients = 0

async def run(self):
server = await asyncio.start_server(self.handle_resp_request, host=self.config.host, port=self.config.port)
addr = server.sockets[0].getsockname()
self.logger.info(f"Starting RESP server on {addr}")
async with server:
await server.serve_forever()

async def handle_resp_request(self, s_reader, s_writer):
async with self.__lock:
if self.__clients >= self.config.max_clients:
self.logger.warning("Max clients reached, closing connection")
s_writer.close()
return
self.__clients += 1
client_addr = s_writer.get_extra_info("peername")
self.logger.info(f"New connection from {client_addr}")

try:
while True:
request = await s_reader.read(4096)
if not request:
break
cmd = await deserialize_resp(request)
self.logger.debug(f"cmd: {cmd}")
out = await self.handle_command(cmd)
s_writer.write(out.encode())
await s_writer.drain()

except Exception as e:
self.logger.error(f"{e}")
finally:
async with self.__lock:
self.__clients -= 1
s_writer.close()
await s_writer.wait_closed()


async def handle_command(self, cmd_arr: list):
if cmd_arr[0] == "CLIENT":
# respond OK to get past the initial ack
if cmd_arr[1] == "SETINFO":
return f"+OK{CRLF}"

elif cmd_arr[0] == "GET":
# val = await self.__cache.get(cmd_arr[1])
# TMP(kevinc): short circuit backing cache to verify response
val = 'bar'
return await serialize_resp(val)

return f"-ERR unimplemented command '{cmd_arr[0]}'{CRLF}"

if __name__ == "__main__":
from proxy.config import REDIS_HOST, REDIS_PORT, CACHE_MAX_KEYS, CACHE_TTL_SEC
cache = TTLLRUCache(TTLLRUCacheConfig(max_size=CACHE_MAX_KEYS, ttl_seconds=CACHE_TTL_SEC))
client_cache = ClientCache(REDIS_HOST, REDIS_PORT, cache)
server = RESPProxyServer(RESPServerConfig(), client_cache)
asyncio.run(server.run())