Skip to content

Commit a9beaf1

Browse files
committed
leakage
1 parent 050ac1c commit a9beaf1

File tree

10 files changed

+201
-31
lines changed

10 files changed

+201
-31
lines changed

proxy/core/base/tcp_server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class BaseTcpServerHandler(Work[T]):
104104
is ready to accept new data before flushing data to it.
105105
106106
Most importantly, BaseTcpServerHandler ensures that pending buffers
107-
to the client are flushed before connection is closed.
107+
to the client are flushed before connection is closed with the client.
108108
109109
Implementations must provide::
110110
@@ -170,9 +170,9 @@ async def handle_events(
170170
async def handle_writables(self, writables: Writables) -> bool:
171171
teardown = False
172172
if self.work.connection.fileno() in writables and self.work.has_buffer():
173-
logger.debug(
174-
'Flushing buffer to client {0}'.format(self.work.address),
175-
)
173+
# logger.debug(
174+
# 'Flushing buffer to client {0}'.format(self.work.address),
175+
# )
176176
self.work.flush(self.flags.max_sendbuf_size)
177177
if self.must_flush_before_shutdown is True and \
178178
not self.work.has_buffer():

proxy/core/connection/connection.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
from abc import ABC, abstractmethod
1313
from typing import List, Union, Optional
1414

15+
from .leak import Leakage
1516
from .types import tcpConnectionTypes
1617
from ...common.types import TcpOrTlsSocket
1718
from ...common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_MAX_SEND_SIZE
1819

1920

2021
logger = logging.getLogger(__name__)
2122

23+
EMPTY_MV = memoryview(b"")
2224

2325
class TcpConnectionUninitializedException(Exception):
2426
pass
@@ -34,12 +36,23 @@ class TcpConnection(ABC):
3436
a socket connection object.
3537
"""
3638

37-
def __init__(self, tag: int) -> None:
39+
def __init__(
40+
self,
41+
tag: int,
42+
flush_bw_in_bps: int = 512,
43+
recv_bw_in_bps: int = 512,
44+
) -> None:
3845
self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client'
3946
self.buffer: List[memoryview] = []
4047
self.closed: bool = False
4148
self._reusable: bool = False
4249
self._num_buffer = 0
50+
self._flush_leakage = (
51+
Leakage(rate=flush_bw_in_bps) if flush_bw_in_bps > 0 else None
52+
)
53+
self._recv_leakage = (
54+
Leakage(rate=recv_bw_in_bps) if recv_bw_in_bps > 0 else None
55+
)
4356

4457
@property
4558
@abstractmethod
@@ -55,14 +68,19 @@ def recv(
5568
self, buffer_size: int = DEFAULT_BUFFER_SIZE,
5669
) -> Optional[memoryview]:
5770
"""Users must handle socket.error exceptions"""
71+
if self._recv_leakage is not None:
72+
allowed_bytes = self._recv_leakage.consume(buffer_size)
73+
if allowed_bytes == 0:
74+
return EMPTY_MV
75+
buffer_size = min(buffer_size, allowed_bytes)
5876
data: bytes = self.connection.recv(buffer_size)
59-
if len(data) == 0:
77+
size = len(data)
78+
if self._recv_leakage is not None:
79+
self._recv_leakage.putback(buffer_size - size)
80+
if size == 0:
6081
return None
61-
logger.debug(
62-
'received %d bytes from %s' %
63-
(len(data), self.tag),
64-
)
65-
# logger.info(data)
82+
logger.debug("received %d bytes from %s" % (size, self.tag))
83+
logger.info(data)
6684
return memoryview(data)
6785

6886
def close(self) -> bool:
@@ -75,6 +93,8 @@ def has_buffer(self) -> bool:
7593
return self._num_buffer != 0
7694

7795
def queue(self, mv: memoryview) -> None:
96+
if len(mv) == 0:
97+
return
7898
self.buffer.append(mv)
7999
self._num_buffer += 1
80100

@@ -83,21 +103,38 @@ def flush(self, max_send_size: Optional[int] = None) -> int:
83103
if not self.has_buffer():
84104
return 0
85105
mv = self.buffer[0]
106+
print(self.buffer)
107+
print(mv.tobytes())
86108
# TODO: Assemble multiple packets if total
87109
# size remains below max send size.
88110
max_send_size = max_send_size or DEFAULT_MAX_SEND_SIZE
89-
try:
90-
sent: int = self.send(mv[:max_send_size])
91-
except BlockingIOError:
92-
logger.warning('BlockingIOError when trying send to {0}'.format(self.tag))
93-
return 0
111+
allowed_bytes = (
112+
self._flush_leakage.consume(min(len(mv), max_send_size))
113+
if self._flush_leakage is not None
114+
else max_send_size
115+
)
116+
sent: int = 0
117+
if allowed_bytes > 0:
118+
try:
119+
sent = self.send(mv[:allowed_bytes])
120+
if self._flush_leakage is not None:
121+
self._flush_leakage.putback(allowed_bytes - sent)
122+
except BlockingIOError:
123+
logger.warning(
124+
"BlockingIOError when trying send to {0}".format(self.tag)
125+
)
126+
del mv
127+
return 0
128+
# if sent == 0:
129+
# return 0
94130
if sent == len(mv):
95131
self.buffer.pop(0)
96132
self._num_buffer -= 1
97133
else:
98134
self.buffer[0] = mv[sent:]
99-
logger.debug('flushed %d bytes to %s' % (sent, self.tag))
100-
# logger.info(mv[:sent].tobytes())
135+
# if sent > 0:
136+
logger.debug("flushed %d bytes to %s" % (sent, self.tag))
137+
logger.info(mv[:sent].tobytes())
101138
del mv
102139
return sent
103140

proxy/core/connection/leak.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
"""
11+
import time
12+
13+
14+
class Leakage:
15+
def __init__(self, rate: int):
16+
"""Initialize the leaky bucket with a specified leak rate in bytes per second."""
17+
self.rate = (
18+
rate # Maximum number of tokens the bucket can hold (bytes per second)
19+
)
20+
self.tokens = rate # Initially start with a full bucket
21+
self.last_check = time.time() # Record the current time
22+
23+
def _refill(self):
24+
"""Refill tokens based on the elapsed time since the last check."""
25+
now = time.time()
26+
elapsed = now - self.last_check
27+
# Add tokens proportional to elapsed time, up to the rate
28+
self.tokens += int(elapsed * self.rate)
29+
# Cap tokens at the maximum rate to enforce the rate limit
30+
self.tokens = min(self.tokens, self.rate)
31+
self.last_check = now # Update the last check time
32+
33+
def putback(self, tokens) -> None:
34+
self.tokens += tokens
35+
36+
def consume(self, amount: int) -> int:
37+
"""Attempt to consume the amount from the bucket.
38+
39+
Returns the amount allowed to be sent, up to the available tokens (rate).
40+
"""
41+
self._refill() # Refill the tokens before consumption
42+
allowed = min(amount, self.tokens) # Allow up to the available tokens
43+
self.tokens -= allowed # Subtract the allowed amount from the available tokens
44+
return allowed # Return the number of bytes allowed to be consumed

proxy/core/work/fd/fd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ def work(self, *args: Any) -> None:
4141
publisher_id=self.__class__.__qualname__,
4242
)
4343
try:
44+
logger.debug("Initializing work#{0}".format(fileno))
4445
self.works[fileno].initialize()
4546
self._total += 1
4647
except Exception as e:
4748
logger.exception( # pragma: no cover
4849
'Exception occurred during initialization',
4950
exc_info=e,
5051
)
51-
self._cleanup(fileno)
52+
self._cleanup(fileno, "error")
5253

