diff --git a/pyproject.toml b/pyproject.toml index 563ff86..2b2b2aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ readme = "README.md" requires-python = ">=3.11" [project.optional-dependencies] -server = ["fastapi", "uvicorn", "redis", "hiredis"] +server = ["fastapi", "uvicorn", "redis", "hiredis", "cachetools"] dev = [ "copier", "httpx", @@ -31,6 +31,9 @@ dev = [ "ruff", "tox-direct", "types-mock", + "cachetools", + "fastapi", + "uvicorn", ] [project.scripts] diff --git a/src/daq_config_server/client.py b/src/daq_config_server/client.py index bb32633..4f72379 100644 --- a/src/daq_config_server/client.py +++ b/src/daq_config_server/client.py @@ -1,26 +1,101 @@ +import operator from logging import Logger, getLogger from typing import Any import requests +from cachetools import TTLCache, cachedmethod from .constants import ENDPOINTS class ConfigServer: - def __init__(self, url: str, log: Logger | None = None) -> None: + def __init__( + self, + url: str, + log: Logger | None = None, + cache_size: int = 10, + cache_lifetime_s: int = 3600, + ) -> None: + """ + Initialize the ConfigServer client. + + Args: + url: Base URL of the config server. + log: Optional logger instance. + cache_size: Size of the cache (maximum number of items can be stored). + cache_lifetime_s: Lifetime of the cache (in seconds). + """ self._url = url.rstrip("/") self._log = log if log else getLogger("daq_config_server.client") + self._cache = TTLCache(maxsize=cache_size, ttl=cache_lifetime_s) def _get( self, endpoint: str, item: str | None = None, - ): - r = requests.get(self._url + endpoint + (f"/{item}" if item else "")) - return r.json() - - def read_unformatted_file(self, file_path: str) -> Any: - # After https://github.com/DiamondLightSource/daq-config-server/issues/67, we - # can get specific formats, and then have better typing on - # return values - return self._get(ENDPOINTS.CONFIG, file_path) + reset_cached_result: bool = False, + ) -> Any: + """ + Get data from the config server with cache management. + If a cached response doesn't already exist, makes a request to + the config server. + If reset_cached_result is true, remove the cache entry for that request and + make a new request + + Args: + endpoint: API endpoint. + item: Optional item identifier. + reset_cached_result: Whether to reset cache. + + Returns: + The response data. + """ + + if (endpoint, item) in self._cache and reset_cached_result: + del self._cache[(endpoint, item)] + return self._cached_get(endpoint, item) + + @cachedmethod(cache=operator.attrgetter("_cache")) + def _cached_get( + self, + endpoint: str, + item: str | None = None, + ) -> Any: + """ + Get data from the config server and cache it. + + Args: + endpoint: API endpoint. + item: Optional item identifier. + + Returns: + The response data. + """ + url = self._url + endpoint + (f"/{item}" if item else "") + + try: + r = requests.get(url) + r.raise_for_status() + data = r.json() + self._log.debug(f"Cache set for {endpoint}/{item}.") + return data + except requests.exceptions.HTTPError as e: + self._log.error(f"HTTP error: {e}") + raise + + def read_unformatted_file( + self, file_path: str, reset_cached_result: bool = False + ) -> Any: + """ + Read an unformatted file from the config server. + + Args: + file_path: Path to the file. + reset_cached_result: Whether to reset cache. + + Returns: + The file content. + """ + return self._get( + ENDPOINTS.CONFIG, file_path, reset_cached_result=reset_cached_result + ) diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 79ec0c2..c8a3b09 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -1,18 +1,115 @@ +from time import sleep from unittest.mock import MagicMock, patch +import pytest +import requests from fastapi import status -from httpx import Response from daq_config_server.client import ConfigServer from daq_config_server.constants import ENDPOINTS -# More useful tests for the client are in tests/system_tests +def make_mock_response(json_value, status_code=200, raise_exc=None): + mock_response = MagicMock() + mock_response.json.return_value = json_value + mock_response.status_code = status_code + if raise_exc: + mock_response.raise_for_status.side_effect = raise_exc + else: + mock_response.raise_for_status.return_value = None + return mock_response + + @patch("daq_config_server.client.requests.get") def test_read_unformatted_file(mock_request: MagicMock): - mock_request.return_value = Response(status_code=status.HTTP_200_OK, json="test") + """Test that read_unformatted_file calls the correct endpoint and + returns the expected result.""" + mock_request.return_value = make_mock_response({"key": "value"}, status.HTTP_200_OK) file_path = "test" url = "url" server = ConfigServer(url) - server.read_unformatted_file(file_path) + result = server.read_unformatted_file(file_path) + assert result == {"key": "value"} mock_request.assert_called_once_with(url + ENDPOINTS.CONFIG + "/" + file_path) + + +@patch("daq_config_server.client.requests.get") +def test_read_unformatted_file_reading_cache(mock_request: MagicMock): + """Test cache behavior for read_unformatted_file.""" + mock_request.side_effect = [ + make_mock_response("1st_read", status.HTTP_200_OK), + make_mock_response("2nd_read", status.HTTP_200_OK), + make_mock_response("3rd_read", status.HTTP_200_OK), + ] + file_path = "test" + url = "url" + server = ConfigServer(url) + assert server.read_unformatted_file(file_path) == "1st_read" + assert server.read_unformatted_file(file_path) == "1st_read" # Cached + assert ( + server.read_unformatted_file(file_path, reset_cached_result=True) == "2nd_read" + ) + assert server.read_unformatted_file(file_path) == "2nd_read" # Cached + assert ( + server.read_unformatted_file(file_path, reset_cached_result=True) == "3rd_read" + ) + + +@patch("daq_config_server.client.requests.get") +def test_read_unformatted_file_reading_reset_cached_result_true_without_cache( + mock_request: MagicMock, +): + """Test repeated reset_cached_result=False disables cache for each call.""" + mock_request.side_effect = [ + make_mock_response("1st_read", status.HTTP_200_OK), + ] + file_path = "test" + url = "url" + server = ConfigServer(url) + assert ( + server.read_unformatted_file(file_path, reset_cached_result=True) == "1st_read" + ) + + +@patch("daq_config_server.client.requests.get") +def test_read_unformatted_file_reading_not_OK(mock_request: MagicMock): + """Test that a non-200 response raises a RequestException.""" + mock_request.return_value = make_mock_response( + "1st_read", status.HTTP_204_NO_CONTENT, raise_exc=requests.exceptions.HTTPError + ) + file_path = "test" + url = "url" + server = ConfigServer(url) + with pytest.raises(requests.exceptions.HTTPError): + server.read_unformatted_file(file_path) + + +@patch("daq_config_server.client.requests.get") +def test_read_unformatted_file_reading_cache_custom_size(mock_request: MagicMock): + mock_request.side_effect = [ + make_mock_response("1st_read", status.HTTP_200_OK), + make_mock_response("2nd_read", status.HTTP_200_OK), + make_mock_response("3rd_read", status.HTTP_200_OK), + ] + file_path = "test" + url = "url" + server = ConfigServer(url=url, cache_size=1) + assert server.read_unformatted_file(file_path) == "1st_read" + assert server.read_unformatted_file(file_path + "1") == "2nd_read" + assert server.read_unformatted_file(file_path) == "3rd_read" + + +@patch("daq_config_server.client.requests.get") +def test_read_unformatted_file_cache_custom_lifetime(mock_request: MagicMock): + mock_request.side_effect = [ + make_mock_response("1st_read", status.HTTP_200_OK), + make_mock_response("2nd_read", status.HTTP_200_OK), + make_mock_response("3rd_read", status.HTTP_200_OK), + ] + file_path = "test" + url = "url" + server = ConfigServer(url=url, cache_lifetime_s=0.1) # type: ignore + assert server.read_unformatted_file(file_path) == "1st_read" + assert server.read_unformatted_file(file_path) == "1st_read" + sleep(0.1) + assert server.read_unformatted_file(file_path) == "2nd_read"