|
6 | 6 | import re |
7 | 7 |
|
8 | 8 | from fastapi import Request, Response |
| 9 | +from fastapi.encoders import jsonable_encoder |
9 | 10 | from redis.asyncio import Redis, ConnectionPool |
10 | | -from sqlalchemy.orm import class_mapper, DeclarativeBase |
11 | 11 | from fastapi import FastAPI |
12 | 12 | from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint |
13 | 13 |
|
14 | | -from app.core.exceptions import CacheIdentificationInferenceError, InvalidRequestError |
| 14 | +from app.core.exceptions import CacheIdentificationInferenceError, InvalidRequestError, InvalidOutputTypeError |
15 | 15 |
|
16 | 16 | # --------------- server side caching --------------- |
17 | 17 |
|
18 | 18 | pool: ConnectionPool | None = None |
19 | 19 | client: Redis | None = None |
20 | 20 |
|
21 | | -def _serialize_sqlalchemy_object(obj: DeclarativeBase) -> Dict[str, Any]: |
22 | | - """ |
23 | | - Serialize a SQLAlchemy DeclarativeBase object to a dictionary. |
24 | | -
|
25 | | - Parameters |
26 | | - ---------- |
27 | | - obj: DeclarativeBase |
28 | | - The SQLAlchemy DeclarativeBase object to be serialized. |
29 | | - |
30 | | - Returns |
31 | | - ------- |
32 | | - Dict[str, Any] |
33 | | - A dictionary containing the serialized attributes of the object. |
34 | | - |
35 | | - Note |
36 | | - ---- |
37 | | - - Datetime objects are converted to ISO 8601 string format. |
38 | | - - UUID objects are converted to strings before serializing to JSON. |
39 | | - """ |
40 | | - if isinstance(obj, DeclarativeBase): |
41 | | - data = {} |
42 | | - for column in class_mapper(obj.__class__).columns: |
43 | | - value = getattr(obj, column.name) |
44 | | - |
45 | | - if isinstance(value, datetime): |
46 | | - value = value.isoformat() |
47 | | - |
48 | | - if isinstance(value, UUID): |
49 | | - value = str(value) |
50 | | - |
51 | | - data[column.name] = value |
52 | | - return data |
53 | | - |
54 | | - |
55 | 21 | def _infer_resource_id(kwargs: Dict[str, Any], resource_id_type: Union[type, str]) -> Union[None, int, str]: |
56 | 22 | """ |
57 | 23 | Infer the resource ID from a dictionary of keyword arguments. |
@@ -236,46 +202,43 @@ async def sample_endpoint(request: Request, resource_id: int): |
236 | 202 | This decorator caches the response data of the endpoint function using a unique cache key. |
237 | 203 | The cached data is retrieved for GET requests, and the cache is invalidated for other types of requests. |
238 | 204 |
|
239 | | - Note: |
240 | | - - For caching lists of objects, ensure that the response is a list of objects, and the decorator will handle caching accordingly. |
| 205 | + Note |
| 206 | + ---- |
241 | 207 | - resource_id_type is used only if resource_id is not passed. |
242 | 208 | """ |
243 | 209 | def wrapper(func: Callable) -> Callable: |
244 | 210 | @functools.wraps(func) |
245 | 211 | async def inner(request: Request, *args, **kwargs) -> Response: |
| 212 | + if "output_type" in kwargs.keys() and kwargs["output_type"] == list: |
| 213 | + raise InvalidOutputTypeError |
| 214 | + |
246 | 215 | if resource_id_name: |
247 | 216 | resource_id = kwargs[resource_id_name] |
248 | 217 | else: |
249 | 218 | resource_id = _infer_resource_id(kwargs=kwargs, resource_id_type=resource_id_type) |
250 | 219 |
|
251 | 220 | formatted_key_prefix = _format_prefix(key_prefix, kwargs) |
252 | 221 | cache_key = f"{formatted_key_prefix}:{resource_id}" |
253 | | - |
254 | 222 | if request.method == "GET": |
255 | 223 | if to_invalidate_extra: |
256 | 224 | raise InvalidRequestError |
257 | 225 |
|
258 | 226 | cached_data = await client.get(cache_key) |
259 | 227 | if cached_data: |
| 228 | + print("cache hit") |
260 | 229 | return json.loads(cached_data.decode()) |
261 | | - |
| 230 | + |
262 | 231 | result = await func(request, *args, **kwargs) |
263 | 232 |
|
264 | 233 | if request.method == "GET": |
265 | | - if to_invalidate_extra: |
266 | | - raise InvalidRequestError |
| 234 | + serializable_data = jsonable_encoder(result) |
| 235 | + serialized_data = json.dumps(serializable_data) |
267 | 236 |
|
268 | | - if isinstance(result, list): |
269 | | - serialized_data = json.dumps( |
270 | | - [_serialize_sqlalchemy_object(obj) for obj in result] |
271 | | - ) |
272 | | - else: |
273 | | - serialized_data = json.dumps( |
274 | | - _serialize_sqlalchemy_object(result) |
275 | | - ) |
276 | | - |
277 | 237 | await client.set(cache_key, serialized_data) |
278 | 238 | await client.expire(cache_key, expiration) |
| 239 | + |
| 240 | + serialized_data = json.loads(serialized_data) |
| 241 | + |
279 | 242 | else: |
280 | 243 | await client.delete(cache_key) |
281 | 244 | if to_invalidate_extra: |
|
0 commit comments