22
33import json
44import re
5- from dataclasses import dataclass , field
6- from functools import partial
5+ from dataclasses import dataclass
76from logging import getLogger
8- from typing import Callable , Optional
7+ from typing import Optional
98
109from cql2 import Expr
11- from starlette .datastructures import MutableHeaders , State
10+ from starlette .datastructures import MutableHeaders
1211from starlette .requests import Request
1312from starlette .types import ASGIApp , Message , Receive , Scope , Send
1413
@@ -39,32 +38,25 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3938
4039 request = Request (scope )
4140
42- get_cql2_filter : Callable [[], Optional [Expr ]] = partial (
43- getattr , request .state , self .state_key , None
44- )
41+ cql2_filter : Optional [Expr ] = getattr (request .state , self .state_key , None )
42+
43+ if not cql2_filter :
44+ return await self .app (scope , receive , send )
4545
4646 # Handle POST, PUT, PATCH
4747 if request .method in ["POST" , "PUT" , "PATCH" ]:
48- return await self .app (
49- scope ,
50- Cql2RequestBodyAugmentor (
51- receive = receive ,
52- state = request .state ,
53- get_cql2_filter = get_cql2_filter ,
54- ),
55- send ,
48+ req_body_handler = Cql2RequestBodyAugmentor (
49+ app = self .app ,
50+ cql2_filter = cql2_filter ,
5651 )
57-
58- cql2_filter = get_cql2_filter ()
59- if not cql2_filter :
60- return await self .app (scope , receive , send )
52+ return await req_body_handler (scope , receive , send )
6153
6254 if re .match (r"^/collections/([^/]+)/items/([^/]+)$" , request .url .path ):
63- return await self .app (
64- scope ,
65- receive ,
66- Cql2ResponseBodyValidator (cql2_filter = cql2_filter , send = send ),
55+ res_body_validator = Cql2ResponseBodyValidator (
56+ app = self .app ,
57+ cql2_filter = cql2_filter ,
6758 )
59+ return await res_body_validator (scope , send , receive )
6860
6961 scope ["query_string" ] = filters .append_qs_filter (request .url .query , cql2_filter )
7062 return await self .app (scope , receive , send )
@@ -74,88 +66,115 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
7466class Cql2RequestBodyAugmentor :
7567 """Handler to augment the request body with a CQL2 filter."""
7668
77- receive : Receive
78- state : State
79- get_cql2_filter : Callable [[], Optional [Expr ]]
80-
81- async def __call__ (self ) -> Message :
82- """Process a request body and augment with a CQL2 filter if available."""
83- message = await self .receive ()
84- if message ["type" ] != "http.request" :
85- return message
86-
87- # NOTE: Can only get cql2 filter _after_ calling self.receive()
88- cql2_filter = self .get_cql2_filter ()
89- if not cql2_filter :
90- return message
69+ app : ASGIApp
70+ cql2_filter : Expr
9171
72+ async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
73+ """Augment the request body with a CQL2 filter."""
74+ body = b""
75+ more_body = True
76+
77+ # Read the body
78+ while more_body :
79+ message = await receive ()
80+ if message ["type" ] == "http.request" :
81+ body += message .get ("body" , b"" )
82+ more_body = message .get ("more_body" , False )
83+
84+ # Modify body
9285 try :
93- body = json .loads (message . get ( " body" , b"{}" ) )
86+ body = json .loads (body )
9487 except json .JSONDecodeError as e :
9588 logger .warning ("Failed to parse request body as JSON" )
9689 # TODO: Return a 400 error
9790 raise e
9891
99- new_body = filters .append_body_filter (body , cql2_filter )
100- message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
101- return message
92+ # Augment the body
93+ assert isinstance (body , dict ), "Request body must be a JSON object"
94+ new_body = json .dumps (
95+ filters .append_body_filter (body , self .cql2_filter )
96+ ).encode ("utf-8" )
97+
98+ # Patch content-length in the headers
99+ headers = dict (scope ["headers" ])
100+ headers [b"content-length" ] = str (len (new_body )).encode ("latin1" )
101+ scope ["headers" ] = list (headers .items ())
102+
103+ async def new_receive ():
104+ return {
105+ "type" : "http.request" ,
106+ "body" : new_body ,
107+ "more_body" : False ,
108+ }
109+
110+ await self .app (scope , new_receive , send )
102111
103112
104113@dataclass
105114class Cql2ResponseBodyValidator :
106115 """Handler to validate response body with CQL2."""
107116
108- send : Send
117+ app : ASGIApp
109118 cql2_filter : Expr
110- initial_message : Optional [Message ] = field (init = False )
111- body : bytes = field (init = False , default_factory = bytes )
112119
113- async def __call__ (self , message : Message ) -> None :
120+ async def __call__ (self , scope : Scope , send : Send , receive : Receive ) -> None :
114121 """Process a response message and apply filtering if needed."""
115- if message ["type" ] == "http.response.start" :
116- self .initial_message = message
117- return
122+ if scope ["type" ] != "http" :
123+ return await self .app (scope , send , receive )
124+
125+ body = b""
126+ initial_message : Optional [Message ] = None
127+
128+ async def _send_error_response (status : int , message : str ) -> None :
129+ """Send an error response with the given status and message."""
130+ assert initial_message , "Initial message not set"
131+ error_body = json .dumps ({"message" : message }).encode ("utf-8" )
132+ headers = MutableHeaders (scope = initial_message )
133+ headers ["content-length" ] = str (len (error_body ))
134+ initial_message ["status" ] = status
135+ await send (initial_message )
136+ await send (
137+ {
138+ "type" : "http.response.body" ,
139+ "body" : error_body ,
140+ "more_body" : False ,
141+ }
142+ )
118143
119- if message ["type" ] == "http.response.body" :
120- assert self .initial_message , "Initial message not set"
144+ async def buffered_send (message : Message ) -> None :
145+ """Process a response message and apply filtering if needed."""
146+ nonlocal body
147+ nonlocal initial_message
121148
122- self .body += message ["body" ]
149+ if message ["type" ] == "http.response.start" :
150+ initial_message = message
151+ return
152+
153+ assert initial_message , "Initial message not set"
154+
155+ body += message ["body" ]
123156 if message .get ("more_body" ):
124157 return
125158
126159 try :
127- body_json = json .loads (self . body )
160+ body_json = json .loads (body )
128161 except json .JSONDecodeError :
129162 logger .warning ("Failed to parse response body as JSON" )
130- await self . _send_error_response (502 , "Not found" )
163+ await _send_error_response (502 , "Not found" )
131164 return
132165
133166 logger .debug (
134167 "Applying %s filter to %s" , self .cql2_filter .to_text (), body_json
135168 )
136169 if self .cql2_filter .matches (body_json ):
137- await self . send (self . initial_message )
138- return await self . send (
170+ await send (initial_message )
171+ return await send (
139172 {
140173 "type" : "http.response.body" ,
141174 "body" : json .dumps (body_json ).encode ("utf-8" ),
142175 "more_body" : False ,
143176 }
144177 )
145- return await self ._send_error_response (404 , "Not found" )
146-
147- async def _send_error_response (self , status : int , message : str ) -> None :
148- """Send an error response with the given status and message."""
149- assert self .initial_message , "Initial message not set"
150- error_body = json .dumps ({"message" : message }).encode ("utf-8" )
151- headers = MutableHeaders (scope = self .initial_message )
152- headers ["content-length" ] = str (len (error_body ))
153- self .initial_message ["status" ] = status
154- await self .send (self .initial_message )
155- await self .send (
156- {
157- "type" : "http.response.body" ,
158- "body" : error_body ,
159- "more_body" : False ,
160- }
161- )
178+ return await _send_error_response (404 , "Not found" )
179+
180+ return await self .app (scope , receive , buffered_send )
0 commit comments