1+ from logging import getLogger
12import json
23from dataclasses import dataclass
34from typing import Annotated , Callable , Optional
910from ..config import EndpointMethods
1011from ..utils import di , filters , requests
1112
13+ logger = getLogger (__name__ )
1214
1315FILTER_STATE_KEY = "cql2_filter"
1416
@@ -27,17 +29,38 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2729 if scope ["type" ] != "http" :
2830 return await self .app (scope , receive , send )
2931
30- request = Request (scope )
31- filter_builder = self ._get_filter (request .url .path )
32- if filter_builder :
33- cql2_filter = await di .call_with_injected_dependencies (
34- func = filter_builder ,
35- request = request ,
36- )
37- cql2_filter .validate ()
38- scope ["state" ][FILTER_STATE_KEY ] = cql2_filter
32+ total_body = b""
33+
34+ async def receive_build_filter () -> Message :
35+ nonlocal total_body
3936
40- return await self .app (scope , receive , send )
37+ message = await receive ()
38+ total_body += message .get ("body" , b"" )
39+
40+ if not message .get ("more_body" ):
41+ request = Request (scope )
42+ filter_builder = self ._get_filter (request .url .path )
43+ if filter_builder :
44+ cql2_filter = await filter_builder (
45+ {
46+ "req" : {
47+ "path" : request .url .path ,
48+ "method" : request .method ,
49+ "query_params" : dict (request .query_params ),
50+ "path_params" : requests .extract_variables (
51+ request .url .path
52+ ),
53+ "headers" : dict (request .headers ),
54+ "body" : json .loads (total_body ),
55+ },
56+ ** request .state ._state ,
57+ }
58+ )
59+ cql2_filter .validate ()
60+ scope ["state" ][FILTER_STATE_KEY ] = cql2_filter
61+ return message
62+
63+ return await self .app (scope , receive_build_filter , send )
4164
4265 def _get_filter (self , path : str ) -> Optional [Callable [..., Expr ]]:
4366 """Get the CQL2 filter builder for the given path."""
@@ -55,51 +78,40 @@ def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
5578
5679@dataclass (frozen = True )
5780class ApplyCql2FilterMiddleware :
58- """Middleware to add the OpenAPI spec to the response ."""
81+ """Middleware to apply the Cql2Filter to the request ."""
5982
6083 app : ASGIApp
6184
6285 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
6386 """Add the Cql2Filter to the request."""
64- request = Request (scope )
65- cql2_filter = request .state .get (FILTER_STATE_KEY )
66-
67- if scope ["type" ] != "http" or not cql2_filter :
87+ if scope ["type" ] != "http" :
6888 return await self .app (scope , receive , send )
6989
70- # Apply filter if applicable
90+ async def apply_filter () -> Message :
91+ message = await receive ()
92+ request = Request (scope )
93+ cql2_filter = getattr (request .state , FILTER_STATE_KEY , None )
94+ if not cql2_filter :
95+ logger .debug ("No cql2 filter found on message." )
96+ return message
7197
72- total_body = b""
73-
74- async def receive_with_filter (message : Message ):
75- query = request .url .query
76-
77- # TODO: How do we handle querystrings in middleware?
7898 if request .method == "GET" :
7999 query = filters .insert_qs_filter (qs = query , filter = cql2_filter )
80-
81- if message ["type" ] == "http.response.body" :
82- nonlocal total_body
83- total_body += message ["body" ]
84- if message ["more_body" ]:
85- return await receive ({** message , "body" : b"" })
86-
87- # TODO: Only on search, not on create or update...
88- if request .method in ["POST" , "PUT" ]:
89- return await receive (
90- {
91- "type" : "http.response.body" ,
92- "body" : requests .dict_to_bytes (
93- filters .append_body_filter (
94- json .loads (total_body ), cql2_filter
95- )
96- ),
97- "more_body" : False ,
98- }
100+ # Get the original query string
101+ original_qs = scope ["query_string" ].decode ("utf-8" )
102+ # Apply the filter to query string
103+ new_qs = filters .append_qs_filter (original_qs , cql2_filter )
104+ # Update the scope with new query string
105+ # scope["query_string"] = new_qs.encode("utf-8")
106+ elif request .method in ["POST" , "PUT" , "PATCH" ]:
107+ # TODO: Apply the filter to the body
108+ message ["body" ] = json .dumps (
109+ filters .append_body_filter (
110+ body = json .loads (message .get ("body" , "{}" )),
111+ filter = cql2_filter ,
99112 )
113+ ).encode ("utf-8" )
100114
101- return await receive (message )
102-
103- await receive (message )
115+ return message
104116
105- return await self .app (scope , receive_with_filter , send )
117+ return await self .app (scope , apply_filter , send )
0 commit comments