33import json
44import re
55from dataclasses import dataclass , field
6- from functools import partial
76from logging import getLogger
8- from typing import Callable , Optional
7+ from typing import Optional
98
109from cql2 import Expr
1110from starlette .datastructures import MutableHeaders , State
@@ -39,25 +38,60 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3938
4039 request = Request (scope )
4140
42- get_cql2_filter : Callable [[], Optional [Expr ]] = partial (
43- getattr , request .state , self .state_key , None
44- )
41+ cql2_filter : Optional [Expr ] = getattr (request .state , self .state_key , None )
42+
43+ if not cql2_filter :
44+ return await self .app (scope , receive , send )
4545
4646 # Handle POST, PUT, PATCH
4747 if request .method in ["POST" , "PUT" , "PATCH" ]:
48- return await self .app (
49- scope ,
50- Cql2RequestBodyAugmentor (
51- receive = receive ,
52- state = request .state ,
53- get_cql2_filter = get_cql2_filter ,
54- ),
55- send ,
48+ body = b""
49+ more_body = True
50+ receive_ = receive
51+
52+ async def buffered_receive ():
53+ nonlocal body , more_body
54+ if more_body :
55+ message = await receive_ ()
56+ if message ["type" ] == "http.request" :
57+ body += message .get ("body" , b"" )
58+ more_body = message .get ("more_body" , False )
59+ return message
60+ return {"type" : "http.request" , "body" : b"" , "more_body" : False }
61+
62+ while more_body :
63+ await buffered_receive ()
64+
65+ # Modify body
66+ # modified_body = body + b"\nAppended content."
67+ try :
68+ body = json .loads (body )
69+ except json .JSONDecodeError as e :
70+ logger .warning ("Failed to parse request body as JSON" )
71+ # TODO: Return a 400 error
72+ raise e
73+
74+ new_body = json .dumps (filters .append_body_filter (body , cql2_filter )).encode (
75+ "utf-8"
5676 )
5777
58- cql2_filter = get_cql2_filter ()
59- if not cql2_filter :
60- return await self .app (scope , receive , send )
78+ # Override the receive function with a generator sending our modified body
79+ async def new_receive ():
80+ nonlocal new_body
81+ chunk = new_body
82+ new_body = b""
83+ return {
84+ "type" : "http.request" ,
85+ "body" : chunk ,
86+ "more_body" : False ,
87+ }
88+
89+ # Patch content-length in the headers
90+ headers = dict (scope ["headers" ])
91+ headers [b"content-length" ] = str (len (new_body )).encode ("latin1" )
92+ scope ["headers" ] = list (headers .items ())
93+
94+ await self .app (scope , new_receive , send )
6195
6296 if re .match (r"^/collections/([^/]+)/items/([^/]+)$" , request .url .path ):
6397 return await self .app (
@@ -76,27 +110,27 @@ class Cql2RequestBodyAugmentor:
76110
77111 receive : Receive
78112 state : State
79- get_cql2_filter : Callable [[], Optional [Expr ]]
113+ cql2_filter : Expr
114+ initial_message : Optional [Message ] = field (init = False )
80115
81116 async def __call__ (self ) -> Message :
82117 """Process a request body and augment with a CQL2 filter if available."""
83118 message = await self .receive ()
84- if message ["type" ] != "http.request" :
85- return message
119+ if message ["type" ] == "http.response.start" :
120+ self .initial_message = message
121+ return
86122
87- # NOTE: Can only get cql2 filter _after_ calling self.receive()
88- cql2_filter = self .get_cql2_filter ()
89- if not cql2_filter :
123+ if not message .get ("body" ):
90124 return message
91125
92126 try :
93- body = json .loads (message .get ("body" , b"{}" ))
127+ body = json .loads (message .get ("body" ))
94128 except json .JSONDecodeError as e :
95129 logger .warning ("Failed to parse request body as JSON" )
96130 # TODO: Return a 400 error
97131 raise e
98132
99- new_body = filters .append_body_filter (body , cql2_filter )
133+ new_body = filters .append_body_filter (body , self . cql2_filter )
100134 message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
101135 return message
102136
@@ -116,33 +150,30 @@ async def __call__(self, message: Message) -> None:
116150 self .initial_message = message
117151 return
118152
119- if message ["type" ] == "http.response.body" :
120- assert self .initial_message , "Initial message not set"
153+ assert self .initial_message , "Initial message not set"
121154
122- self .body += message ["body" ]
123- if message .get ("more_body" ):
124- return
155+ self .body += message ["body" ]
156+ if message .get ("more_body" ):
157+ return
125158
126- try :
127- body_json = json .loads (self .body )
128- except json .JSONDecodeError :
129- logger .warning ("Failed to parse response body as JSON" )
130- await self ._send_error_response (502 , "Not found" )
131- return
132-
133- logger .debug (
134- "Applying %s filter to %s" , self .cql2_filter .to_text (), body_json
159+ try :
160+ body_json = json .loads (self .body )
161+ except json .JSONDecodeError :
162+ logger .warning ("Failed to parse response body as JSON" )
163+ await self ._send_error_response (502 , "Not found" )
164+ return
165+
166+ logger .debug ("Applying %s filter to %s" , self .cql2_filter .to_text (), body_json )
167+ if self .cql2_filter .matches (body_json ):
168+ await self .send (self .initial_message )
169+ return await self .send (
170+ {
171+ "type" : "http.response.body" ,
172+ "body" : json .dumps (body_json ).encode ("utf-8" ),
173+ "more_body" : False ,
174+ }
135175 )
136- if self .cql2_filter .matches (body_json ):
137- await self .send (self .initial_message )
138- return await self .send (
139- {
140- "type" : "http.response.body" ,
141- "body" : json .dumps (body_json ).encode ("utf-8" ),
142- "more_body" : False ,
143- }
144- )
145- return await self ._send_error_response (404 , "Not found" )
176+ return await self ._send_error_response (404 , "Not found" )
146177
147178 async def _send_error_response (self , status : int , message : str ) -> None :
148179 """Send an error response with the given status and message."""
0 commit comments