1616import logging
1717import selectors
1818
19- from typing import Tuple , List , Union , Optional , Dict , Any
19+ from typing import Tuple , List , Type , Union , Optional , Dict , Any
2020
21- from .plugin import HttpProtocolHandlerPlugin
22- from .parser import HttpParser , httpParserStates , httpParserTypes
23- from .exception import HttpProtocolException
24-
25- from ..common .types import Readables , Writables
21+ from ..common .flag import flags
2622from ..common .utils import wrap_socket
2723from ..core .base import BaseTcpServerHandler
24+ from ..common .types import Readables , Writables
2825from ..core .connection import TcpClientConnection
29- from ..common .flag import flags
3026from ..common .constants import DEFAULT_CLIENT_RECVBUF_SIZE , DEFAULT_KEY_FILE
3127from ..common .constants import DEFAULT_SELECTOR_SELECT_TIMEOUT , DEFAULT_TIMEOUT
3228
29+ from .exception import HttpProtocolException
30+ from .plugin import HttpProtocolHandlerPlugin
31+ from .responses import BAD_REQUEST_RESPONSE_PKT
32+ from .parser import HttpParser , httpParserStates , httpParserTypes
33+
3334
3435logger = logging .getLogger (__name__ )
3536
@@ -78,31 +79,24 @@ def __init__(self, *args: Any, **kwargs: Any):
7879 self .selector : Optional [selectors .DefaultSelector ] = None
7980 if not self .flags .threadless :
8081 self .selector = selectors .DefaultSelector ()
81- self .plugins : Dict [ str , HttpProtocolHandlerPlugin ] = {}
82+ self .plugin : Optional [ HttpProtocolHandlerPlugin ] = None
8283
8384 ##
8485 # initialize, is_inactive, shutdown, get_events, handle_events
8586 # overrides Work class definitions.
8687 ##
8788
8889 def initialize (self ) -> None :
89- """Optionally upgrades connection to HTTPS, set ``conn`` in non-blocking mode and initializes plugins."""
90+ """Optionally upgrades connection to HTTPS,
91+ sets ``conn`` in non-blocking mode and initializes
92+ HTTP protocol plugins.
93+ """
9094 conn = self ._optionally_wrap_socket (self .work .connection )
9195 conn .setblocking (False )
9296 # Update client connection reference if connection was wrapped
9397 if self ._encryption_enabled ():
9498 self .work = TcpClientConnection (conn = conn , addr = self .work .addr )
95- if b'HttpProtocolHandlerPlugin' in self .flags .plugins :
96- for klass in self .flags .plugins [b'HttpProtocolHandlerPlugin' ]:
97- instance : HttpProtocolHandlerPlugin = klass (
98- self .uid ,
99- self .flags ,
100- self .work ,
101- self .request ,
102- self .event_queue ,
103- self .upstream_conn_pool ,
104- )
105- self .plugins [instance .name ()] = instance
99+ # self._initialize_plugins()
106100 logger .debug ('Handling connection %s' % self .work .address )
107101
108102 def is_inactive (self ) -> bool :
@@ -120,8 +114,8 @@ def shutdown(self) -> None:
120114 if self .selector and self .work .has_buffer ():
121115 self ._flush ()
122116 # Invoke plugin.on_client_connection_close
123- for plugin in self .plugins . values () :
124- plugin .on_client_connection_close ()
117+ if self .plugin :
118+ self . plugin .on_client_connection_close ()
125119 logger .debug (
126120 'Closing client connection %s has buffer %s' %
127121 (self .work .address , self .work .has_buffer ()),
@@ -153,8 +147,8 @@ async def get_events(self) -> Dict[int, int]:
153147 # Get default client events
154148 events : Dict [int , int ] = await super ().get_events ()
155149 # HttpProtocolHandlerPlugin.get_descriptors
156- for plugin in self .plugins . values () :
157- plugin_read_desc , plugin_write_desc = plugin .get_descriptors ()
150+ if self .plugin :
151+ plugin_read_desc , plugin_write_desc = self . plugin .get_descriptors ()
158152 for rfileno in plugin_read_desc :
159153 if rfileno not in events :
160154 events [rfileno ] = selectors .EVENT_READ
@@ -179,17 +173,17 @@ async def handle_events(
179173 if teardown :
180174 return True
181175 # Invoke plugin.write_to_descriptors
182- for plugin in self .plugins . values () :
183- teardown = await plugin .write_to_descriptors (writables )
176+ if self .plugin :
177+ teardown = await self . plugin .write_to_descriptors (writables )
184178 if teardown :
185179 return True
186180 # Read from ready to read sockets
187181 teardown = await self .handle_readables (readables )
188182 if teardown :
189183 return True
190184 # Invoke plugin.read_from_descriptors
191- for plugin in self .plugins . values () :
192- teardown = await plugin .read_from_descriptors (readables )
185+ if self .plugin :
186+ teardown = await self . plugin .read_from_descriptors (readables )
193187 if teardown :
194188 return True
195189 return False
@@ -209,33 +203,13 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
209203 # apply custom logic to handle request data sent after 1st
210204 # valid request.
211205 if self .request .state != httpParserStates .COMPLETE :
212- # Parse http request
213- #
214- # TODO(abhinavsingh): Remove .tobytes after parser is
215- # memoryview compliant
216- self .request .parse (data .tobytes ())
217- if self .request .is_complete :
218- # Invoke plugin.on_request_complete
219- for plugin in self .plugins .values ():
220- upgraded_sock = plugin .on_request_complete ()
221- if isinstance (upgraded_sock , ssl .SSLSocket ):
222- logger .debug (
223- 'Updated client conn to %s' , upgraded_sock ,
224- )
225- self .work ._conn = upgraded_sock
226- for plugin_ in self .plugins .values ():
227- if plugin_ != plugin :
228- plugin_ .client ._conn = upgraded_sock
229- elif isinstance (upgraded_sock , bool ) and upgraded_sock is True :
230- return True
206+ if self ._parse_first_request (data ):
207+ return True
231208 else :
232209 # HttpProtocolHandlerPlugin.on_client_data
233210 # Can raise HttpProtocolException to tear down the connection
234- for plugin in self .plugins .values ():
235- optional_data = plugin .on_client_data (data )
236- if optional_data is None :
237- break
238- data = optional_data
211+ if self .plugin :
212+ data = self .plugin .on_client_data (data ) or data
239213 except HttpProtocolException as e :
240214 logger .info ('HttpProtocolException: %s' % e )
241215 response : Optional [memoryview ] = e .response (self .request )
@@ -248,17 +222,13 @@ async def handle_writables(self, writables: Writables) -> bool:
248222 if self .work .connection .fileno () in writables and self .work .has_buffer ():
249223 logger .debug ('Client is write ready, flushing...' )
250224 self .last_activity = time .time ()
251-
252225 # TODO(abhinavsingh): This hook could just reside within server recv block
253226 # instead of invoking when flushed to client.
254227 #
255228 # Invoke plugin.on_response_chunk
256229 chunk = self .work .buffer
257- for plugin in self .plugins .values ():
258- chunk = plugin .on_response_chunk (chunk )
259- if chunk is None :
260- break
261-
230+ if self .plugin :
231+ chunk = self .plugin .on_response_chunk (chunk )
262232 try :
263233 # Call super() for client flush
264234 teardown = await super ().handle_writables (writables )
@@ -305,6 +275,61 @@ async def handle_readables(self, readables: Readables) -> bool:
305275 # Internal methods
306276 ##
307277
278+ def _initialize_plugin (
279+ self ,
280+ klass : Type ['HttpProtocolHandlerPlugin' ],
281+ ) -> HttpProtocolHandlerPlugin :
282+ """Initializes passed HTTP protocol handler plugin class."""
283+ return klass (
284+ self .uid ,
285+ self .flags ,
286+ self .work ,
287+ self .request ,
288+ self .event_queue ,
289+ self .upstream_conn_pool ,
290+ )
291+
292+ def _discover_plugin_klass (self , protocol : int ) -> Optional [Type ['HttpProtocolHandlerPlugin' ]]:
293+ """Discovers and return matching HTTP handler plugin matching protocol."""
294+ if b'HttpProtocolHandlerPlugin' in self .flags .plugins :
295+ for klass in self .flags .plugins [b'HttpProtocolHandlerPlugin' ]:
296+ k : Type ['HttpProtocolHandlerPlugin' ] = klass
297+ if protocol in k .protocols ():
298+ return k
299+ return None
300+
301+ def _parse_first_request (self , data : memoryview ) -> bool :
302+ # Parse http request
303+ #
304+ # TODO(abhinavsingh): Remove .tobytes after parser is
305+ # memoryview compliant
306+ self .request .parse (data .tobytes ())
307+ if not self .request .is_complete :
308+ return False
309+ # Discover which HTTP handler plugin is capable of
310+ # handling the current incoming request
311+ klass = self ._discover_plugin_klass (
312+ self .request .http_handler_protocol ,
313+ )
314+ if klass is None :
315+ # No matching protocol class found.
316+ # Return bad request response and
317+ # close the connection.
318+ self .work .queue (BAD_REQUEST_RESPONSE_PKT )
319+ return True
320+ assert klass is not None
321+ self .plugin = self ._initialize_plugin (klass )
322+ # Invoke plugin.on_request_complete
323+ output = self .plugin .on_request_complete ()
324+ if isinstance (output , bool ):
325+ return output
326+ assert isinstance (output , ssl .SSLSocket )
327+ logger .debug (
328+ 'Updated client conn to %s' , output ,
329+ )
330+ self .work ._conn = output
331+ return False
332+
308333 def _encryption_enabled (self ) -> bool :
309334 return self .flags .keyfile is not None and \
310335 self .flags .certfile is not None
0 commit comments