@@ -29,6 +29,35 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2929 if scope ["type" ] != "http" :
3030 return await self .app (scope , receive , send )
3131
32+ request = Request (scope )
33+
34+ filter_builder = self ._get_filter (request .url .path )
35+ if not filter_builder :
36+ return await self .app (scope , receive , send )
37+
38+ async def set_filter (body : Optional [dict ] = None ) -> None :
39+ cql2_filter = await filter_builder (
40+ {
41+ "req" : {
42+ "path" : request .url .path ,
43+ "method" : request .method ,
44+ "query_params" : dict (request .query_params ),
45+ "path_params" : requests .extract_variables (request .url .path ),
46+ "headers" : dict (request .headers ),
47+ "body" : body ,
48+ },
49+ ** request .state ._state ,
50+ }
51+ )
52+ cql2_filter .validate ()
53+ scope ["state" ][FILTER_STATE_KEY ] = cql2_filter
54+
55+ # For GET requests, we can build the filter immediately
56+ # NOTE: It appears that FastAPI will not call receive function for GET requests
57+ if request .method == "GET" :
58+ await set_filter ()
59+ return await self .app (scope , receive , send )
60+
3261 total_body = b""
3362
3463 async def receive_build_filter () -> Message :
@@ -38,26 +67,7 @@ async def receive_build_filter() -> Message:
3867 total_body += message .get ("body" , b"" )
3968
4069 if not message .get ("more_body" ):
41- request = Request (scope )
42- filter_builder = self ._get_filter (request .url .path )
43- if filter_builder :
44- cql2_filter = await filter_builder (
45- {
46- "req" : {
47- "path" : request .url .path ,
48- "method" : request .method ,
49- "query_params" : dict (request .query_params ),
50- "path_params" : requests .extract_variables (
51- request .url .path
52- ),
53- "headers" : dict (request .headers ),
54- "body" : json .loads (total_body ),
55- },
56- ** request .state ._state ,
57- }
58- )
59- cql2_filter .validate ()
60- scope ["state" ][FILTER_STATE_KEY ] = cql2_filter
70+ await set_filter (json .loads (total_body ) if total_body else None )
6171 return message
6272
6373 return await self .app (scope , receive_build_filter , send )
@@ -87,31 +97,35 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8797 if scope ["type" ] != "http" :
8898 return await self .app (scope , receive , send )
8999
90- async def apply_filter () -> Message :
91- message = await receive ()
92- request = Request (scope )
93- cql2_filter = getattr (request .state , FILTER_STATE_KEY , None )
94- if not cql2_filter :
95- logger .debug ("No cql2 filter found on message." )
96- return message
100+ request = Request (scope )
97101
98- if request .method == "GET" :
99- query = filters .insert_qs_filter (qs = query , filter = cql2_filter )
100- # Get the original query string
101- original_qs = scope ["query_string" ].decode ("utf-8" )
102- # Apply the filter to query string
103- new_qs = filters .append_qs_filter (original_qs , cql2_filter )
104- # Update the scope with new query string
105- # scope["query_string"] = new_qs.encode("utf-8")
106- elif request .method in ["POST" , "PUT" , "PATCH" ]:
107- # TODO: Apply the filter to the body
108- message ["body" ] = json .dumps (
109- filters .append_body_filter (
110- body = json .loads (message .get ("body" , "{}" )),
111- filter = cql2_filter ,
112- )
113- ).encode ("utf-8" )
102+ if request .method == "GET" :
103+ cql2_filter = scope ["state" ].get (FILTER_STATE_KEY )
104+ if cql2_filter :
105+ scope ["query_string" ] = filters .append_qs_filter (
106+ request .url .query , cql2_filter
107+ )
108+ return await self .app (scope , receive , send )
109+ elif request .method in ["POST" , "PUT" , "PATCH" ]:
110+ # For methods with bodies, we need to wrap receive to modify the body
111+ async def receive_with_filter () -> Message :
112+ message = await receive ()
113+ if message ["type" ] != "http.request" :
114+ return message
115+
116+ cql2_filter = scope ["state" ].get (FILTER_STATE_KEY )
117+ if cql2_filter :
118+ try :
119+ body = message .get ("body" , b"{}" )
120+ except json .JSONDecodeError as e :
121+ logger .warning ("Failed to parse request body as JSON" )
122+ # TODO: Return a 400 error
123+ raise e
124+
125+ new_body = filters .append_body_filter (json .loads (body ), cql2_filter )
126+ message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
127+ return message
114128
115- return message
129+ return await self . app ( scope , receive_with_filter , send )
116130
117- return await self .app (scope , apply_filter , send )
131+ return await self .app (scope , receive , send )
0 commit comments