11"""Middleware to apply CQL2 filters."""
22
33import json
4- from dataclasses import dataclass
4+ import re
5+ from dataclasses import dataclass , field
6+ from functools import partial
57from logging import getLogger
8+ from typing import Callable , Optional
69
10+ from cql2 import Expr
11+ from starlette .datastructures import MutableHeaders , State
712from starlette .requests import Request
813from starlette .types import ASGIApp , Message , Receive , Scope , Send
914
@@ -17,7 +22,6 @@ class ApplyCql2FilterMiddleware:
1722 """Middleware to apply the Cql2Filter to the request."""
1823
1924 app : ASGIApp
20-
2125 state_key : str = "cql2_filter"
2226
2327 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
@@ -27,34 +31,123 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2731
2832 request = Request (scope )
2933
30- if request .method == "GET" :
31- cql2_filter = getattr (request .state , self .state_key , None )
32- if cql2_filter :
33- scope ["query_string" ] = filters .append_qs_filter (
34- request .url .query , cql2_filter
35- )
34+ get_cql2_filter : Callable [[], Optional [Expr ]] = partial (
35+ getattr , request .state , self .state_key , None
36+ )
37+
38+ # Handle POST, PUT, PATCH
39+ if request .method in ["POST" , "PUT" , "PATCH" ]:
40+ return await self .app (
41+ scope ,
42+ Cql2RequestBodyAugmentor (
43+ receive = receive ,
44+ state = request .state ,
45+ get_cql2_filter = get_cql2_filter ,
46+ ),
47+ send ,
48+ )
49+
50+ cql2_filter = get_cql2_filter ()
51+ if not cql2_filter :
3652 return await self .app (scope , receive , send )
3753
38- elif request .method in ["POST" , "PUT" , "PATCH" ]:
39-
40- async def receive_and_apply_filter () -> Message :
41- message = await receive ()
42- if message ["type" ] != "http.request" :
43- return message
44-
45- cql2_filter = getattr (request .state , self .state_key , None )
46- if cql2_filter :
47- try :
48- body = json .loads (message .get ("body" , b"{}" ))
49- except json .JSONDecodeError as e :
50- logger .warning ("Failed to parse request body as JSON" )
51- # TODO: Return a 400 error
52- raise e
54+ if re .match (r"^/collections/([^/]+)/items/([^/]+)$" , request .url .path ):
55+ return await self .app (
56+ scope ,
57+ receive ,
58+ Cql2ResponseBodyValidator (cql2_filter = cql2_filter , send = send ),
59+ )
5360
54- new_body = filters .append_body_filter (body , cql2_filter )
55- message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
56- return message
61+ scope ["query_string" ] = filters .append_qs_filter (request .url .query , cql2_filter )
62+ return await self .app (scope , receive , send )
5763
58- return await self .app (scope , receive_and_apply_filter , send )
5964
60- return await self .app (scope , receive , send )
65+ @dataclass (frozen = True )
66+ class Cql2RequestBodyAugmentor :
67+ """Handler to augment the request body with a CQL2 filter."""
68+
69+ receive : Receive
70+ state : State
71+ get_cql2_filter : Callable [[], Optional [Expr ]]
72+
73+ async def __call__ (self ) -> Message :
74+ """Process a request body and augment with a CQL2 filter if available."""
75+ message = await self .receive ()
76+ if message ["type" ] != "http.request" :
77+ return message
78+
79+ # NOTE: Can only get cql2 filter _after_ calling self.receive()
80+ cql2_filter = self .get_cql2_filter ()
81+ if not cql2_filter :
82+ return message
83+
84+ try :
85+ body = json .loads (message .get ("body" , b"{}" ))
86+ except json .JSONDecodeError as e :
87+ logger .warning ("Failed to parse request body as JSON" )
88+ # TODO: Return a 400 error
89+ raise e
90+
91+ new_body = filters .append_body_filter (body , cql2_filter )
92+ message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
93+ return message
94+
95+
96+ @dataclass
97+ class Cql2ResponseBodyValidator :
98+ """Handler to validate response body with CQL2."""
99+
100+ send : Send
101+ cql2_filter : Expr
102+ initial_message : Optional [Message ] = field (init = False )
103+ body : bytes = field (init = False , default_factory = bytes )
104+
105+ async def __call__ (self , message : Message ) -> None :
106+ """Process a response message and apply filtering if needed."""
107+ if message ["type" ] == "http.response.start" :
108+ self .initial_message = message
109+ return
110+
111+ if message ["type" ] == "http.response.body" :
112+ assert self .initial_message , "Initial message not set"
113+
114+ self .body += message ["body" ]
115+ if message .get ("more_body" ):
116+ return
117+
118+ try :
119+ body_json = json .loads (self .body )
120+ except json .JSONDecodeError :
121+ logger .warning ("Failed to parse response body as JSON" )
122+ await self ._send_error_response (502 , "Not found" )
123+ return
124+
125+ logger .debug (
126+ "Applying %s filter to %s" , self .cql2_filter .to_text (), body_json
127+ )
128+ if self .cql2_filter .matches (body_json ):
129+ await self .send (self .initial_message )
130+ return await self .send (
131+ {
132+ "type" : "http.response.body" ,
133+ "body" : json .dumps (body_json ).encode ("utf-8" ),
134+ "more_body" : False ,
135+ }
136+ )
137+ return await self ._send_error_response (404 , "Not found" )
138+
139+ async def _send_error_response (self , status : int , message : str ) -> None :
140+ """Send an error response with the given status and message."""
141+ assert self .initial_message , "Initial message not set"
142+ error_body = json .dumps ({"message" : message }).encode ("utf-8" )
143+ headers = MutableHeaders (scope = self .initial_message )
144+ headers ["content-length" ] = str (len (error_body ))
145+ self .initial_message ["status" ] = status
146+ await self .send (self .initial_message )
147+ await self .send (
148+ {
149+ "type" : "http.response.body" ,
150+ "body" : error_body ,
151+ "more_body" : False ,
152+ }
153+ )
0 commit comments