Skip to content

Commit 37e585b

Browse files
committed
cache decorator fixed, now you can either pass a resource_id or let it infer it. Cache applied to some posts and users endpoints
1 parent 6e66c64 commit 37e585b

File tree

4 files changed

+159
-18
lines changed

4 files changed

+159
-18
lines changed

src/app/api/v1/posts.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from app.crud.crud_posts import crud_posts
1212
from app.crud.crud_users import crud_users
1313
from app.api.exceptions import privileges_exception
14+
from app.core.cache import cache
1415

1516
router = fastapi.APIRouter(tags=["posts"])
1617

@@ -36,8 +37,9 @@ async def write_post(
3637

3738

3839
@router.get("/{username}/posts", response_model=List[PostRead])
40+
@cache(key_prefix="{username}_posts", resource_id_name="username")
3941
async def read_posts(
40-
request: Request,
42+
request: Request,
4143
username: str,
4244
db: Annotated[AsyncSession, Depends(async_get_db)]
4345
):
@@ -50,6 +52,7 @@ async def read_posts(
5052

5153

5254
@router.get("/{username}/post/{id}", response_model=PostRead)
55+
@cache(key_prefix="{username}_post_cache")
5356
async def read_post(
5457
request: Request,
5558
username: str,
@@ -68,8 +71,9 @@ async def read_post(
6871

6972

7073
@router.patch("/{username}/post/{id}", response_model=PostRead)
74+
@cache("{username}_post_cache", resource_id_name="id")
7175
async def patch_post(
72-
request: Request,
76+
request: Request,
7377
username: str,
7478
id: int,
7579
values: PostUpdate,

src/app/api/v1/users.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ async def read_users_me(
4444
):
4545
return current_user
4646

47+
from app.core.cache import cache
48+
4749

4850
@router.get("/user/{username}", response_model=UserRead)
51+
@cache("user_cache", resource_id_type=str)
4952
async def read_user(request: Request, username: str, db: Annotated[AsyncSession, Depends(async_get_db)]):
5053
db_user = await crud_users.get(db=db, username=username, is_deleted=False)
5154
if db_user is None:

src/app/core/cache.py

Lines changed: 146 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,170 @@
1-
from typing import Callable
1+
from typing import Callable, Union, List, Dict, Any
22
import functools
3+
import json
4+
from uuid import UUID
5+
from datetime import datetime
36

47
from fastapi import Request, Response
5-
from fastapi.responses import JSONResponse
6-
import redis.asyncio as redis
78
from redis.asyncio import Redis, ConnectionPool
8-
from fastapi.responses import JSONResponse
9+
from sqlalchemy.orm import class_mapper, DeclarativeBase
10+
11+
from app.core.exceptions import CacheIdentificationInferenceError
912

1013
pool: ConnectionPool | None = None
1114
client: Redis | None = None
1215

13-
def cache(key_prefix: str, expiration: int = 3600) -> Callable:
16+
def _serialize_sqlalchemy_object(obj: DeclarativeBase) -> Dict[str, Any]:
17+
"""
18+
Serialize a SQLAlchemy DeclarativeBase object to a dictionary.
19+
20+
Parameters
21+
----------
22+
obj: DeclarativeBase
23+
The SQLAlchemy DeclarativeBase object to be serialized.
24+
25+
Returns
26+
-------
27+
Dict[str, Any]
28+
A dictionary containing the serialized attributes of the object.
29+
30+
Note
31+
----
32+
- Datetime objects are converted to ISO 8601 string format.
33+
- UUID objects are converted to strings before serializing to JSON.
34+
"""
35+
if isinstance(obj, DeclarativeBase):
36+
data = {}
37+
for column in class_mapper(obj.__class__).columns:
38+
value = getattr(obj, column.name)
39+
40+
if isinstance(value, datetime):
41+
value = value.isoformat()
42+
43+
if isinstance(value, UUID):
44+
value = str(value)
45+
46+
data[column.name] = value
47+
return data
48+
49+
50+
def _infer_resource_id(kwargs: Dict[str, Any], resource_id_type: Union[type, str]) -> Union[None, int, str]:
51+
"""
52+
Infer the resource ID from a dictionary of keyword arguments.
53+
54+
Parameters
55+
----------
56+
kwargs: Dict[str, Any]
57+
A dictionary of keyword arguments.
58+
resource_id_type: Union[type, str]
59+
The expected type of the resource ID, which can be an integer (int) or a string (str).
60+
61+
Returns
62+
-------
63+
Union[None, int, str]
64+
The inferred resource ID. If it cannot be inferred or does not match the expected type, it returns None.
65+
66+
Note
67+
----
68+
- When `resource_id_type` is 'int', the function looks for an argument with the key 'id'.
69+
- When `resource_id_type` is 'str', it attempts to infer the resource ID as a string.
70+
"""
71+
resource_id = None
72+
for arg_name, arg_value in kwargs.items():
73+
if isinstance(arg_value, resource_id_type):
74+
if (resource_id_type is int) and ("id" in arg_name):
75+
resource_id = arg_value
76+
77+
elif (resource_id_type is int) and ("id" not in arg_name):
78+
pass
79+
80+
elif resource_id_type is str:
81+
resource_id = arg_value
82+
83+
if resource_id is None:
84+
raise CacheIdentificationInferenceError
85+
86+
return resource_id
87+
88+
89+
def cache(key_prefix: str, resource_id_name: Any = None, expiration: int = 3600, resource_id_type: Union[type, List[type]] = int) -> Callable:
90+
"""
91+
Cache decorator for FastAPI endpoints.
92+
93+
This decorator allows you to cache the results of FastAPI endpoint functions, improving response times and reducing the load on the application by storing and retrieving data in a cache.
94+
95+
Parameters
96+
----------
97+
key_prefix: str
98+
A unique prefix to identify the cache key.
99+
resource_id: Any, optional
100+
The resource ID to be used in cache key generation. If not provided, it will be inferred from the endpoint's keyword arguments.
101+
expiration: int, optional
102+
The expiration time for cached data in seconds. Defaults to 3600 seconds (1 hour).
103+
resource_id_type: Union[type, List[type]], optional
104+
The expected type of the resource ID. This can be a single type (e.g., int) or a list of types (e.g., [int, str]). Defaults to int.
105+
106+
Returns
107+
-------
108+
Callable
109+
A decorator function that can be applied to FastAPI endpoints.
110+
111+
Example usage
112+
-------------
113+
114+
```python
115+
from fastapi import FastAPI, Request
116+
from my_module import cache # Replace with your actual module and imports
117+
118+
app = FastAPI()
119+
120+
# Define a sample endpoint with caching
121+
@app.get("/sample/{resource_id}")
122+
@cache(key_prefix="sample_data", expiration=3600, resource_id_type=int)
123+
async def sample_endpoint(request: Request, resource_id: int):
124+
# Your endpoint logic here
125+
return {"data": "your_data"}
126+
```
127+
128+
This decorator caches the response data of the endpoint function using a unique cache key.
129+
The cached data is retrieved for GET requests, and the cache is invalidated for other types of requests.
130+
131+
Note:
132+
- For caching lists of objects, ensure that the response is a list of objects, and the decorator will handle caching accordingly.
133+
- resource_id_type is used only if resource_id is not passed.
134+
"""
14135
def wrapper(func: Callable) -> Callable:
15136
@functools.wraps(func)
16137
async def inner(request: Request, *args, **kwargs) -> Response:
17-
resource_id = args[0] # Assuming the resource ID is the first argument
138+
if resource_id_name:
139+
resource_id = kwargs[resource_id_name]
140+
else:
141+
resource_id = _infer_resource_id(kwargs=kwargs, resource_id_type=resource_id_type)
142+
18143
cache_key = f"{key_prefix}:{resource_id}"
19144

20145
if request.method == "GET":
21-
# Check if the data exists in the cache for GET requests
22146
cached_data = await client.get(cache_key)
23147
if cached_data:
24-
# If data exists in the cache, return it
25-
return JSONResponse(content=cached_data.decode(), status_code=200)
148+
return json.loads(cached_data.decode())
26149

27-
# Call the original function for both all types of requests
28150
result = await func(request, *args, **kwargs)
29-
151+
30152
if request.method == "GET":
31-
# Store the result in the cache for GET requests with the specified expiration time
32-
await client.set(cache_key, result, expire=expiration)
153+
if isinstance(result, list):
154+
serialized_data = json.dumps(
155+
[_serialize_sqlalchemy_object(obj) for obj in result]
156+
)
157+
else:
158+
serialized_data = json.dumps(
159+
_serialize_sqlalchemy_object(result)
160+
)
161+
162+
await client.set(cache_key, serialized_data)
163+
await client.expire(cache_key, expiration)
33164
else:
34-
# Invalidate the cache for other types of requests
35-
await redis.delete(cache_key)
165+
await client.delete(cache_key)
36166

37-
return JSONResponse(content=result, status_code=200)
167+
return result
38168

39169
return inner
40170

src/app/core/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class CacheIdentificationInferenceError(Exception):
2+
def __init__(self, message="Could not infer id for resource being cached."):
3+
self.message = message
4+
super().__init__(self.message)

0 commit comments

Comments
 (0)