5354
@property
5455
@abstractmethod

proxy/core/work/threadless.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -301,16 +301,18 @@ def _cleanup_inactive(self) -> None:
301301
if self.works[work_id].is_inactive():
302302
inactive_works.append(work_id)
303303
for work_id in inactive_works:
304-
self._cleanup(work_id)
304+
self._cleanup(work_id, "inactive")
305305

306306
# TODO: HttpProtocolHandler.shutdown can call flush which may block
307-
def _cleanup(self, work_id: int) -> None:
307+
def _cleanup(self, work_id: int, reason: str) -> None:
308308
if work_id in self.registered_events_by_work_ids:
309309
assert self.selector
310310
for fileno in self.registered_events_by_work_ids[work_id]:
311311
logger.debug(
312-
'fd#{0} unregistered by work#{1}'.format(
313-
fileno, work_id,
312+
"fd#{0} unregistered by work#{1}, reason: {2}".format(
313+
fileno,
314+
work_id,
315+
reason,
314316
),
315317
)
316318
self.selector.unregister(fileno)
@@ -360,7 +362,7 @@ async def _run_once(self) -> bool:
360362
return False
361363
# Invoke Threadless.handle_events
362364
self.unfinished.update(self._create_tasks(work_by_ids))
363-
# logger.debug('Executing {0} works'.format(len(self.unfinished)))
365+
# logger.debug("Executing {0} works".format(len(self.unfinished)))
364366
# Cleanup finished tasks
365367
for task in await self._wait_for_tasks():
366368
# Checking for result can raise exception e.g.
@@ -374,11 +376,12 @@ async def _run_once(self) -> bool:
374376
teardown = True
375377
finally:
376378
if teardown:
377-
self._cleanup(work_id)
379+
self._cleanup(work_id, "teardown")
378380
# self.cleanup(int(task.get_name()))
379381
# logger.debug(
380-
# 'Done executing works, {0} pending, {1} registered'.format(
381-
# len(self.unfinished), len(self.registered_events_by_work_ids),
382+
# "Done executing works, {0} pending, {1} registered".format(
383+
# len(self.unfinished),
384+
# len(self.registered_events_by_work_ids),
382385
# ),
383386
# )
384387
return False

proxy/http/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def client(
3232
conn_close: bool = True,
3333
scheme: bytes = HTTPS_PROTO,
3434
timeout: float = DEFAULT_TIMEOUT,
35-
content_type: bytes = b'application/x-www-form-urlencoded',
36-
verify: bool = True,
35+
content_type: bytes = b"application/x-www-form-urlencoded",
36+
verify: bool = False,
3737
) -> Optional[HttpParser]:
3838
"""HTTP Client"""
3939
request = build_http_request(

proxy/http/handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
190190

191191
async def handle_writables(self, writables: Writables) -> bool:
192192
if self.work.connection.fileno() in writables and self.work.has_buffer():
193-
logger.debug('Client is write ready, flushing...')
193+
# logger.debug('Client is write ready, flushing...')
194194
self.last_activity = time.time()
195195
# TODO(abhinavsingh): This hook could just reside within server recv block
196196
# instead of invoking when flushed to client.
@@ -219,7 +219,7 @@ async def handle_writables(self, writables: Writables) -> bool:
219219

220220
async def handle_readables(self, readables: Readables) -> bool:
221221
if self.work.connection.fileno() in readables:
222-
logger.debug('Client is read ready, receiving...')
222+
# logger.debug('Client is read ready, receiving...')
223223
self.last_activity = time.time()
224224
try:
225225
teardown = await super().handle_readables(readables)

proxy/http/server/reverse.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def handle_request(self, request: HttpParser) -> None:
8585
random.choice(route[1]),
8686
)
8787
needs_upstream = True
88+
logger.debug(
89+
"Starting connection to upstream {0}".format(self.choice)
90+
)
8891
break
8992
# Dynamic routes
9093
elif isinstance(route, str):
@@ -95,14 +98,23 @@ def handle_request(self, request: HttpParser) -> None:
9598
self.choice = choice
9699
needs_upstream = True
97100
self._upstream_proxy_pass = str(self.choice)
101+
logger.debug(
102+
"Starting connection to upstream {0}".format(choice)
103+
)
98104
elif isinstance(choice, memoryview):
99105
self.client.queue(choice)
100106
self._upstream_proxy_pass = '{0} bytes'.format(len(choice))
107+
logger.debug("Sending raw response to client")
101108
else:
102109
self.upstream = choice
103110
self._upstream_proxy_pass = '{0}:{1}'.format(
104111
*self.upstream.addr,
105112
)
113+
logger.debug(
114+
"Using existing connection to upstream {0}".format(
115+
self.upstream.addr
116+
)
117+
)
106118
break
107119
else:
108120
raise ValueError('Invalid route')

tests/core/test_leakage.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
"""
11+
import time
12+
13+
import unittest
14+
15+
from proxy.core.connection.leak import Leakage
16+
17+
18+
class TestTcpConnectionLeakage(unittest.TestCase):
19+
def test_initial_consume_no_tokens(self):
20+
# Test consuming with no tokens available initially
21+
rate = 100 # bytes per second
22+
bucket = Leakage(rate)
23+
self.assertEqual(
24+
bucket.consume(150), 100
25+
) # No tokens yet, so expect 0 bytes to be sent
26+
27+
def test_consume_with_refill(self):
28+
# Test consuming with refill after waiting
29+
rate = 100 # bytes per second
30+
bucket = Leakage(rate)
31+
time.sleep(1) # Wait for a second to allow refill
32+
self.assertEqual(bucket.consume(50), 50) # 50 bytes should be available
33+
34+
def test_consume_above_leak_rate(self):
35+
# Test attempting to consume more than the leak rate after a refill
36+
rate = 100 # bytes per second
37+
bucket = Leakage(rate)
38+
time.sleep(1) # Wait for a second to allow refill
39+
self.assertEqual(bucket.consume(150), 100) # Only 100 bytes should be allowed
40+
41+
def test_repeated_consume_with_partial_refill(self):
42+
# Test repeated consumption with partial refill
43+
rate = 100 # bytes per second
44+
bucket = Leakage(rate)
45+
46+
time.sleep(1) # Allow tokens to accumulate
47+
bucket.consume(80) # Consume 80 bytes, should leave 20
48+
time.sleep(0.5) # Wait half a second to refill by 50 bytes
49+
50+
self.assertEqual(bucket.consume(50), 50) # 50 bytes should be available now
51+
52+
def test_negative_token_guard(self):
53+
# Ensure tokens do not go negative
54+
rate = 100 # bytes per second
55+
bucket = Leakage(rate)
56+
time.sleep(1) # Allow tokens to accumulate
57+
bucket.consume(150) # Consume all available tokens
58+
self.assertEqual(bucket.consume(10), 0) # Should return 0 as no tokens are left
59+
self.assertEqual(bucket.tokens, 0) # Tokens should not be negative

tests/http/parser/test_http_parser.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,3 +908,17 @@ def test_cannot_parse_sip_protocol(self) -> None:
908908
b'\r\n',
909909
),
910910
)
911+
912+
def test_byte_by_byte(self) -> None:
913+
response = HttpParser(httpParserTypes.RESPONSE_PARSER)
914+
request = [
915+
# pylint: disable=line-too-long
916+
b'HTTP/1.1 200 OK\r\naccess-control-allow-credentials: true\r\naccess-control-allow-origin: *\r\ncontent-type: application/json; charset=utf-8\r\ndate: Thu, 14 Nov 2024 10:24:11 GMT\r\ncontent-length: 550\r\nserver: Fly/a40a59d0 (2024-11-12)\r\nvia: 1.1 fly.io\r\nfly-request-id: 01JCN37CEK4TB4DRWZDFFQYSD9-bom\r\n\r\n{\n "args": {},\n "headers": {\n "Accept": [\n "*/*"\n ],\n "Host": [\n "httpbingo.org"\n ],\n "User-Agent": [\n "curl/8.6.0"\n ],\n "Via": [\n "1.1 fly.io"\n ],\n "X-Forwarded-For": [\n "183.82.162.68, 66.241.125.232"\n ],\n "X-Forwarded-Port": [\n "443"\n ],\n "X-Forwarded-Proto": [\n "https"\n ],\n "X-Forwarded-Ssl',
917+
b'": [\n "on"\n ],\n "X-Request-Start": [\n "t=1731579851219982"\n ]\n },\n "method": "GET",\n "origin": "183.82.162.68",\n "url": "https://httpbingo.org/get"\n}\n',
918+
]
919+
response.parse(memoryview(request[0]))
920+
self.assertEqual(response.state, httpParserStates.RCVING_BODY)
921+
self.assertEqual(response.code, b"200")
922+
for byte in (bytes([b]) for b in request[1]):
923+
response.parse(memoryview(byte))
924+
self.assertEqual(response.state, httpParserStates.COMPLETE)

0 commit comments

Comments
 (0)