Skip to content

Commit ac183f5

Browse files
authored
[Core] Invoke selector.close on shutdown (#1055)
[Core] Invoke `selector.close` on shutdown
1 parent 3858f3a commit ac183f5

File tree

6 files changed

+86
-14
lines changed

6 files changed

+86
-14
lines changed

proxy/core/acceptor/acceptor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def run(self) -> None:
178178
for fileno in self.socks:
179179
self.socks[fileno].close()
180180
self.socks.clear()
181+
self.selector.close()
181182
logger.debug('Acceptor#%d shutdown', self.idd)
182183

183184
def _recv_and_setup_socks(self) -> None:
@@ -207,7 +208,8 @@ def _start_local(self) -> None:
207208
self._lthread.start()
208209

209210
def _stop_local(self) -> None:
210-
if self._lthread is not None and self._local_work_queue is not None:
211+
if self._lthread is not None and \
212+
self._local_work_queue is not None:
211213
self._local_work_queue.put(False)
212214
self._lthread.join()
213215

proxy/core/work/threadless.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ def run(self) -> None:
419419
if wqfileno is not None:
420420
self.selector.unregister(wqfileno)
421421
self.close_work_queue()
422+
self.selector.close()
422423
assert self.loop is not None
423424
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
424425
self.loop.close()

proxy/http/handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ def run(self) -> None:
351351
)
352352
finally:
353353
self.shutdown()
354+
if self.selector:
355+
self.selector.close()
354356
loop.close()
355357

356358
async def _run_once(self) -> bool:

proxy/http/parser/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def response(cls: Type[T], raw: bytes) -> T:
114114
def header(self, key: bytes) -> bytes:
115115
"""Convenient method to return original header value from internal data structure."""
116116
if self.headers is None or key.lower() not in self.headers:
117-
raise KeyError('%s not found in headers', text_(key))
117+
raise KeyError('%s not found in headers' % text_(key))
118118
return self.headers[key.lower()][1]
119119

120120
def has_header(self, key: bytes) -> bool:

proxy/http/websocket/client.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828

2929
class WebsocketClient(TcpConnection):
30+
"""Websocket client connection.
31+
32+
TODO: Make me compatible with the work framework."""
3033

