Skip to content

Commit 263c067

Browse files
Define work lifecycle events for pool (#918)
* Define work lifecycle events for pool * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use isinstance * Use mocker fixture to pass CI on 3.6 and 3.7 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f6214a4 commit 263c067

File tree

4 files changed

+109
-51
lines changed

4 files changed

+109
-51
lines changed

proxy/core/connection/pool.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
reusability
1414
"""
1515
import logging
16+
import selectors
1617

17-
from typing import Set, Dict, Tuple
18+
from typing import TYPE_CHECKING, Set, Dict, Tuple
1819

1920
from ...common.flag import flags
2021
from ...common.types import Readables, Writables
@@ -66,30 +67,36 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):
6667

6768
def __init__(self) -> None:
6869
# Pools of connection per upstream server
70+
self.connections: Dict[int, TcpServerConnection] = {}
6971
self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {}
7072

71-
def acquire(self, host: str, port: int) -> Tuple[bool, TcpServerConnection]:
73+
def add(self, addr: Tuple[str, int]) -> TcpServerConnection:
74+
# Create new connection
75+
new_conn = TcpServerConnection(addr[0], addr[1])
76+
new_conn.connect()
77+
if addr not in self.pools:
78+
self.pools[addr] = set()
79+
self.pools[addr].add(new_conn)
80+
self.connections[new_conn.connection.fileno()] = new_conn
81+
return new_conn
82+
83+
def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]:
7284
"""Returns a connection for use with the server."""
73-
addr = (host, port)
7485
# Return a reusable connection if available
7586
if addr in self.pools:
7687
for old_conn in self.pools[addr]:
7788
if old_conn.is_reusable():
7889
old_conn.mark_inuse()
7990
logger.debug(
8091
'Reusing connection#{2} for upstream {0}:{1}'.format(
81-
host, port, id(old_conn),
92+
addr[0], addr[1], id(old_conn),
8293
),
8394
)
8495
return False, old_conn
85-
# Create new connection
86-
new_conn = TcpServerConnection(*addr)
87-
if addr not in self.pools:
88-
self.pools[addr] = set()
89-
self.pools[addr].add(new_conn)
96+
new_conn = self.add(addr)
9097
logger.debug(
9198
'Created new connection#{2} for upstream {0}:{1}'.format(
92-
host, port, id(new_conn),
99+
addr[0], addr[1], id(new_conn),
93100
),
94101
)
95102
return True, new_conn
@@ -118,7 +125,17 @@ def release(self, conn: TcpServerConnection) -> None:
118125
conn.reset()
119126

120127
async def get_events(self) -> Dict[int, int]:
121-
return await super().get_events()
122-
123-
async def handle_events(self, readables: Readables, writables: Writables) -> bool:
124-
return await super().handle_events(readables, writables)
128+
events = {}
129+
for connections in self.pools.values():
130+
for conn in connections:
131+
events[conn.connection.fileno()] = selectors.EVENT_READ
132+
return events
133+
134+
async def handle_events(self, readables: Readables, _writables: Writables) -> bool:
135+
for r in readables:
136+
if TYPE_CHECKING:
137+
assert isinstance(r, int)
138+
conn = self.connections[r]
139+
self.pools[conn.addr].remove(conn)
140+
del self.connections[r]
141+
return False

proxy/core/connection/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class TcpServerConnection(TcpConnection):
2525
def __init__(self, host: str, port: int) -> None:
2626
super().__init__(tcpConnectionTypes.SERVER)
2727
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None
28-
self.addr: Tuple[str, int] = (host, int(port))
28+
self.addr: Tuple[str, int] = (host, port)
2929
self.closed = True
3030

3131
@property

proxy/http/proxy/server.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -586,29 +586,6 @@ def handle_pipeline_response(self, raw: memoryview) -> None:
586586
def connect_upstream(self) -> None:
587587
host, port = self.request.host, self.request.port
588588
if host and port:
589-
if self.flags.enable_conn_pool:
590-
assert self.upstream_conn_pool
591-
with self.lock:
592-
created, self.upstream = self.upstream_conn_pool.acquire(
593-
text_(host), port,
594-
)
595-
else:
596-
created, self.upstream = True, TcpServerConnection(
597-
text_(host), port,
598-
)
599-
if not created:
600-
# NOTE: Acquired connection might be in an unusable state.
601-
#
602-
# This can only be confirmed by reading from connection.
603-
# For stale connections, we will receive None, indicating
604-
# to drop the connection.
605-
#
606-
# If that happen, we must acquire a fresh connection.
607-
logger.info(
608-
'Reusing connection to upstream %s:%d' %
609-
(text_(host), port),
610-
)
611-
return
612589
try:
613590
logger.debug(
614591
'Connecting to upstream %s:%d' %
@@ -622,14 +599,37 @@ def connect_upstream(self) -> None:
622599
)
623600
if upstream_ip or source_addr:
624601
break
625-
# Connect with overridden upstream IP and source address
626-
# if any of the plugin returned a non-null value.
627-
self.upstream.connect(
628-
addr=None if not upstream_ip else (
629-
upstream_ip, port,
630-
), source_address=source_addr,
631-
)
632-
self.upstream.connection.setblocking(False)
602+
if self.flags.enable_conn_pool:
603+
assert self.upstream_conn_pool
604+
with self.lock:
605+
created, self.upstream = self.upstream_conn_pool.acquire(
606+
(text_(host), port),
607+
)
608+
else:
609+
created, self.upstream = True, TcpServerConnection(
610+
text_(host), port,
611+
)
612+
# Connect with overridden upstream IP and source address
613+
# if any of the plugin returned a non-null value.
614+
self.upstream.connect(
615+
addr=None if not upstream_ip else (
616+
upstream_ip, port,
617+
), source_address=source_addr,
618+
)
619+
self.upstream.connection.setblocking(False)
620+
if not created:
621+
# NOTE: Acquired connection might be in an unusable state.
622+
#
623+
# This can only be confirmed by reading from connection.
624+
# For stale connections, we will receive None, indicating
625+
# to drop the connection.
626+
#
627+
# If that happen, we must acquire a fresh connection.
628+
logger.info(
629+
'Reusing connection to upstream %s:%d' %
630+
(text_(host), port),
631+
)
632+
return
633633
logger.debug(
634634
'Connected to upstream %s:%s' %
635635
(text_(host), port),
@@ -640,7 +640,7 @@ def connect_upstream(self) -> None:
640640
text_(host), port, str(e),
641641
),
642642
)
643-
if self.flags.enable_conn_pool:
643+
if self.flags.enable_conn_pool and self.upstream:
644644
assert self.upstream_conn_pool
645645
with self.lock:
646646
self.upstream_conn_pool.release(self.upstream)

tests/core/test_conn_pool.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11+
import pytest
1112
import unittest
13+
import selectors
1214

1315
from unittest import mock
16+
from pytest_mock import MockerFixture
1417

1518
from proxy.core.connection import UpstreamConnectionPool
1619

@@ -28,7 +31,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc
2831
]
2932
mock_conn.closed = False
3033
# Acquire
31-
created, conn = pool.acquire(*addr)
34+
created, conn = pool.acquire(addr)
3235
self.assertTrue(created)
3336
mock_tcp_server_connection.assert_called_once_with(*addr)
3437
self.assertEqual(conn, mock_conn)
@@ -39,7 +42,7 @@ def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: moc
3942
self.assertEqual(len(pool.pools[addr]), 1)
4043
self.assertTrue(conn in pool.pools[addr])
4144
# Reacquire
42-
created, conn = pool.acquire(*addr)
45+
created, conn = pool.acquire(addr)
4346
self.assertFalse(created)
4447
mock_conn.reset.assert_called_once()
4548
self.assertEqual(conn, mock_conn)
@@ -57,7 +60,7 @@ def test_closed_connections_are_removed_on_release(
5760
mock_conn.closed = True
5861
mock_conn.addr = addr
5962
# Acquire
60-
created, conn = pool.acquire(*addr)
63+
created, conn = pool.acquire(addr)
6164
self.assertTrue(created)
6265
mock_tcp_server_connection.assert_called_once_with(*addr)
6366
self.assertEqual(conn, mock_conn)
@@ -67,7 +70,45 @@ def test_closed_connections_are_removed_on_release(
6770
pool.release(conn)
6871
self.assertEqual(len(pool.pools[addr]), 0)
6972
# Acquire
70-
created, conn = pool.acquire(*addr)
73+
created, conn = pool.acquire(addr)
7174
self.assertTrue(created)
7275
self.assertEqual(mock_tcp_server_connection.call_count, 2)
7376
mock_conn.is_reusable.assert_not_called()
77+
78+
79+
class TestConnectionPoolAsync:
80+
81+
@pytest.mark.asyncio # type: ignore[misc]
82+
async def test_get_events(self, mocker: MockerFixture) -> None:
83+
mock_tcp_server_connection = mocker.patch(
84+
'proxy.core.connection.pool.TcpServerConnection',
85+
)
86+
pool = UpstreamConnectionPool()
87+
addr = ('localhost', 1234)
88+
mock_conn = mock_tcp_server_connection.return_value
89+
pool.add(addr)
90+
mock_tcp_server_connection.assert_called_once_with(*addr)
91+
mock_conn.connect.assert_called_once()
92+
events = await pool.get_events()
93+
print(events)
94+
assert events == {
95+
mock_conn.connection.fileno.return_value: selectors.EVENT_READ,
96+
}
97+
assert pool.pools[addr].pop() == mock_conn
98+
assert len(pool.pools[addr]) == 0
99+
assert pool.connections[mock_conn.connection.fileno.return_value] == mock_conn
100+
101+
@pytest.mark.asyncio # type: ignore[misc]
102+
async def test_handle_events(self, mocker: MockerFixture) -> None:
103+
mock_tcp_server_connection = mocker.patch(
104+
'proxy.core.connection.pool.TcpServerConnection',
105+
)
106+
pool = UpstreamConnectionPool()
107+
mock_conn = mock_tcp_server_connection.return_value
108+
addr = mock_conn.addr
109+
pool.add(addr)
110+
assert len(pool.pools[addr]) == 1
111+
assert len(pool.connections) == 1
112+
await pool.handle_events([mock_conn.connection.fileno.return_value], [])
113+
assert len(pool.pools[addr]) == 0
114+
assert len(pool.connections) == 0

0 commit comments

Comments
 (0)