Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions scitt_emulator/policy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@
from celery import Celery, current_app as celery_current_app
from celery.result import AsyncResult
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import (
BaseModel,
PlainSerializer,
Expand Down Expand Up @@ -1722,6 +1723,13 @@ async def startup_fastapi_app_policy_engine_context(
yield state


class ContextVarMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
fastapi_current_app.set(request.app)
fastapi_current_request.set(request)
response = await call_next(request)
return response

def make_fastapi_app(
*,
context: Optional[Dict[str, Any]] = None,
Expand All @@ -1732,6 +1740,7 @@ def make_fastapi_app(
context,
),
)
app.add_middleware(ContextVarMiddleware)

@app.get("/rate_limit")
async def route_policy_engine_status(
Expand All @@ -1751,8 +1760,6 @@ async def route_policy_engine_status(
fastapi_request: Request,
) -> PolicyEngineStatus:
global celery_app
fastapi_current_app.set(fastapi_request.app)
fastapi_current_request.set(fastapi_request)
async with fastapi_request.state.no_celery_async_results_lock:
request_task = AsyncResult(request_id, app=celery_app)
request_task_state = request_task.state
Expand Down Expand Up @@ -1787,8 +1794,6 @@ async def route_request(
request: PolicyEngineRequest,
fastapi_request: Request,
) -> PolicyEngineStatus:
fastapi_current_app.set(fastapi_request.app)
fastapi_current_request.set(fastapi_request)
# TODO Handle when submitted.status cases
request_status = PolicyEngineStatus(
status=PolicyEngineStatuses.SUBMITTED,
Expand Down