diff --git a/merino/providers/suggest/finance/backends/polygon/backend.py b/merino/providers/suggest/finance/backends/polygon/backend.py index 26054a5cd..2fa4808a0 100644 --- a/merino/providers/suggest/finance/backends/polygon/backend.py +++ b/merino/providers/suggest/finance/backends/polygon/backend.py @@ -1,5 +1,7 @@ """A wrapper for Polygon API interactions.""" +import asyncio +from asyncio.queues import QueueFull import itertools import hashlib import logging @@ -79,6 +81,9 @@ """ SCRIPT_ID_BULK_FETCH_TICKERS: str = "bulk_fetch_tickers" +# Type alias for parsed cached data +ParsedCachedData = list[Optional[Tuple[TickerSnapshot, int]]] + class PolygonBackend: """Backend that connects to the Polygon API.""" @@ -95,6 +100,8 @@ class PolygonBackend: url_single_ticker_snapshot: str url_single_ticker_overview: str filemanager: PolygonFilemanager + cache_refresh_task: asyncio.Task + cache_refresh_task_queue: asyncio.Queue def __init__( self, @@ -136,6 +143,9 @@ def __init__( SCRIPT_ID_BULK_WRITE_TICKERS, LUA_SCRIPT_CACHE_BULK_WRITE_TICKERS ) + self.cache_refresh_task_queue = asyncio.Queue() + self.cache_refresh_task = asyncio.create_task(self.refresh_ticker_cache_entries()) + async def get_snapshots(self, tickers: list[str]) -> list[TickerSnapshot]: """Get snapshots for the list of tickers.""" # check the cache first. @@ -168,12 +178,11 @@ def get_ticker_summary( """ return build_ticker_summary(snapshot, image_url) - async def get_snapshots_from_cache( - self, tickers: list[str] - ) -> list[Optional[Tuple[TickerSnapshot, int]]]: + async def get_snapshots_from_cache(self, tickers: list[str]) -> ParsedCachedData: """Return snapshots from the cache with their respective TTLs in a list of tuples format.""" + parsed_cached_data = [] + cache_keys = [] try: - cache_keys = [] for ticker in tickers: cache_keys.append(generate_cache_key_for_ticker(ticker)) @@ -184,14 +193,29 @@ async def get_snapshots_from_cache( readonly=True, ) - if cached_data: - parsed_cached_data = self._parse_cached_data(cached_data) - return parsed_cached_data + # early exit with empty list if all items are None in cached data + if cached_data is None or all(item is None for item in cached_data): + return [] + + parsed_cached_data = self._parse_cached_data(cached_data) + + # early exit with empty list if all items are None in parsed cached data + if parsed_cached_data is None or all(item is None for item in parsed_cached_data): + return [] + + # get a list of ticker snapshots that need to be refreshed + tickers_to_refresh = self._tickers_to_refresh(parsed_cached_data) + + # add list of tickers that need to be refreshed to the queue + self.cache_refresh_task_queue.put_nowait(tickers_to_refresh) + except CacheAdapterError as exc: logger.error(f"Failed to fetch snapshots from Redis: {exc}") - + except QueueFull: + logger.error("Ticker refresh queue is full") + finally: # TODO @Herraj -- Propagate the error for circuit breaking as PolygonError. - return [] + return parsed_cached_data async def fetch_ticker_snapshot(self, ticker: str) -> Any | None: """Make a request and fetch the snapshot for this single ticker.""" @@ -401,10 +425,31 @@ async def store_snapshots_in_cache(self, snapshots: list[TickerSnapshot]) -> Non ], ) + async def refresh_ticker_cache_entries(self) -> None: + """Refresh ticker snapshot in cache. Fetches new snapshot from upstream API and + fires a background task to write it to the cache. + """ + while True: + try: + # Get the tickers list from the queue + # NOTE: get_nowait() can throw QueueEmpty exception + tickers: list[str] = self.cache_refresh_task_queue.get_nowait() + snapshots = await self.get_snapshots(tickers) + + # Early exit if no snapshots returned. + # Although, the store method also has the same check. + if len(snapshots) == 0: + return + + await self.store_snapshots_in_cache(snapshots) + + # notify queue that the task is done + self.cache_refresh_task_queue.task_done() + except Exception as exc: + logger.warning(f"Error occerred while refreshing ticker snapshots: {exc}") + # TODO @herraj add unit tests for this - def _parse_cached_data( - self, cached_data: list[bytes | None] - ) -> list[Optional[Tuple[TickerSnapshot, int]]]: + def _parse_cached_data(self, cached_data: list[bytes | None]) -> ParsedCachedData: """Parse Redis output of the form [snapshot_json, ttl, snapshot_json, ttl, ...]. Each snapshot is JSON-decoded and validated into a `TickerSnapshot`, and each TTL is converted to an int. @@ -418,7 +463,7 @@ def _parse_cached_data( if (len(cached_data) % 2) != 0: return [] - result: list[Optional[Tuple[TickerSnapshot, int]]] = [] + result: ParsedCachedData = [] # every even index is a snapshot and odd index is its TTL for snapshot, ttl in itertools.batched(cached_data, 2): @@ -441,6 +486,21 @@ def _parse_cached_data( return result + def _tickers_to_refresh(self, parsed_cached_data: ParsedCachedData) -> list[str]: + """Loop through the parsed cached data (list of tuples) and + return a list of tickers whose snapshots need to be refreshed. + """ + tickers_to_refresh = [] + + for snaphot_ttl_tuple in parsed_cached_data: + if snaphot_ttl_tuple is not None: + snaphot, ttl = snaphot_ttl_tuple + + if ttl < self.ticker_ttl_sec / 2: + tickers_to_refresh.append(snaphot.ticker) + + return tickers_to_refresh + async def shutdown(self) -> None: """Close http client and cache connections.""" logger.info("Shutting down polygon backend") diff --git a/tests/integration/providers/suggest/finance/backends/test_polygon.py b/tests/integration/providers/suggest/finance/backends/test_polygon.py index 3741878e0..e3df90df3 100644 --- a/tests/integration/providers/suggest/finance/backends/test_polygon.py +++ b/tests/integration/providers/suggest/finance/backends/test_polygon.py @@ -25,6 +25,7 @@ from merino.providers.suggest.finance.backends.protocol import TickerSnapshot from merino.providers.suggest.finance.backends.polygon import PolygonBackend +from merino.providers.suggest.finance.backends.polygon.utils import generate_cache_key_for_ticker logger = logging.getLogger(__name__) @@ -83,6 +84,29 @@ def fixture_polygon_parameters( } +@pytest.fixture(name="polygon_factory") +def fixture_polygon_factory(mocker: MockerFixture, statsd_mock: Any, redis_client: Redis): + """Return factory fixture to create Polygon backend parameters with overrides.""" + + def _polygon_parameters(**overrides: Any) -> dict[str, Any]: + params = { + "api_key": "api_key", + "metrics_client": statsd_mock, + "http_client": mocker.AsyncMock(spec=AsyncClient), + "metrics_sample_rate": 1, + "url_param_api_key": "apiKey", + "url_single_ticker_snapshot": URL_SINGLE_TICKER_SNAPSHOT, + "url_single_ticker_overview": URL_SINGLE_TICKER_OVERVIEW, + "gcs_uploader": mocker.MagicMock(), + "cache": RedisAdapter(redis_client), + "ticker_ttl_sec": TICKER_TTL_SEC, + } + params.update(overrides) + return params + + return _polygon_parameters + + @pytest.fixture(name="polygon") def fixture_polygon( polygon_parameters: dict[str, Any], @@ -119,6 +143,14 @@ def fixture_ticker_snapshot_NFLX() -> TickerSnapshot: ) +async def set_redis_key_expiry( + redis_client: Redis, keys_and_expiry: list[tuple[str, int]] +) -> None: + """Set redis cache key expiry (TTL seconds).""" + for key, ttl in keys_and_expiry: + await redis_client.expire(key, ttl) + + @pytest.mark.asyncio async def test_get_snapshots_from_cache_success( mocker: MockerFixture, @@ -195,3 +227,44 @@ async def test_get_snapshots_from_cache_raises_cache_error( assert len(records) == 1 assert records[0].message.startswith("Failed to fetch snapshots from Redis: test cache error") assert actual == [] + + +@pytest.mark.asyncio +async def test_refresh_ticker_cache_entries_success( + polygon_factory, + ticker_snapshot_AAPL: TickerSnapshot, + ticker_snapshot_NFLX, + redis_client: Redis, + mocker, +) -> None: + """Test that refresh_ticker_cache_entries method successfully writes snapshots to cache with new TTL.""" + polygon = PolygonBackend(**polygon_factory(cache=RedisAdapter(redis_client))) + + # Mocking the get_snapshots method to return AAPL and NFLX snapshots fixtures for 2 calls. + get_snapshots_mock = mocker.patch.object( + polygon, "get_snapshots", new_callable=mocker.AsyncMock + ) + get_snapshots_mock.return_value = [ticker_snapshot_AAPL, ticker_snapshot_NFLX] + + expected = [(ticker_snapshot_AAPL, TICKER_TTL_SEC), (ticker_snapshot_NFLX, TICKER_TTL_SEC)] + + # write to cache (this method writes with the default 300 sec TTL) + await polygon.store_snapshots_in_cache([ticker_snapshot_AAPL, ticker_snapshot_NFLX]) + + # manually modify the TTL for the above cache entries to 100 instead of 300 + cache_keys = [] + for key in ["AAPL", "NFLX"]: + cache_keys.append(generate_cache_key_for_ticker(key)) + await set_redis_key_expiry(redis_client, [(cache_keys[0], 100), (cache_keys[1], 100)]) + + # refresh the cache entries -- this should reset the TTL to 300 + # forcing the await here otherwise this task finishes after test execution + await polygon.refresh_ticker_cache_entries(["AAPL", "NFLX"], await_store=True) + + actual = await polygon.get_snapshots_from_cache(["AAPL", "NFLX"]) + + assert actual is not None + assert actual == expected + + assert actual[0] == expected[0] + assert actual[1] == expected[1] diff --git a/tests/unit/providers/suggest/finance/backends/test_polygon.py b/tests/unit/providers/suggest/finance/backends/test_polygon.py index f45828b56..007894424 100644 --- a/tests/unit/providers/suggest/finance/backends/test_polygon.py +++ b/tests/unit/providers/suggest/finance/backends/test_polygon.py @@ -312,8 +312,7 @@ async def test_get_snapshots_success( """ tickers = ["AAPL", "MSFT", "TSLA"] - # Mocking the fetch_ticker_snapshot method to return single_ticker_snapshot_response fixture for two of the calls. - # Returns None for one of the calls. + # Mocking the fetch_ticker_snapshot method to return single_ticker_snapshot_response fixture for all three calls. fetch_mock = mocker.patch.object( polygon, "fetch_ticker_snapshot",