Skip to content

Commit d22d551

Browse files
Optimize how HttpProtocolHandler delegates to the core plugins (#925)
* Add `protocols` abstract static method to `HttpProtocolHandlerBase` which defines which HTTP specification is followed by the core plugin * lint * Fix tests * Lint fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 498a1bb commit d22d551

File tree

13 files changed

+203
-129
lines changed

13 files changed

+203
-129
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2232,7 +2232,7 @@ usage: -m [-h] [--enable-events] [--enable-conn-pool] [--threadless]
22322232
[--filtered-url-regex-config FILTERED_URL_REGEX_CONFIG]
22332233
[--cloudflare-dns-mode CLOUDFLARE_DNS_MODE]
22342234

2235-
proxy.py v2.4.0rc5.dev11+ga872675.d20211225
2235+
proxy.py v2.4.0rc5.dev26+gb2b1bdc.d20211230
22362236

22372237
options:
22382238
-h, --help show this help message and exit

docs/conf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,11 @@
219219
# -- Options for linkcheck builder -------------------------------------------
220220

221221
linkcheck_ignore = [
222-
r'http://localhost:\d+/', # local URLs
222+
# local URLs
223+
r'http://localhost:\d+/',
224+
# GHA sees "403 Client Error: Forbidden for url:"
225+
# while the URL actually works
226+
r'https://developers.cloudflare.com/',
223227
]
224228
linkcheck_workers = 25
225229

proxy/http/handler.py

Lines changed: 83 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,21 @@
1616
import logging
1717
import selectors
1818

19-
from typing import Tuple, List, Union, Optional, Dict, Any
19+
from typing import Tuple, List, Type, Union, Optional, Dict, Any
2020

21-
from .plugin import HttpProtocolHandlerPlugin
22-
from .parser import HttpParser, httpParserStates, httpParserTypes
23-
from .exception import HttpProtocolException
24-
25-
from ..common.types import Readables, Writables
21+
from ..common.flag import flags
2622
from ..common.utils import wrap_socket
2723
from ..core.base import BaseTcpServerHandler
24+
from ..common.types import Readables, Writables
2825
from ..core.connection import TcpClientConnection
29-
from ..common.flag import flags
3026
from ..common.constants import DEFAULT_CLIENT_RECVBUF_SIZE, DEFAULT_KEY_FILE
3127
from ..common.constants import DEFAULT_SELECTOR_SELECT_TIMEOUT, DEFAULT_TIMEOUT
3228

29+
from .exception import HttpProtocolException
30+
from .plugin import HttpProtocolHandlerPlugin
31+
from .responses import BAD_REQUEST_RESPONSE_PKT
32+
from .parser import HttpParser, httpParserStates, httpParserTypes
33+
3334

3435
logger = logging.getLogger(__name__)
3536

@@ -78,31 +79,24 @@ def __init__(self, *args: Any, **kwargs: Any):
7879
self.selector: Optional[selectors.DefaultSelector] = None
7980
if not self.flags.threadless:
8081
self.selector = selectors.DefaultSelector()
81-
self.plugins: Dict[str, HttpProtocolHandlerPlugin] = {}
82+
self.plugin: Optional[HttpProtocolHandlerPlugin] = None
8283

8384
##
8485
# initialize, is_inactive, shutdown, get_events, handle_events
8586
# overrides Work class definitions.
8687
##
8788

8889
def initialize(self) -> None:
89-
"""Optionally upgrades connection to HTTPS, set ``conn`` in non-blocking mode and initializes plugins."""
90+
"""Optionally upgrades connection to HTTPS,
91+
sets ``conn`` in non-blocking mode and initializes
92+
HTTP protocol plugins.
93+
"""
9094
conn = self._optionally_wrap_socket(self.work.connection)
9195
conn.setblocking(False)
9296
# Update client connection reference if connection was wrapped
9397
if self._encryption_enabled():
9498
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
95-
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
96-
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
97-
instance: HttpProtocolHandlerPlugin = klass(
98-
self.uid,
99-
self.flags,
100-
self.work,
101-
self.request,
102-
self.event_queue,
103-
self.upstream_conn_pool,
104-
)
105-
self.plugins[instance.name()] = instance
99+
# self._initialize_plugins()
106100
logger.debug('Handling connection %s' % self.work.address)
107101

108102
def is_inactive(self) -> bool:
@@ -120,8 +114,8 @@ def shutdown(self) -> None:
120114
if self.selector and self.work.has_buffer():
121115
self._flush()
122116
# Invoke plugin.on_client_connection_close
123-
for plugin in self.plugins.values():
124-
plugin.on_client_connection_close()
117+
if self.plugin:
118+
self.plugin.on_client_connection_close()
125119
logger.debug(
126120
'Closing client connection %s has buffer %s' %
127121
(self.work.address, self.work.has_buffer()),
@@ -153,8 +147,8 @@ async def get_events(self) -> Dict[int, int]:
153147
# Get default client events
154148
events: Dict[int, int] = await super().get_events()
155149
# HttpProtocolHandlerPlugin.get_descriptors
156-
for plugin in self.plugins.values():
157-
plugin_read_desc, plugin_write_desc = plugin.get_descriptors()
150+
if self.plugin:
151+
plugin_read_desc, plugin_write_desc = self.plugin.get_descriptors()
158152
for rfileno in plugin_read_desc:
159153
if rfileno not in events:
160154
events[rfileno] = selectors.EVENT_READ
@@ -179,17 +173,17 @@ async def handle_events(
179173
if teardown:
180174
return True
181175
# Invoke plugin.write_to_descriptors
182-
for plugin in self.plugins.values():
183-
teardown = await plugin.write_to_descriptors(writables)
176+
if self.plugin:
177+
teardown = await self.plugin.write_to_descriptors(writables)
184178
if teardown:
185179
return True
186180
# Read from ready to read sockets
187181
teardown = await self.handle_readables(readables)
188182
if teardown:
189183
return True
190184
# Invoke plugin.read_from_descriptors
191-
for plugin in self.plugins.values():
192-
teardown = await plugin.read_from_descriptors(readables)
185+
if self.plugin:
186+
teardown = await self.plugin.read_from_descriptors(readables)
193187
if teardown:
194188
return True
195189
return False
@@ -209,33 +203,13 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
209203
# apply custom logic to handle request data sent after 1st
210204
# valid request.
211205
if self.request.state != httpParserStates.COMPLETE:
212-
# Parse http request
213-
#
214-
# TODO(abhinavsingh): Remove .tobytes after parser is
215-
# memoryview compliant
216-
self.request.parse(data.tobytes())
217-
if self.request.is_complete:
218-
# Invoke plugin.on_request_complete
219-
for plugin in self.plugins.values():
220-
upgraded_sock = plugin.on_request_complete()
221-
if isinstance(upgraded_sock, ssl.SSLSocket):
222-
logger.debug(
223-
'Updated client conn to %s', upgraded_sock,
224-
)
225-
self.work._conn = upgraded_sock
226-
for plugin_ in self.plugins.values():
227-
if plugin_ != plugin:
228-
plugin_.client._conn = upgraded_sock
229-
elif isinstance(upgraded_sock, bool) and upgraded_sock is True:
230-
return True
206+
if self._parse_first_request(data):
207+
return True
231208
else:
232209
# HttpProtocolHandlerPlugin.on_client_data
233210
# Can raise HttpProtocolException to tear down the connection
234-
for plugin in self.plugins.values():
235-
optional_data = plugin.on_client_data(data)
236-
if optional_data is None:
237-
break
238-
data = optional_data
211+
if self.plugin:
212+
data = self.plugin.on_client_data(data) or data
239213
except HttpProtocolException as e:
240214
logger.info('HttpProtocolException: %s' % e)
241215
response: Optional[memoryview] = e.response(self.request)
@@ -248,17 +222,13 @@ async def handle_writables(self, writables: Writables) -> bool:
248222
if self.work.connection.fileno() in writables and self.work.has_buffer():
249223
logger.debug('Client is write ready, flushing...')
250224
self.last_activity = time.time()
251-
252225
# TODO(abhinavsingh): This hook could just reside within server recv block
253226
# instead of invoking when flushed to client.
254227
#
255228
# Invoke plugin.on_response_chunk
256229
chunk = self.work.buffer
257-
for plugin in self.plugins.values():
258-
chunk = plugin.on_response_chunk(chunk)
259-
if chunk is None:
260-
break
261-
230+
if self.plugin:
231+
chunk = self.plugin.on_response_chunk(chunk)
262232
try:
263233
# Call super() for client flush
264234
teardown = await super().handle_writables(writables)
@@ -305,6 +275,61 @@ async def handle_readables(self, readables: Readables) -> bool:
305275
# Internal methods
306276
##
307277

278+
def _initialize_plugin(
279+
self,
280+
klass: Type['HttpProtocolHandlerPlugin'],
281+
) -> HttpProtocolHandlerPlugin:
282+
"""Initializes passed HTTP protocol handler plugin class."""
283+
return klass(
284+
self.uid,
285+
self.flags,
286+
self.work,
287+
self.request,
288+
self.event_queue,
289+
self.upstream_conn_pool,
290+
)
291+
292+
def _discover_plugin_klass(self, protocol: int) -> Optional[Type['HttpProtocolHandlerPlugin']]:
293+
"""Discovers and return matching HTTP handler plugin matching protocol."""
294+
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
295+
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
296+
k: Type['HttpProtocolHandlerPlugin'] = klass
297+
if protocol in k.protocols():
298+
return k
299+
return None
300+
301+
def _parse_first_request(self, data: memoryview) -> bool:
302+
# Parse http request
303+
#
304+
# TODO(abhinavsingh): Remove .tobytes after parser is
305+
# memoryview compliant
306+
self.request.parse(data.tobytes())
307+
if not self.request.is_complete:
308+
return False
309+
# Discover which HTTP handler plugin is capable of
310+
# handling the current incoming request
311+
klass = self._discover_plugin_klass(
312+
self.request.http_handler_protocol,
313+
)
314+
if klass is None:
315+
# No matching protocol class found.
316+
# Return bad request response and
317+
# close the connection.
318+
self.work.queue(BAD_REQUEST_RESPONSE_PKT)
319+
return True
320+
assert klass is not None
321+
self.plugin = self._initialize_plugin(klass)
322+
# Invoke plugin.on_request_complete
323+
output = self.plugin.on_request_complete()
324+
if isinstance(output, bool):
325+
return output
326+
assert isinstance(output, ssl.SSLSocket)
327+
logger.debug(
328+
'Updated client conn to %s', output,
329+
)
330+
self.work._conn = output
331+
return False
332+
308333
def _encryption_enabled(self) -> bool:
309334
return self.flags.keyfile is not None and \
310335
self.flags.certfile is not None

proxy/http/parser/parser.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from ..url import Url
2424
from ..methods import httpMethods
25+
from ..protocols import httpProtocols
2526
from ..exception import HttpProtocolException
2627

2728
from .protocol import ProxyProtocol
@@ -153,11 +154,10 @@ def set_url(self, url: bytes) -> None:
153154
self._url = Url.from_bytes(url)
154155
self._set_line_attributes()
155156

156-
def has_host(self) -> bool:
157-
"""Returns whether host line attribute was parsed or set.
158-
159-
NOTE: Host field WILL be None for incoming local WebServer requests."""
160-
return self.host is not None
157+
@property
158+
def http_handler_protocol(self) -> int:
159+
"""Returns `HttpProtocols` that this request belongs to."""
160+
return httpProtocols.HTTP_PROXY if self.host is not None else httpProtocols.WEB_SERVER
161161

162162
@property
163163
def is_complete(self) -> bool:

proxy/http/plugin.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from abc import ABC, abstractmethod
1515
from typing import Tuple, List, Union, Optional, TYPE_CHECKING
1616

17-
from .parser import HttpParser
18-
1917
from ..common.types import Readables, Writables
2018
from ..core.event import EventQueue
2119
from ..core.connection import TcpClientConnection
2220

21+
from .parser import HttpParser
22+
2323
if TYPE_CHECKING:
2424
from ..core.connection import UpstreamConnectionPool
2525

@@ -52,7 +52,7 @@ def __init__(
5252
flags: argparse.Namespace,
5353
client: TcpClientConnection,
5454
request: HttpParser,
55-
event_queue: EventQueue,
55+
event_queue: Optional[EventQueue],
5656
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
5757
):
5858
self.uid: str = uid
@@ -63,12 +63,10 @@ def __init__(
6363
self.upstream_conn_pool = upstream_conn_pool
6464
super().__init__()
6565

66-
def name(self) -> str:
67-
"""A unique name for your plugin.
68-
69-
Defaults to name of the class. This helps plugin developers to directly
70-
access a specific plugin by its name."""
71-
return self.__class__.__name__
66+
@staticmethod
67+
@abstractmethod
68+
def protocols() -> List[int]:
69+
raise NotImplementedError()
7270

7371
@abstractmethod
7472
def get_descriptors(self) -> Tuple[List[int], List[int]]:

proxy/http/protocols.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
11+
.. spelling::
12+
13+
http
14+
iterable
15+
"""
16+
from typing import NamedTuple
17+
18+
19+
HttpProtocols = NamedTuple(
20+
'HttpProtocols', [
21+
# Web server handling HTTP/1.0, HTTP/1.1, HTTP/2, HTTP/3
22+
# over plain Text or encrypted connection with clients
23+
('WEB_SERVER', int),
24+
# Proxies handling HTTP/1.0, HTTP/1.1, HTTP/2 protocols
25+
# over plain text connection or encrypted connection
26+
# with clients
27+
('HTTP_PROXY', int),
28+
# Proxies handling SOCKS4, SOCKS4a, SOCKS5 protocol
29+
('SOCKS_PROXY', int),
30+
],
31+
)
32+
33+
httpProtocols = HttpProtocols(1, 2, 3)

0 commit comments

Comments
 (0)