Skip to content

Commit d6e6077

Browse files
authored
Add DEFAULT_MAX_SEND_SIZE and handle SSLWantWriteError errors when dispatching to upstream servers (#368)
1 parent 529580b commit d6e6077

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

proxy/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,4 @@
7474
DEFAULT_TIMEOUT = 10
7575
DEFAULT_VERSION = False
7676
DEFAULT_HTTP_PORT = 80
77+
DEFAULT_MAX_SEND_SIZE = 16 * 1024

proxy/core/connection/connection.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from abc import ABC, abstractmethod
1515
from typing import NamedTuple, Optional, Union, List
1616

17-
from ...common.constants import DEFAULT_BUFFER_SIZE
17+
from ...common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_MAX_SEND_SIZE
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -82,11 +82,11 @@ def flush(self) -> int:
8282
"""Users must handle BrokenPipeError exceptions"""
8383
if not self.has_buffer():
8484
return 0
85-
mv = self.buffer[0]
86-
sent: int = self.send(mv.tobytes())
85+
mv = self.buffer[0].tobytes()
86+
sent: int = self.send(mv[:DEFAULT_MAX_SEND_SIZE])
8787
if sent == len(mv):
8888
self.buffer.pop(0)
8989
else:
90-
self.buffer[0] = memoryview(mv.tobytes()[sent:])
90+
self.buffer[0] = memoryview(mv[sent:])
9191
logger.debug('flushed %d bytes to %s' % (sent, self.tag))
9292
return sent

proxy/http/proxy/server.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,15 @@ def write_to_descriptors(self, w: List[Union[int, HasFileno]]) -> bool:
8989
logger.debug('Server is write ready, flushing buffer')
9090
try:
9191
self.server.flush()
92+
except ssl.SSLWantWriteError:
93+
logger.warning('SSLWantWriteError while trying to flush to server, will retry')
94+
return False
9295
except BrokenPipeError:
9396
logger.error(
9497
'BrokenPipeError when flushing buffer for server')
9598
return True
96-
except OSError:
97-
logger.error('OSError when flushing buffer to server')
99+
except OSError as e:
100+
logger.exception('OSError when flushing buffer to server', exc_info=e)
98101
return True
99102
return False
100103

@@ -207,7 +210,6 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]:
207210
if self.request.state == httpParserStates.COMPLETE and (
208211
self.request.method != httpMethods.CONNECT or
209212
self.flags.tls_interception_enabled()):
210-
211213
if self.pipeline_request is not None and \
212214
self.pipeline_request.is_connection_upgrade():
213215
# Previous pipelined request was a WebSocket
@@ -219,6 +221,7 @@ def on_client_data(self, raw: memoryview) -> Optional[memoryview]:
219221
if self.pipeline_request is None:
220222
self.pipeline_request = HttpParser(
221223
httpParserTypes.REQUEST_PARSER)
224+
222225
# TODO(abhinavsingh): Remove .tobytes after parser is
223226
# memoryview compliant
224227
self.pipeline_request.parse(raw.tobytes())

0 commit comments

Comments
 (0)