diff --git a/Pipfile b/Pipfile index 5ab4e79..bc00d24 100644 --- a/Pipfile +++ b/Pipfile @@ -27,6 +27,7 @@ pytest-cov = "==3.0.0" pytest-mock = "==3.6.1" flake8 = "==4.0.1" fastapi = "*" +fastapi_cache2 = {extras = ["redis"], version = "==0.2.1"} uvicorn = {extras = ["standard"], version = "*"} jinja2 = "*" psutil = "*" diff --git a/app/coder.py b/app/coder.py new file mode 100644 index 0000000..a84e87e --- /dev/null +++ b/app/coder.py @@ -0,0 +1,38 @@ +import json +from io import BytesIO +from fastapi.responses import StreamingResponse +from fastapi_cache.coder import Coder + + +class StreamingResponseCoder(Coder): + @classmethod + async def encode(cls, value: StreamingResponse) -> bytes: + # Extract serializable parts of StreamingResponse + if isinstance(value, StreamingResponse): + headers = dict(value.headers) + content = b"" + + # Properly await the async generator + async for chunk in value.body_iterator: + if not isinstance(chunk, bytes): + chunk = chunk.encode(value.charset) + content += chunk + + # Prepare data for serialization + data = { + "status_code": value.status_code, + "headers": headers, + "content": content.decode('utf-8'), # Convert to string for JSON serialization + } + return json.dumps(data).encode("utf-8") + else: + raise TypeError(f"Unsupported type: {type(value)}") + + @classmethod + async def decode(cls, value: bytes) -> StreamingResponse: + # Convert the cached data back into a StreamingResponse + data = json.loads(value.decode("utf-8")) + content = BytesIO(data["content"].encode("utf-8")) # Convert content back to bytes + headers = data["headers"] + status_code = data["status_code"] + return StreamingResponse(content=content, headers=headers, status_code=status_code) diff --git a/app/main.py b/app/main.py index de716ab..d9b115d 100644 --- a/app/main.py +++ b/app/main.py @@ -13,9 +13,13 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi import status as status from fastapi.responses import FileResponse, JSONResponse, StreamingResponse, HTMLResponse +from fastapi_cache import FastAPICache +from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from fastapi_cache.decorator import cache from pydantic import HttpUrl +from app.coder import StreamingResponseCoder import app.mongodb as mongodb from app import ( @@ -53,6 +57,12 @@ templates = Jinja2Templates(directory="templates") +@api.on_event("startup") +async def on_startup(): + # Initialize FastAPICache with appropriate backend, coder, etc. + FastAPICache.init(InMemoryBackend(), coder=StreamingResponseCoder, prefix="fastapi-cache") + + @api.get("/build") def build_version() -> str: """Build version.""" @@ -146,6 +156,7 @@ async def system_info() -> Dict: @api.get("/download_template") +@cache(expire=86400, coder=StreamingResponseCoder) async def download_template( repo_zip_url: HttpUrl = "https://github.com/jdi-templates/" "jdi-light-testng-empty-template/archive/refs/heads/main.zip",