22
33import json
44import re
5- from dataclasses import dataclass , field
5+ from dataclasses import dataclass
66from logging import getLogger
77from typing import Optional
88
99from cql2 import Expr
10- from starlette .datastructures import MutableHeaders , State
10+ from starlette .datastructures import MutableHeaders
1111from starlette .requests import Request
1212from starlette .types import ASGIApp , Message , Receive , Scope , Send
1313
@@ -45,60 +45,18 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
4545
4646 # Handle POST, PUT, PATCH
4747 if request .method in ["POST" , "PUT" , "PATCH" ]:
48- body = b""
49- more_body = True
50- receive_ = receive
51-
52- async def buffered_receive ():
53- nonlocal body , more_body
54- if more_body :
55- message = await receive_ ()
56- if message ["type" ] == "http.request" :
57- body += message .get ("body" , b"" )
58- more_body = message .get ("more_body" , False )
59- return message
60- return {"type" : "http.request" , "body" : b"" , "more_body" : False }
61-
62- while more_body :
63- await buffered_receive ()
64-
65- # Modify body
66- # modified_body = body + b"\nAppended content."
67- try :
68- body = json .loads (body )
69- except json .JSONDecodeError as e :
70- logger .warning ("Failed to parse request body as JSON" )
71- # TODO: Return a 400 error
72- raise e
73-
74- new_body = json .dumps (filters .append_body_filter (body , cql2_filter )).encode (
75- "utf-8"
48+ req_body_handler = Cql2RequestBodyAugmentor (
49+ app = self .app ,
50+ cql2_filter = cql2_filter ,
7651 )
77-
78- # Override the receive function with a generator sending our modified body
79- async def new_receive ():
80- nonlocal new_body
81- chunk = new_body
82- new_body = b""
83- return {
84- "type" : "http.request" ,
85- "body" : chunk ,
86- "more_body" : False ,
87- }
88-
89- # Patch content-length in the headers
90- headers = dict (scope ["headers" ])
91- headers [b"content-length" ] = str (len (new_body )).encode ("latin1" )
92- scope ["headers" ] = list (headers .items ())
93-
94- await self .app (scope , new_receive , send )
52+ return await req_body_handler (scope , receive , send )
9553
9654 if re .match (r"^/collections/([^/]+)/items/([^/]+)$" , request .url .path ):
97- return await self .app (
98- scope ,
99- receive ,
100- Cql2ResponseBodyValidator (cql2_filter = cql2_filter , send = send ),
55+ res_body_validator = Cql2ResponseBodyValidator (
56+ app = self .app ,
57+ cql2_filter = cql2_filter ,
10158 )
59+ return await res_body_validator (scope , send , receive )
10260
10361 scope ["query_string" ] = filters .append_qs_filter (request .url .query , cql2_filter )
10462 return await self .app (scope , receive , send )
@@ -108,85 +66,115 @@ async def new_receive():
10866class Cql2RequestBodyAugmentor :
10967 """Handler to augment the request body with a CQL2 filter."""
11068
111- receive : Receive
112- state : State
69+ app : ASGIApp
11370 cql2_filter : Expr
114- initial_message : Optional [Message ] = field (init = False )
115-
116- async def __call__ (self ) -> Message :
117- """Process a request body and augment with a CQL2 filter if available."""
118- message = await self .receive ()
119- if message ["type" ] == "http.response.start" :
120- self .initial_message = message
121- return
122-
123- if not message .get ("body" ):
124- return message
12571
72+ async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
73+ """Augment the request body with a CQL2 filter."""
74+ body = b""
75+ more_body = True
76+
77+ # Read the body
78+ while more_body :
79+ message = await receive ()
80+ if message ["type" ] == "http.request" :
81+ body += message .get ("body" , b"" )
82+ more_body = message .get ("more_body" , False )
83+
84+ # Modify body
12685 try :
127- body = json .loads (message . get ( " body" ) )
86+ body = json .loads (body )
12887 except json .JSONDecodeError as e :
12988 logger .warning ("Failed to parse request body as JSON" )
13089 # TODO: Return a 400 error
13190 raise e
13291
133- new_body = filters .append_body_filter (body , self .cql2_filter )
134- message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
135- return message
92+ # Augment the body
93+ assert isinstance (body , dict ), "Request body must be a JSON object"
94+ new_body = json .dumps (
95+ filters .append_body_filter (body , self .cql2_filter )
96+ ).encode ("utf-8" )
97+
98+ # Patch content-length in the headers
99+ headers = dict (scope ["headers" ])
100+ headers [b"content-length" ] = str (len (new_body )).encode ("latin1" )
101+ scope ["headers" ] = list (headers .items ())
102+
103+ async def new_receive ():
104+ return {
105+ "type" : "http.request" ,
106+ "body" : new_body ,
107+ "more_body" : False ,
108+ }
109+
110+ await self .app (scope , new_receive , send )
136111
137112
138113@dataclass
139114class Cql2ResponseBodyValidator :
140115 """Handler to validate response body with CQL2."""
141116
142- send : Send
117+ app : ASGIApp
143118 cql2_filter : Expr
144- initial_message : Optional [Message ] = field (init = False )
145- body : bytes = field (init = False , default_factory = bytes )
146119
147- async def __call__ (self , message : Message ) -> None :
120+ async def __call__ (self , scope : Scope , send : Send , receive : Receive ) -> None :
148121 """Process a response message and apply filtering if needed."""
149- if message ["type" ] == "http.response.start" :
150- self .initial_message = message
151- return
152-
153- assert self .initial_message , "Initial message not set"
154-
155- self .body += message ["body" ]
156- if message .get ("more_body" ):
157- return
158-
159- try :
160- body_json = json .loads (self .body )
161- except json .JSONDecodeError :
162- logger .warning ("Failed to parse response body as JSON" )
163- await self ._send_error_response (502 , "Not found" )
164- return
165-
166- logger .debug ("Applying %s filter to %s" , self .cql2_filter .to_text (), body_json )
167- if self .cql2_filter .matches (body_json ):
168- await self .send (self .initial_message )
169- return await self .send (
122+ if scope ["type" ] != "http" :
123+ return await self .app (scope , send , receive )
124+
125+ body = b""
126+ initial_message : Optional [Message ] = None
127+
128+ async def _send_error_response (status : int , message : str ) -> None :
129+ """Send an error response with the given status and message."""
130+ assert initial_message , "Initial message not set"
131+ error_body = json .dumps ({"message" : message }).encode ("utf-8" )
132+ headers = MutableHeaders (scope = initial_message )
133+ headers ["content-length" ] = str (len (error_body ))
134+ initial_message ["status" ] = status
135+ await send (initial_message )
136+ await send (
170137 {
171138 "type" : "http.response.body" ,
172- "body" : json . dumps ( body_json ). encode ( "utf-8" ) ,
139+ "body" : error_body ,
173140 "more_body" : False ,
174141 }
175142 )
176- return await self ._send_error_response (404 , "Not found" )
177-
178- async def _send_error_response (self , status : int , message : str ) -> None :
179- """Send an error response with the given status and message."""
180- assert self .initial_message , "Initial message not set"
181- error_body = json .dumps ({"message" : message }).encode ("utf-8" )
182- headers = MutableHeaders (scope = self .initial_message )
183- headers ["content-length" ] = str (len (error_body ))
184- self .initial_message ["status" ] = status
185- await self .send (self .initial_message )
186- await self .send (
187- {
188- "type" : "http.response.body" ,
189- "body" : error_body ,
190- "more_body" : False ,
191- }
192- )
143+
144+ async def buffered_send (message : Message ) -> None :
145+ """Process a response message and apply filtering if needed."""
146+ nonlocal body
147+ nonlocal initial_message
148+
149+ if message ["type" ] == "http.response.start" :
150+ initial_message = message
151+ return
152+
153+ assert initial_message , "Initial message not set"
154+
155+ body += message ["body" ]
156+ if message .get ("more_body" ):
157+ return
158+
159+ try :
160+ body_json = json .loads (body )
161+ except json .JSONDecodeError :
162+ logger .warning ("Failed to parse response body as JSON" )
163+ await _send_error_response (502 , "Not found" )
164+ return
165+
166+ logger .debug (
167+ "Applying %s filter to %s" , self .cql2_filter .to_text (), body_json
168+ )
169+ if self .cql2_filter .matches (body_json ):
170+ await send (initial_message )
171+ return await send (
172+ {
173+ "type" : "http.response.body" ,
174+ "body" : json .dumps (body_json ).encode ("utf-8" ),
175+ "more_body" : False ,
176+ }
177+ )
178+ return await _send_error_response (404 , "Not found" )
179+
180+ return await self .app (scope , receive , buffered_send )
0 commit comments