3134
def __init__(
3235
self,
@@ -57,10 +60,14 @@ def connection(self) -> TcpOrTlsSocket:
5760
return self.sock
5861

5962
def handshake(self) -> None:
63+
"""Start websocket upgrade & handshake protocol"""
6064
self.upgrade()
6165
self.sock.setblocking(False)
6266

6367
def upgrade(self) -> None:
68+
"""Creates a key and sends websocket handshake packet to upstream.
69+
Receives response from the server and asserts that websocket
70+
accept header is valid in the response."""
6471
key = base64.b64encode(secrets.token_bytes(16))
6572
self.sock.send(
6673
build_websocket_handshake_request(
@@ -74,12 +81,6 @@ def upgrade(self) -> None:
7481
accept = response.header(b'Sec-Websocket-Accept')
7582
assert WebsocketFrame.key_to_accept(key) == accept
7683

77-
def ping(self, data: Optional[bytes] = None) -> None:
78-
pass # pragma: no cover
79-
80-
def pong(self, data: Optional[bytes] = None) -> None:
81-
pass # pragma: no cover
82-
8384
def shutdown(self, _data: Optional[bytes] = None) -> None:
8485
"""Closes connection with the server."""
8586
super().close()
@@ -121,3 +122,4 @@ def run(self) -> None:
121122
except OSError:
122123
pass
123124
self.sock.close()
125+
self.selector.close()

tests/http/websocket/test_websocket_client.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,28 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11+
import selectors
12+
1113
import unittest
1214
from unittest import mock
1315

1416
from proxy.common.utils import (
1517
build_websocket_handshake_request, build_websocket_handshake_response,
1618
)
1719
from proxy.http.websocket import WebsocketFrame, WebsocketClient
18-
from proxy.common.constants import DEFAULT_PORT
20+
from proxy.common.constants import DEFAULT_PORT, DEFAULT_BUFFER_SIZE
1921

2022

2123
class TestWebsocketClient(unittest.TestCase):
2224

23-
@mock.patch('proxy.http.websocket.client.socket.gethostbyname')
2425
@mock.patch('base64.b64encode')
26+
@mock.patch('proxy.http.websocket.client.socket.gethostbyname')
2527
@mock.patch('proxy.http.websocket.client.new_socket_connection')
26-
def test_handshake(
27-
self, mock_connect: mock.Mock,
28-
mock_b64encode: mock.Mock,
29-
mock_gethostbyname: mock.Mock,
28+
def test_handshake_success(
29+
self,
30+
mock_connect: mock.Mock,
31+
mock_gethostbyname: mock.Mock,
32+
mock_b64encode: mock.Mock,
3033
) -> None:
3134
key = b'MySecretKey'
3235
mock_b64encode.return_value = key
@@ -35,9 +38,71 @@ def test_handshake(
3538
build_websocket_handshake_response(
3639
WebsocketFrame.key_to_accept(key),
3740
)
41+
mock_connect.assert_not_called()
3842
client = WebsocketClient(b'localhost', DEFAULT_PORT)
43+
mock_connect.assert_called_once()
3944
mock_connect.return_value.send.assert_not_called()
4045
client.handshake()
4146
mock_connect.return_value.send.assert_called_with(
4247
build_websocket_handshake_request(key),
4348
)
49+
mock_connect.return_value.recv.assert_called_once_with(
50+
DEFAULT_BUFFER_SIZE,
51+
)
52+
53+
@mock.patch('base64.b64encode')
54+
@mock.patch('selectors.DefaultSelector')
55+
@mock.patch('proxy.http.websocket.client.new_socket_connection')
56+
def test_send_recv_frames_success(
57+
self,
58+
mock_connect: mock.Mock,
59+
mock_selector: mock.Mock,
60+
mock_b64encode: mock.Mock,
61+
) -> None:
62+
key = b'MySecretKey'
63+
mock_b64encode.return_value = key
64+
mock_connect.return_value.recv.side_effect = [
65+
build_websocket_handshake_response(
66+
WebsocketFrame.key_to_accept(key),
67+
),
68+
WebsocketFrame.text(b'world'),
69+
]
70+
71+
def on_message(frame: WebsocketFrame) -> None:
72+
assert frame.build() == WebsocketFrame.text(b'world')
73+
74+
client = WebsocketClient(
75+
b'localhost', DEFAULT_PORT, on_message=on_message,
76+
)
77+
mock_selector.assert_called_once()
78+
client.handshake()
79+
client.queue(memoryview(WebsocketFrame.text(b'hello')))
80+
mock_connect.return_value.send.assert_called_once()
81+
mock_selector.return_value.select.side_effect = [
82+
[
83+
(mock.Mock(), selectors.EVENT_WRITE),
84+
],
85+
]
86+
client.run_once()
87+
self.assertEqual(mock_connect.return_value.send.call_count, 2)
88+
mock_selector.return_value.select.side_effect = [
89+
[
90+
(mock.Mock(), selectors.EVENT_READ),
91+
],
92+
]
93+
client.run_once()
94+
95+
@mock.patch('selectors.DefaultSelector')
96+
@mock.patch('proxy.http.websocket.client.new_socket_connection')
97+
def test_run(
98+
self,
99+
mock_connect: mock.Mock,
100+
mock_selector: mock.Mock,
101+
) -> None:
102+
mock_selector.return_value.select.side_effect = KeyboardInterrupt
103+
client = WebsocketClient(b'localhost', DEFAULT_PORT)
104+
client.run()
105+
mock_connect.return_value.shutdown.assert_called_once()
106+
mock_connect.return_value.close.assert_called_once()
107+
mock_selector.return_value.unregister.assert_called_once_with(mock_connect.return_value)
108+
mock_selector.return_value.close.assert_called_once()

0 commit comments

Comments
 (0)