Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions fastapi_cache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class RedisBackend(Backend):
def __init__(self, redis: Union["Redis[bytes]", "RedisCluster[bytes]"]):
self.redis = redis
self.is_cluster: bool = isinstance(redis, RedisCluster)
# Add driver identification for redis-py
self._add_driver_info()

async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
Expand All @@ -28,3 +30,37 @@ async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None
elif key:
return await self.redis.delete(key) # type: ignore[union-attr]
return 0

def _add_driver_info(self) -> None:
"""Add driver identification to Redis connection.

Uses DriverInfo class if available, or falls back to
lib_name/lib_version for older versions.
"""
from typing import Any

from fastapi_cache import __version__

# Get connection pool from the redis client
connection_pool: Any = getattr(self.redis, "connection_pool", None)
if connection_pool is None:
return

# Try to use DriverInfo class
try:
from redis import DriverInfo

driver_info = DriverInfo().add_upstream_driver("fastapi-cache", __version__)
connection_pool.connection_kwargs["driver_info"] = driver_info
except (ImportError, AttributeError):
# Fallback: use lib_name/lib_version
# Format: lib_name='redis-py(fastapi-cache_v{version})'
connection_pool.connection_kwargs["lib_name"] = f"redis-py(fastapi-cache_v{__version__})"
# lib_version should be the redis client version
try:
import redis

redis_version = redis.__version__
except (ImportError, AttributeError):
redis_version = "unknown"
connection_pool.connection_kwargs["lib_version"] = redis_version
149 changes: 149 additions & 0 deletions tests/test_redis_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Any, Dict, Optional
from unittest.mock import MagicMock, patch

import pytest

from fastapi_cache.backends.redis import RedisBackend


class MockConnectionPool:
"""Mock Redis connection pool."""

def __init__(self) -> None:
self.connection_kwargs: Dict[str, Any] = {}


class MockRedisClient:
"""Mock Redis client."""

def __init__(self, has_pool: bool = True) -> None:
self.connection_pool: Optional[MockConnectionPool] = (
MockConnectionPool() if has_pool else None
)


@pytest.fixture
def mock_redis_client() -> MockRedisClient:
"""Create a mock Redis client with connection pool."""
return MockRedisClient(has_pool=True)


@pytest.fixture
def mock_redis_client_no_pool() -> MockRedisClient:
"""Create a mock Redis client without connection pool."""
return MockRedisClient(has_pool=False)


def test_add_driver_info_with_driver_info_class(mock_redis_client: MockRedisClient) -> None:
"""Test _add_driver_info when DriverInfo class is available."""
mock_driver_info_instance = MagicMock()
mock_driver_info_instance.add_upstream_driver.return_value = mock_driver_info_instance
mock_driver_info_class = MagicMock(return_value=mock_driver_info_instance)

with patch("redis.DriverInfo", mock_driver_info_class, create=True):
with patch("fastapi_cache.__version__", "0.2.2"):
RedisBackend(mock_redis_client) # type: ignore[arg-type]

# Verify DriverInfo was instantiated
mock_driver_info_class.assert_called_once()
mock_driver_info_instance.add_upstream_driver.assert_called_once_with(
"fastapi-cache", "0.2.2"
)

# Verify driver_info was set in connection_kwargs
assert "driver_info" in mock_redis_client.connection_pool.connection_kwargs # type: ignore[union-attr]
assert (
mock_redis_client.connection_pool.connection_kwargs["driver_info"] # type: ignore[union-attr]
== mock_driver_info_instance
)


def test_add_driver_info_fallback_without_driver_info(
mock_redis_client: MockRedisClient,
) -> None:
"""Test _add_driver_info fallback when DriverInfo is not available."""
with patch("redis.DriverInfo", side_effect=ImportError, create=True):
with patch("fastapi_cache.__version__", "0.2.2"):
with patch("redis.__version__", "5.0.0"):
RedisBackend(mock_redis_client) # type: ignore[arg-type]

# Verify fallback to lib_name/lib_version
assert "lib_name" in mock_redis_client.connection_pool.connection_kwargs # type: ignore[union-attr]
assert "lib_version" in mock_redis_client.connection_pool.connection_kwargs # type: ignore[union-attr]

lib_name = mock_redis_client.connection_pool.connection_kwargs["lib_name"] # type: ignore[union-attr]
assert lib_name == "redis-py(fastapi-cache_v0.2.2)"
assert (
mock_redis_client.connection_pool.connection_kwargs["lib_version"] # type: ignore[union-attr]
== "5.0.0"
)


def test_add_driver_info_fallback_unknown_redis_version(
mock_redis_client: MockRedisClient,
) -> None:
"""Test _add_driver_info fallback when redis version is unknown."""
with patch("redis.DriverInfo", side_effect=ImportError, create=True):
with patch("fastapi_cache.__version__", "0.2.2"):
# Delete __version__ from redis module to trigger AttributeError
import redis
original_version = getattr(redis, "__version__", None)
try:
if hasattr(redis, "__version__"):
delattr(redis, "__version__")

RedisBackend(mock_redis_client) # type: ignore[arg-type]

# Verify fallback with unknown version
assert "lib_version" in mock_redis_client.connection_pool.connection_kwargs # type: ignore[union-attr]
assert (
mock_redis_client.connection_pool.connection_kwargs["lib_version"] # type: ignore[union-attr]
== "unknown"
)
finally:
# Restore original version
if original_version is not None:
redis.__version__ = original_version # type: ignore[attr-defined]


def test_add_driver_info_no_connection_pool(
mock_redis_client_no_pool: MockRedisClient,
) -> None:
"""Test _add_driver_info when connection pool is not available."""
# Should not raise an error, just return early
backend = RedisBackend(mock_redis_client_no_pool) # type: ignore[arg-type]

# Verify no error was raised and backend was created
assert backend.redis == mock_redis_client_no_pool


def test_add_driver_info_attribute_error_fallback(
mock_redis_client: MockRedisClient,
) -> None:
"""Test _add_driver_info fallback when DriverInfo raises AttributeError."""
with patch("redis.DriverInfo", side_effect=AttributeError, create=True):
with patch("fastapi_cache.__version__", "0.2.2"):
with patch("redis.__version__", "4.5.0"):
RedisBackend(mock_redis_client) # type: ignore[arg-type]

# Verify fallback to lib_name/lib_version
assert "lib_name" in mock_redis_client.connection_pool.connection_kwargs # type: ignore[union-attr]
assert (
mock_redis_client.connection_pool.connection_kwargs["lib_version"] # type: ignore[union-attr]
== "4.5.0"
)


def test_redis_backend_is_cluster_false(mock_redis_client: MockRedisClient) -> None:
"""Test that is_cluster is False for regular Redis client."""
backend = RedisBackend(mock_redis_client) # type: ignore[arg-type]
assert backend.is_cluster is False


def test_redis_backend_initialization(mock_redis_client: MockRedisClient) -> None:
"""Test RedisBackend initialization."""
backend = RedisBackend(mock_redis_client) # type: ignore[arg-type]

assert backend.redis == mock_redis_client
assert backend.is_cluster is False