Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions merino/providers/suggest/finance/backends/polygon/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A wrapper for Polygon API interactions."""

import asyncio
from asyncio.queues import QueueFull
import itertools
import hashlib
import logging
Expand Down Expand Up @@ -79,6 +81,9 @@
"""
SCRIPT_ID_BULK_FETCH_TICKERS: str = "bulk_fetch_tickers"

# Type alias for parsed cached data
ParsedCachedData = list[Optional[Tuple[TickerSnapshot, int]]]


class PolygonBackend:
"""Backend that connects to the Polygon API."""
Expand All @@ -95,6 +100,8 @@ class PolygonBackend:
url_single_ticker_snapshot: str
url_single_ticker_overview: str
filemanager: PolygonFilemanager
cache_refresh_task: asyncio.Task
cache_refresh_task_queue: asyncio.Queue

def __init__(
self,
Expand Down Expand Up @@ -136,6 +143,9 @@ def __init__(
SCRIPT_ID_BULK_WRITE_TICKERS, LUA_SCRIPT_CACHE_BULK_WRITE_TICKERS
)

self.cache_refresh_task_queue = asyncio.Queue()
self.cache_refresh_task = asyncio.create_task(self.refresh_ticker_cache_entries())

async def get_snapshots(self, tickers: list[str]) -> list[TickerSnapshot]:
"""Get snapshots for the list of tickers."""
# check the cache first.
Expand Down Expand Up @@ -168,12 +178,11 @@ def get_ticker_summary(
"""
return build_ticker_summary(snapshot, image_url)

async def get_snapshots_from_cache(
self, tickers: list[str]
) -> list[Optional[Tuple[TickerSnapshot, int]]]:
async def get_snapshots_from_cache(self, tickers: list[str]) -> ParsedCachedData:
"""Return snapshots from the cache with their respective TTLs in a list of tuples format."""
parsed_cached_data = []
cache_keys = []
try:
cache_keys = []
for ticker in tickers:
cache_keys.append(generate_cache_key_for_ticker(ticker))

Expand All @@ -184,14 +193,29 @@ async def get_snapshots_from_cache(
readonly=True,
)

if cached_data:
parsed_cached_data = self._parse_cached_data(cached_data)
return parsed_cached_data
# early exit with empty list if all items are None in cached data
if cached_data is None or all(item is None for item in cached_data):
return []

parsed_cached_data = self._parse_cached_data(cached_data)

# early exit with empty list if all items are None in parsed cached data
if parsed_cached_data is None or all(item is None for item in parsed_cached_data):
return []

# get a list of ticker snapshots that need to be refreshed
tickers_to_refresh = self._tickers_to_refresh(parsed_cached_data)

# add list of tickers that need to be refreshed to the queue
self.cache_refresh_task_queue.put_nowait(tickers_to_refresh)

except CacheAdapterError as exc:
logger.error(f"Failed to fetch snapshots from Redis: {exc}")

except QueueFull:
logger.error("Ticker refresh queue is full")
finally:
# TODO @Herraj -- Propagate the error for circuit breaking as PolygonError.
return []
return parsed_cached_data

async def fetch_ticker_snapshot(self, ticker: str) -> Any | None:
"""Make a request and fetch the snapshot for this single ticker."""
Expand Down Expand Up @@ -401,10 +425,31 @@ async def store_snapshots_in_cache(self, snapshots: list[TickerSnapshot]) -> Non
],
)

async def refresh_ticker_cache_entries(self) -> None:
"""Refresh ticker snapshot in cache. Fetches new snapshot from upstream API and
fires a background task to write it to the cache.
"""
while True:
try:
# Get the tickers list from the queue
# NOTE: get_nowait() can throw QueueEmpty exception
tickers: list[str] = self.cache_refresh_task_queue.get_nowait()
snapshots = await self.get_snapshots(tickers)

# Early exit if no snapshots returned.
# Although, the store method also has the same check.
if len(snapshots) == 0:
return

await self.store_snapshots_in_cache(snapshots)

# notify queue that the task is done
self.cache_refresh_task_queue.task_done()
except Exception as exc:
logger.warning(f"Error occerred while refreshing ticker snapshots: {exc}")

# TODO @herraj add unit tests for this
def _parse_cached_data(
self, cached_data: list[bytes | None]
) -> list[Optional[Tuple[TickerSnapshot, int]]]:
def _parse_cached_data(self, cached_data: list[bytes | None]) -> ParsedCachedData:
"""Parse Redis output of the form [snapshot_json, ttl, snapshot_json, ttl, ...].
Each snapshot is JSON-decoded and validated into a `TickerSnapshot`,
and each TTL is converted to an int.
Expand All @@ -418,7 +463,7 @@ def _parse_cached_data(
if (len(cached_data) % 2) != 0:
return []

result: list[Optional[Tuple[TickerSnapshot, int]]] = []
result: ParsedCachedData = []

# every even index is a snapshot and odd index is its TTL
for snapshot, ttl in itertools.batched(cached_data, 2):
Expand All @@ -441,6 +486,21 @@ def _parse_cached_data(

return result

def _tickers_to_refresh(self, parsed_cached_data: ParsedCachedData) -> list[str]:
"""Loop through the parsed cached data (list of tuples) and
return a list of tickers whose snapshots need to be refreshed.
"""
tickers_to_refresh = []

for snaphot_ttl_tuple in parsed_cached_data:
if snaphot_ttl_tuple is not None:
snaphot, ttl = snaphot_ttl_tuple

if ttl < self.ticker_ttl_sec / 2:
tickers_to_refresh.append(snaphot.ticker)

return tickers_to_refresh

async def shutdown(self) -> None:
"""Close http client and cache connections."""
logger.info("Shutting down polygon backend")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from merino.providers.suggest.finance.backends.protocol import TickerSnapshot
from merino.providers.suggest.finance.backends.polygon import PolygonBackend
from merino.providers.suggest.finance.backends.polygon.utils import generate_cache_key_for_ticker

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,6 +84,29 @@ def fixture_polygon_parameters(
}


@pytest.fixture(name="polygon_factory")
def fixture_polygon_factory(mocker: MockerFixture, statsd_mock: Any, redis_client: Redis):
"""Return factory fixture to create Polygon backend parameters with overrides."""

def _polygon_parameters(**overrides: Any) -> dict[str, Any]:
params = {
"api_key": "api_key",
"metrics_client": statsd_mock,
"http_client": mocker.AsyncMock(spec=AsyncClient),
"metrics_sample_rate": 1,
"url_param_api_key": "apiKey",
"url_single_ticker_snapshot": URL_SINGLE_TICKER_SNAPSHOT,
"url_single_ticker_overview": URL_SINGLE_TICKER_OVERVIEW,
"gcs_uploader": mocker.MagicMock(),
"cache": RedisAdapter(redis_client),
"ticker_ttl_sec": TICKER_TTL_SEC,
}
params.update(overrides)
return params

return _polygon_parameters


@pytest.fixture(name="polygon")
def fixture_polygon(
polygon_parameters: dict[str, Any],
Expand Down Expand Up @@ -119,6 +143,14 @@ def fixture_ticker_snapshot_NFLX() -> TickerSnapshot:
)


async def set_redis_key_expiry(
redis_client: Redis, keys_and_expiry: list[tuple[str, int]]
) -> None:
"""Set redis cache key expiry (TTL seconds)."""
for key, ttl in keys_and_expiry:
await redis_client.expire(key, ttl)


@pytest.mark.asyncio
async def test_get_snapshots_from_cache_success(
mocker: MockerFixture,
Expand Down Expand Up @@ -195,3 +227,44 @@ async def test_get_snapshots_from_cache_raises_cache_error(
assert len(records) == 1
assert records[0].message.startswith("Failed to fetch snapshots from Redis: test cache error")
assert actual == []


@pytest.mark.asyncio
async def test_refresh_ticker_cache_entries_success(
polygon_factory,
ticker_snapshot_AAPL: TickerSnapshot,
ticker_snapshot_NFLX,
redis_client: Redis,
mocker,
) -> None:
"""Test that refresh_ticker_cache_entries method successfully writes snapshots to cache with new TTL."""
polygon = PolygonBackend(**polygon_factory(cache=RedisAdapter(redis_client)))

# Mocking the get_snapshots method to return AAPL and NFLX snapshots fixtures for 2 calls.
get_snapshots_mock = mocker.patch.object(
polygon, "get_snapshots", new_callable=mocker.AsyncMock
)
get_snapshots_mock.return_value = [ticker_snapshot_AAPL, ticker_snapshot_NFLX]

expected = [(ticker_snapshot_AAPL, TICKER_TTL_SEC), (ticker_snapshot_NFLX, TICKER_TTL_SEC)]

# write to cache (this method writes with the default 300 sec TTL)
await polygon.store_snapshots_in_cache([ticker_snapshot_AAPL, ticker_snapshot_NFLX])

# manually modify the TTL for the above cache entries to 100 instead of 300
cache_keys = []
for key in ["AAPL", "NFLX"]:
cache_keys.append(generate_cache_key_for_ticker(key))
await set_redis_key_expiry(redis_client, [(cache_keys[0], 100), (cache_keys[1], 100)])

# refresh the cache entries -- this should reset the TTL to 300
# forcing the await here otherwise this task finishes after test execution
await polygon.refresh_ticker_cache_entries(["AAPL", "NFLX"], await_store=True)

actual = await polygon.get_snapshots_from_cache(["AAPL", "NFLX"])

assert actual is not None
assert actual == expected

assert actual[0] == expected[0]
assert actual[1] == expected[1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests would need to be refactored since I refactored the source method to use a queue instead 😅

Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,7 @@ async def test_get_snapshots_success(
"""
tickers = ["AAPL", "MSFT", "TSLA"]

# Mocking the fetch_ticker_snapshot method to return single_ticker_snapshot_response fixture for two of the calls.
# Returns None for one of the calls.
# Mocking the fetch_ticker_snapshot method to return single_ticker_snapshot_response fixture for all three calls.
fetch_mock = mocker.patch.object(
polygon,
"fetch_ticker_snapshot",
Expand Down
Loading