2121 r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2" ,
2222 r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text" ,
2323 r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json" ,
24- r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter" ,
25- r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter" ,
2624)
2725@dataclass (frozen = True )
2826class ApplyCql2FilterMiddleware :
@@ -31,6 +29,11 @@ class ApplyCql2FilterMiddleware:
3129 app : ASGIApp
3230 state_key : str = "cql2_filter"
3331
32+ single_record_endpoints = [
33+ r"^/collections/([^/]+)/items/([^/]+)$" ,
34+ r"^/collections/([^/]+)$" ,
35+ ]
36+
3437 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
3538 """Add the Cql2Filter to the request."""
3639 if scope ["type" ] != "http" :
@@ -52,11 +55,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5255 return await req_body_handler (scope , receive , send )
5356
5457 # Handle single record requests (ie non-filterable endpoints)
55- single_record_endpoints = [
56- r"^/collections/([^/]+)/items/([^/]+)$" ,
57- r"^/collections/([^/]+)$" ,
58- ]
59- if any (re .match (expr , request .url .path ) for expr in single_record_endpoints ):
58+ if any (
59+ re .match (expr , request .url .path ) for expr in self .single_record_endpoints
60+ ):
6061 res_body_validator = Cql2ResponseBodyValidator (
6162 app = self .app ,
6263 cql2_filter = cql2_filter ,
@@ -130,18 +131,22 @@ async def __call__(self, scope: Scope, send: Send, receive: Receive) -> None:
130131 body = b""
131132 initial_message : Optional [Message ] = None
132133
133- async def _send_error_response (status : int , message : str ) -> None :
134+ async def _send_error_response (status : int , code : str , message : str ) -> None :
134135 """Send an error response with the given status and message."""
135136 assert initial_message , "Initial message not set"
136- error_body = json .dumps ({"message" : message }).encode ("utf-8" )
137+ response_dict = {
138+ "code" : code ,
139+ "description" : message ,
140+ }
141+ response_bytes = json .dumps (response_dict ).encode ("utf-8" )
137142 headers = MutableHeaders (scope = initial_message )
138- headers ["content-length" ] = str (len (error_body ))
143+ headers ["content-length" ] = str (len (response_bytes ))
139144 initial_message ["status" ] = status
140145 await send (initial_message )
141146 await send (
142147 {
143148 "type" : "http.response.body" ,
144- "body" : error_body ,
149+ "body" : response_bytes ,
145150 "more_body" : False ,
146151 }
147152 )
@@ -150,40 +155,48 @@ async def buffered_send(message: Message) -> None:
150155 """Process a response message and apply filtering if needed."""
151156 nonlocal body
152157 nonlocal initial_message
158+ initial_message = initial_message or message
159+ # NOTE: to avoid data-leak, we process 404s so their responses are the same as rejected 200s
160+ should_process = initial_message ["status" ] in [200 , 404 ]
161+
162+ if not should_process :
163+ return await send (message )
153164
154165 if message ["type" ] == "http.response.start" :
155- initial_message = message
166+ # Hold off on sending response headers until we've validated the response body
156167 return
157168
158- assert initial_message , "Initial message not set"
159-
160169 body += message ["body" ]
161170 if message .get ("more_body" ):
162171 return
163172
164173 try :
165174 body_json = json .loads (body )
166175 except json .JSONDecodeError :
167- logger .warning ("Failed to parse response body as JSON" )
168- await _send_error_response (502 , "Not found" )
176+ msg = "Failed to parse response body as JSON"
177+ logger .warning (msg )
178+ await _send_error_response (status = 502 , code = "ParseError" , message = msg )
169179 return
170180
171- logger .debug (
172- "Applying %s filter to %s" , self .cql2_filter .to_text (), body_json
173- )
174181 try :
175- if self .cql2_filter .matches (body_json ):
176- await send (initial_message )
177- return await send (
178- {
179- "type" : "http.response.body" ,
180- "body" : json .dumps (body_json ).encode ("utf-8" ),
181- "more_body" : False ,
182- }
183- )
182+ cql2_matches = self .cql2_filter .matches (body_json )
184183 except Exception as e :
184+ cql2_matches = False
185185 logger .warning ("Failed to apply filter: %s" , e )
186186
187- return await _send_error_response (404 , "Not found" )
187+ if cql2_matches :
188+ logger .debug ("Response matches filter, returning record" )
189+ await send (initial_message )
190+ return await send (
191+ {
192+ "type" : "http.response.body" ,
193+ "body" : json .dumps (body_json ).encode ("utf-8" ),
194+ "more_body" : False ,
195+ }
196+ )
197+ logger .debug ("Response did not match filter, returning 404" )
198+ return await _send_error_response (
199+ status = 404 , code = "NotFoundError" , message = "Record not found."
200+ )
188201
189202 return await self .app (scope , receive , buffered_send )
0 commit comments