1616from proxy .http import Url
1717from proxy .core .base import TcpUpstreamConnectionHandler
1818from proxy .http .parser import HttpParser
19- from proxy .http .server import HttpWebServerBasePlugin , httpProtocolTypes
19+ from proxy .http .server import HttpWebServerBasePlugin
2020from proxy .common .utils import text_
2121from proxy .http .exception import HttpProtocolException
2222from proxy .common .constants import (
2323 HTTPS_PROTO , DEFAULT_HTTP_PORT , DEFAULT_HTTPS_PORT ,
2424 DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT ,
2525)
26+ from ...common .types import Readables , Writables , Descriptors
2627
2728
2829if TYPE_CHECKING : # pragma: no cover
@@ -44,6 +45,11 @@ def __init__(self, *args: Any, **kwargs: Any):
4445 self .uid , self .flags , self .client , self .event_queue , self .upstream_conn_pool ,
4546 )
4647 self .plugins .append (plugin )
48+ self ._upstream_proxy_pass : Optional [str ] = None
49+
50+ def do_upgrade (self , request : HttpParser ) -> bool :
51+ """Signal web protocol handler to not upgrade websocket requests by default."""
52+ return False
4753
4854 def handle_upstream_data (self , raw : memoryview ) -> None :
4955 # TODO: Parse response and implement plugin hook per parsed response object
@@ -54,8 +60,8 @@ def routes(self) -> List[Tuple[int, str]]:
5460 r = []
5561 for plugin in self .plugins :
5662 for route in plugin .regexes ():
57- r . append (( httpProtocolTypes . HTTP , route ))
58- r .append ((httpProtocolTypes . HTTPS , route ))
63+ for proto in plugin . protocols ():
64+ r .append ((proto , route ))
5965 return r
6066
6167 def handle_request (self , request : HttpParser ) -> None :
@@ -66,59 +72,123 @@ def handle_request(self, request: HttpParser) -> None:
6672 raise HttpProtocolException ('before_routing closed connection' )
6773 request = r
6874
75+ needs_upstream = False
76+
6977 # routes
7078 for plugin in self .plugins :
7179 for route in plugin .routes ():
80+ # Static routes
7281 if isinstance (route , tuple ):
7382 pattern = re .compile (route [0 ])
7483 if pattern .match (text_ (request .path )):
7584 self .choice = Url .from_bytes (
7685 random .choice (route [1 ]),
7786 )
7887 break
88+ # Dynamic routes
7989 elif isinstance (route , str ):
8090 pattern = re .compile (route )
8191 if pattern .match (text_ (request .path )):
82- self .choice = plugin .handle_route (request , pattern )
92+ choice = plugin .handle_route (request , pattern )
93+ if isinstance (choice , Url ):
94+ self .choice = choice
95+ needs_upstream = True
96+ self ._upstream_proxy_pass = str (self .choice )
97+ elif isinstance (choice , memoryview ):
98+ self .client .queue (choice )
99+ self ._upstream_proxy_pass = '{0} bytes' .format (len (choice ))
100+ else :
101+ self .upstream = choice
102+ self ._upstream_proxy_pass = '{0}:{1}' .format (
103+ * self .upstream .addr ,
104+ )
83105 break
84106 else :
85107 raise ValueError ('Invalid route' )
86108
87- assert self .choice and self .choice .hostname
88- port = self .choice .port or \
89- DEFAULT_HTTP_PORT \
90- if self .choice .scheme == b'http' \
91- else DEFAULT_HTTPS_PORT
92- self .initialize_upstream (text_ (self .choice .hostname ), port )
93- assert self .upstream
94- try :
95- self .upstream .connect ()
96- if self .choice .scheme == HTTPS_PROTO :
97- self .upstream .wrap (
98- text_ (
99- self .choice .hostname ,
109+ if needs_upstream :
110+ assert self .choice and self .choice .hostname
111+ port = (
112+ self .choice .port or DEFAULT_HTTP_PORT
113+ if self .choice .scheme == b'http'
114+ else DEFAULT_HTTPS_PORT
115+ )
116+ self .initialize_upstream (text_ (self .choice .hostname ), port )
117+ assert self .upstream
118+ try :
119+ self .upstream .connect ()
120+ if self .choice .scheme == HTTPS_PROTO :
121+ self .upstream .wrap (
122+ text_ (
123+ self .choice .hostname ,
124+ ),
125+ as_non_blocking = True ,
126+ ca_file = self .flags .ca_file ,
127+ )
128+ request .path = self .choice .remainder
129+ self .upstream .queue (memoryview (request .build ()))
130+ except ConnectionRefusedError :
131+ raise HttpProtocolException ( # pragma: no cover
132+ 'Connection refused by upstream server {0}:{1}' .format (
133+ text_ (self .choice .hostname ),
134+ port ,
100135 ),
101- as_non_blocking = True ,
102- ca_file = self .flags .ca_file ,
103136 )
104- request .path = self .choice .remainder
105- self .upstream .queue (memoryview (request .build ()))
106- except ConnectionRefusedError :
107- raise HttpProtocolException ( # pragma: no cover
108- 'Connection refused by upstream server {0}:{1}' .format (
109- text_ (self .choice .hostname ), port ,
110- ),
111- )
112137
113138 def on_client_connection_close (self ) -> None :
114139 if self .upstream and not self .upstream .closed :
115140 logger .debug ('Closing upstream server connection' )
116141 self .upstream .close ()
117142 self .upstream = None
118143
144+ def on_client_data (
145+ self ,
146+ request : HttpParser ,
147+ raw : memoryview ,
148+ ) -> Optional [memoryview ]:
149+ if request .is_websocket_upgrade :
150+ assert self .upstream
151+ self .upstream .queue (raw )
152+ return raw
153+
119154 def on_access_log (self , context : Dict [str , Any ]) -> Optional [Dict [str , Any ]]:
120- context .update ({
121- 'upstream_proxy_pass' : str (self .choice ) if self .choice else None ,
122- })
123- logger .info (DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT .format_map (context ))
155+ context .update (
156+ {
157+ 'upstream_proxy_pass' : self ._upstream_proxy_pass ,
158+ },
159+ )
160+ log_handled = False
161+ for plugin in self .plugins :
162+ ctx = plugin .on_access_log (context )
163+ if ctx is None :
164+ log_handled = True
165+ break
166+ context = ctx
167+ if not log_handled :
168+ logger .info (DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT .format_map (context ))
124169 return None
170+
171+ async def get_descriptors (self ) -> Descriptors :
172+ r , w = await super ().get_descriptors ()
173+ # TODO(abhinavsingh): We need to keep a mapping of plugin and
174+ # descriptors registered by them, so that within write/read blocks
175+ # we can invoke the right plugin callbacks.
176+ for plugin in self .plugins :
177+ plugin_read_desc , plugin_write_desc = await plugin .get_descriptors ()
178+ r .extend (plugin_read_desc )
179+ w .extend (plugin_write_desc )
180+ return r , w
181+
182+ async def read_from_descriptors (self , r : Readables ) -> bool :
183+ for plugin in self .plugins :
184+ teardown = await plugin .read_from_descriptors (r )
185+ if teardown :
186+ return True
187+ return await super ().read_from_descriptors (r )
188+
189+ async def write_to_descriptors (self , w : Writables ) -> bool :
190+ for plugin in self .plugins :
191+ teardown = await plugin .write_to_descriptors (w )
192+ if teardown :
193+ return True
194+ return await super ().write_to_descriptors (w )
0 commit comments