Skip to content

Commit f8caff2

Browse files
committed
[DISCO-3754] Add background task to refresh snapshot cache
1 parent a2c7654 commit f8caff2

File tree

3 files changed

+97
-2
lines changed

3 files changed

+97
-2
lines changed

merino/providers/suggest/finance/backends/polygon/backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""A wrapper for Polygon API interactions."""
22

3+
import asyncio
34
import itertools
45
import hashlib
56
import logging
@@ -401,6 +402,28 @@ async def store_snapshots_in_cache(self, snapshots: list[TickerSnapshot]) -> Non
401402
],
402403
)
403404

405+
async def refresh_ticker_cache_entries(
406+
self, tickers: list[str], *, await_store: bool = False
407+
) -> None:
408+
"""Refresh ticker snapshot in cache. Fetches new snapshot from upstream API and
409+
fires a background task to write it to the cache.
410+
411+
Note: Only awaits the cache write process if await_store is true. Used only for testing.
412+
"""
413+
snapshots = await self.get_snapshots(tickers)
414+
415+
# early exit if no snapshots returned.
416+
if len(snapshots) == 0:
417+
return
418+
419+
# this parameter is only used for testing.
420+
if await_store:
421+
await self.store_snapshots_in_cache(snapshots)
422+
else:
423+
task = asyncio.create_task(self.store_snapshots_in_cache(snapshots))
424+
# consume/log
425+
task.add_done_callback(lambda t: t.exception())
426+
404427
# TODO @herraj add unit tests for this
405428
def _parse_cached_data(
406429
self, cached_data: list[bytes | None]

tests/integration/providers/suggest/finance/backends/test_polygon.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from merino.providers.suggest.finance.backends.protocol import TickerSnapshot
2727
from merino.providers.suggest.finance.backends.polygon import PolygonBackend
28+
from merino.providers.suggest.finance.backends.polygon.utils import generate_cache_key_for_ticker
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -83,6 +84,29 @@ def fixture_polygon_parameters(
8384
}
8485

8586

87+
@pytest.fixture(name="polygon_factory")
88+
def fixture_polygon_factory(mocker: MockerFixture, statsd_mock: Any, redis_client: Redis):
89+
"""Return factory fixture to create Polygon backend parameters with overrides."""
90+
91+
def _polygon_parameters(**overrides: Any) -> dict[str, Any]:
92+
params = {
93+
"api_key": "api_key",
94+
"metrics_client": statsd_mock,
95+
"http_client": mocker.AsyncMock(spec=AsyncClient),
96+
"metrics_sample_rate": 1,
97+
"url_param_api_key": "apiKey",
98+
"url_single_ticker_snapshot": URL_SINGLE_TICKER_SNAPSHOT,
99+
"url_single_ticker_overview": URL_SINGLE_TICKER_OVERVIEW,
100+
"gcs_uploader": mocker.MagicMock(),
101+
"cache": RedisAdapter(redis_client),
102+
"ticker_ttl_sec": TICKER_TTL_SEC,
103+
}
104+
params.update(overrides)
105+
return params
106+
107+
return _polygon_parameters
108+
109+
86110
@pytest.fixture(name="polygon")
87111
def fixture_polygon(
88112
polygon_parameters: dict[str, Any],
@@ -119,6 +143,14 @@ def fixture_ticker_snapshot_NFLX() -> TickerSnapshot:
119143
)
120144

121145

146+
async def set_redis_key_expiry(
147+
redis_client: Redis, keys_and_expiry: list[tuple[str, int]]
148+
) -> None:
149+
"""Set redis cache key expiry (TTL seconds)."""
150+
for key, ttl in keys_and_expiry:
151+
await redis_client.expire(key, ttl)
152+
153+
122154
@pytest.mark.asyncio
123155
async def test_get_snapshots_from_cache_success(
124156
mocker: MockerFixture,
@@ -195,3 +227,44 @@ async def test_get_snapshots_from_cache_raises_cache_error(
195227
assert len(records) == 1
196228
assert records[0].message.startswith("Failed to fetch snapshots from Redis: test cache error")
197229
assert actual == []
230+
231+
232+
@pytest.mark.asyncio
233+
async def test_refresh_ticker_cache_entries_success(
234+
polygon_factory,
235+
ticker_snapshot_AAPL: TickerSnapshot,
236+
ticker_snapshot_NFLX,
237+
redis_client: Redis,
238+
mocker,
239+
) -> None:
240+
"""Test that refresh_ticker_cache_entries method successfully writes snapshots to cache with new TTL."""
241+
polygon = PolygonBackend(**polygon_factory(cache=RedisAdapter(redis_client)))
242+
243+
# Mocking the get_snapshots method to return AAPL and NFLX snapshots fixtures for 2 calls.
244+
get_snapshots_mock = mocker.patch.object(
245+
polygon, "get_snapshots", new_callable=mocker.AsyncMock
246+
)
247+
get_snapshots_mock.return_value = [ticker_snapshot_AAPL, ticker_snapshot_NFLX]
248+
249+
expected = [(ticker_snapshot_AAPL, TICKER_TTL_SEC), (ticker_snapshot_NFLX, TICKER_TTL_SEC)]
250+
251+
# write to cache (this method writes with the default 300 sec TTL)
252+
await polygon.store_snapshots_in_cache([ticker_snapshot_AAPL, ticker_snapshot_NFLX])
253+
254+
# manually modify the TTL for the above cache entries to 100 instead of 300
255+
cache_keys = []
256+
for key in ["AAPL", "NFLX"]:
257+
cache_keys.append(generate_cache_key_for_ticker(key))
258+
await set_redis_key_expiry(redis_client, [(cache_keys[0], 100), (cache_keys[1], 100)])
259+
260+
# refresh the cache entries -- this should reset the TTL to 300
261+
# forcing the await here otherwise this task finishes after test execution
262+
await polygon.refresh_ticker_cache_entries(["AAPL", "NFLX"], await_store=True)
263+
264+
actual = await polygon.get_snapshots_from_cache(["AAPL", "NFLX"])
265+
266+
assert actual is not None
267+
assert actual == expected
268+
269+
assert actual[0] == expected[0]
270+
assert actual[1] == expected[1]

tests/unit/providers/suggest/finance/backends/test_polygon.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,7 @@ async def test_get_snapshots_success(
312312
"""
313313
tickers = ["AAPL", "MSFT", "TSLA"]
314314

315-
# Mocking the fetch_ticker_snapshot method to return single_ticker_snapshot_response fixture for two of the calls.
316-
# Returns None for one of the calls.
315+
# Mocking the fetch_ticker_snapshot method to return single_ticker_snapshot_response fixture for all three calls.
317316
fetch_mock = mocker.patch.object(
318317
polygon,
319318
"fetch_ticker_snapshot",

0 commit comments

Comments
 (0)