Skip to content

Commit 916b2ce

Browse files
committed
Breakout middleware into separate files
1 parent 341bdc6 commit 916b2ce

File tree

6 files changed

+56
-38
lines changed

6 files changed

+56
-38
lines changed

src/stac_auth_proxy/app.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010

1111
from fastapi import FastAPI
1212

13-
from .auth import OpenIdConnectAuth
14-
from .config import Settings
15-
from .handlers import ReverseProxyHandler
1613
from .middleware import (
14+
OpenApiMiddleware,
1715
AddProcessTimeHeaderMiddleware,
1816
EnforceAuthMiddleware,
19-
OpenApiMiddleware,
2017
)
2118

19+
from .auth import OpenIdConnectAuth
20+
from .config import Settings
21+
from .handlers import ReverseProxyHandler
22+
2223
logger = logging.getLogger(__name__)
2324

2425

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from fastapi import Request, Response
2+
from starlette.middleware.base import BaseHTTPMiddleware
3+
4+
5+
import time
6+
7+
8+
class AddProcessTimeHeaderMiddleware(BaseHTTPMiddleware):
9+
"""Middleware to add a header with the process time to the response."""
10+
11+
async def dispatch(self, request: Request, call_next) -> Response:
12+
"""Add a header with the process time to the response."""
13+
start_time = time.perf_counter()
14+
response = await call_next(request)
15+
process_time = time.perf_counter() - start_time
16+
response.headers["X-Process-Time"] = f"{process_time:.3f}"
17+
return response
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# TODO
2+
from fastapi import Request, Response
3+
from starlette.middleware.base import BaseHTTPMiddleware
4+
5+
6+
class EnforceAuthMiddleware(BaseHTTPMiddleware):
7+
"""Middleware to enforce authentication."""
8+
9+
async def dispatch(self, request: Request, call_next) -> Response:
10+
"""Enforce authentication."""
11+
return await call_next(request)

src/stac_auth_proxy/middleware.py renamed to src/stac_auth_proxy/middleware/OpenApiMiddleware.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,13 @@
1-
"""Custom middleware."""
1+
from stac_auth_proxy.config import EndpointMethods
2+
from stac_auth_proxy.utils.requests import dict_to_bytes
23

3-
from dataclasses import dataclass
4-
from typing import Any
5-
import json
6-
import time
7-
8-
from fastapi import Request, Response
9-
from starlette.middleware.base import BaseHTTPMiddleware
10-
from starlette.types import ASGIApp, Message, Scope, Receive, Send
114

12-
from .config import EndpointMethods
5+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
136

147

15-
class AddProcessTimeHeaderMiddleware(BaseHTTPMiddleware):
16-
"""Middleware to add a header with the process time to the response."""
17-
18-
async def dispatch(self, request: Request, call_next) -> Response:
19-
"""Add a header with the process time to the response."""
20-
start_time = time.perf_counter()
21-
response = await call_next(request)
22-
process_time = time.perf_counter() - start_time
23-
response.headers["X-Process-Time"] = f"{process_time:.3f}"
24-
return response
8+
import json
9+
from dataclasses import dataclass
10+
from typing import Any
2511

2612

2713
@dataclass(frozen=True)
@@ -71,17 +57,3 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]:
7157
{self.oidc_auth_scheme_name: []}
7258
)
7359
return openapi_spec
74-
75-
76-
# TODO
77-
class EnforceAuthMiddleware(BaseHTTPMiddleware):
78-
"""Middleware to enforce authentication."""
79-
80-
async def dispatch(self, request: Request, call_next) -> Response:
81-
"""Enforce authentication."""
82-
return await call_next(request)
83-
84-
85-
def dict_to_bytes(d: dict) -> bytes:
86-
"""Convert a dictionary to a body."""
87-
return json.dumps(d, separators=(",", ":")).encode("utf-8")
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Custom middleware."""
2+
3+
from .OpenApiMiddleware import OpenApiMiddleware
4+
from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
5+
from .EnforceAuthMiddleware import EnforceAuthMiddleware
6+
7+
__all__ = [
8+
OpenApiMiddleware,
9+
AddProcessTimeHeaderMiddleware,
10+
EnforceAuthMiddleware,
11+
]

src/stac_auth_proxy/utils/requests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Utility functions for working with HTTP requests."""
22

3+
import json
34
import re
45
from urllib.parse import urlparse
56

@@ -29,3 +30,8 @@ def extract_variables(url: str) -> dict:
2930
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
3031
match = re.match(pattern, path)
3132
return {k: v for k, v in match.groupdict().items() if v} if match else {}
33+
34+
35+
def dict_to_bytes(d: dict) -> bytes:
36+
"""Convert a dictionary to a body."""
37+
return json.dumps(d, separators=(",", ":")).encode("utf-8")

0 commit comments

Comments
 (0)