Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions backend/apps/common/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Strawberry extensions."""

import json

from django.conf import settings
from django.core.cache import cache
from strawberry.extensions.field_extension import FieldExtension


class CacheFieldExtension(FieldExtension):
"""Cache FieldExtension class."""

def __init__(self, cache_timeout: int | None = None, prefix: str | None = None):
"""Initialize the cache extension.

Args:
cache_timeout (int | None): The TTL for cache entries in seconds.
prefix (str | None): A prefix for the cache key.

"""
self.cache_timeout = cache_timeout or settings.GRAPHQL_RESOLVER_CACHE_TIME_SECONDS
self.prefix = prefix or settings.GRAPHQL_RESOLVER_CACHE_PREFIX

def generate_key(self, info, kwargs: dict) -> str:
"""Generate a unique cache key for a field.

Args:
info (Info): The Strawberry execution info.
kwargs (dict): The resolver's arguments.

Returns:
str: The unique cache key.

"""
args_str = json.dumps(kwargs, sort_keys=True)

return f"{self.prefix}:{info.path.typename}:{info.path.key}:{args_str}"

def resolve(self, next_, source, info, **kwargs):
"""Wrap the resolver to provide caching."""
cache_key = self.generate_key(info, kwargs)
if cached_result := cache.get(cache_key):
return cached_result

if result := next_(source, info, **kwargs):
cache.set(cache_key, result, self.cache_timeout)

return result
2 changes: 2 additions & 0 deletions backend/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class Base(Configuration):
API_PAGE_SIZE = 100
API_CACHE_PREFIX = "api-response"
API_CACHE_TIME_SECONDS = 86400 # 24 hours.
GRAPHQL_RESOLVER_CACHE_PREFIX = "graphql-resolver"
GRAPHQL_RESOLVER_CACHE_TIME_SECONDS = 86400 # 24 hours.
NINJA_PAGINATION_CLASS = "apps.api.rest.v0.pagination.CustomPagination"
NINJA_PAGINATION_PER_PAGE = API_PAGE_SIZE

Expand Down
101 changes: 101 additions & 0 deletions backend/tests/apps/common/extensions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from unittest.mock import MagicMock, patch

import pytest
from strawberry.types.info import Info

from apps.common.extensions import CacheFieldExtension


@pytest.mark.parametrize(
("typename", "key", "kwargs", "prefix", "expected_key"),
[
("UserNode", "name", {}, "p1", "p1:UserNode:name:{}"),
(
"RepositoryNode",
"issues",
{"limit": 10},
"p2",
"""p2:RepositoryNode:issues:{"limit": 10}""",
),
(
"RepositoryNode",
"issues",
{"limit": 10, "state": "open"},
"p3",
"""p3:RepositoryNode:issues:{"limit": 10, "state": "open"}""",
),
(
"RepositoryNode",
"issues",
{"state": "open", "limit": 10},
"p4",
"""p4:RepositoryNode:issues:{"limit": 10, "state": "open"}""",
),
],
)
def test_generate_key(typename, key, kwargs, prefix, expected_key):
"""Test cases for the generate_key method."""
mock_info = MagicMock(spec=Info)
mock_info.path.typename = typename
mock_info.path.key = key

extension = CacheFieldExtension(prefix=prefix)
assert extension.generate_key(mock_info, kwargs) == expected_key


class TestCacheFieldExtensionResolve:
"""Test cases for the resolve method of CacheFieldExtension."""

@pytest.fixture
def mock_info(self):
"""Return a mock Strawberry Info object."""
mock_info = MagicMock(spec=Info)
mock_info.path.typename = "TestType"
mock_info.path.key = "testField"
return mock_info

@patch("apps.common.extensions.cache")
def test_resolve_caches_result_on_miss(self, mock_cache, mock_info):
"""Test that the resolver caches the result on a cache miss."""
mock_cache.get.return_value = None
resolver_result = "some data"
next_ = MagicMock(return_value=resolver_result)
extension = CacheFieldExtension(cache_timeout=60)

result = extension.resolve(next_, source=None, info=mock_info)

assert result == resolver_result
mock_cache.get.assert_called_once()
next_.assert_called_once()
mock_cache.set.assert_called_once()
mock_cache.set.assert_called_with(mock_cache.get.call_args[0][0], resolver_result, 60)

@patch("apps.common.extensions.cache")
def test_resolve_returns_cached_result_on_hit(self, mock_cache, mock_info):
"""Test that the resolver returns the cached result on a cache hit."""
cached_result = "cached data"
mock_cache.get.return_value = cached_result
next_ = MagicMock()
extension = CacheFieldExtension()

result = extension.resolve(next_, source=None, info=mock_info)

assert result == cached_result
mock_cache.get.assert_called_once()
next_.assert_not_called()
mock_cache.set.assert_not_called()

@pytest.mark.parametrize("falsy_result", [None, [], {}, 0, False])
@patch("apps.common.extensions.cache")
def test_resolve_does_not_cache_falsy_result(self, mock_cache, falsy_result, mock_info):
"""Test that the resolver does not cache None or other falsy results."""
mock_cache.get.return_value = None
next_ = MagicMock(return_value=falsy_result)
extension = CacheFieldExtension()

result = extension.resolve(next_, source=None, info=mock_info)

assert result == falsy_result
mock_cache.get.assert_called_once()
next_.assert_called_once()
mock_cache.set.assert_not_called()