Skip to content

Commit 1fd9c82

Browse files
PYTHON-1419 Connection failure to SNI endpoint when first host is unavailable (#1243)
1 parent eebca73 commit 1fd9c82

File tree

3 files changed

+58
-6
lines changed

3 files changed

+58
-6
lines changed

cassandra/connection.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ def create(self, row):
245245
class SniEndPoint(EndPoint):
246246
"""SNI Proxy EndPoint implementation."""
247247

248-
def __init__(self, proxy_address, server_name, port=9042):
248+
def __init__(self, proxy_address, server_name, port=9042, init_index=0):
249249
self._proxy_address = proxy_address
250-
self._index = 0
250+
self._index = init_index
251251
self._resolved_address = None # resolved address
252252
self._port = port
253253
self._server_name = server_name
@@ -267,8 +267,7 @@ def ssl_options(self):
267267

268268
def resolve(self):
269269
try:
270-
resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port,
271-
socket.AF_UNSPEC, socket.SOCK_STREAM)
270+
resolved_addresses = self._resolve_proxy_addresses()
272271
except socket.gaierror:
273272
log.debug('Could not resolve sni proxy hostname "%s" '
274273
'with port %d' % (self._proxy_address, self._port))
@@ -280,6 +279,10 @@ def resolve(self):
280279

281280
return self._resolved_address, self._port
282281

282+
def _resolve_proxy_addresses(self):
283+
return socket.getaddrinfo(self._proxy_address, self._port,
284+
socket.AF_UNSPEC, socket.SOCK_STREAM)
285+
283286
def __eq__(self, other):
284287
return (isinstance(other, SniEndPoint) and
285288
self.address == other.address and self.port == other.port and
@@ -305,16 +308,24 @@ class SniEndPointFactory(EndPointFactory):
305308
def __init__(self, proxy_address, port):
306309
self._proxy_address = proxy_address
307310
self._port = port
311+
# Initial lookup index to prevent all SNI endpoints to be resolved
312+
# into the same starting IP address (which might not be available currently).
313+
# If SNI resolves to 3 IPs, first endpoint will connect to first
314+
# IP address, and subsequent resolutions to next IPs in round-robin
315+
# fusion.
316+
self._init_index = -1
308317

309318
def create(self, row):
310319
host_id = row.get("host_id")
311320
if host_id is None:
312321
raise ValueError("No host_id to create the SniEndPoint")
313322

314-
return SniEndPoint(self._proxy_address, str(host_id), self._port)
323+
self._init_index += 1
324+
return SniEndPoint(self._proxy_address, str(host_id), self._port, self._init_index)
315325

316326
def create_from_sni(self, sni):
317-
return SniEndPoint(self._proxy_address, sni, self._port)
327+
self._init_index += 1
328+
return SniEndPoint(self._proxy_address, sni, self._port, self._init_index)
318329

319330

320331
@total_ordering

tests/unit/test_cluster.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
import logging
1717
import socket
18+
import uuid
1819

1920
from unittest.mock import patch, Mock
2021

2122
from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\
2223
InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion
2324
from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \
2425
ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT
26+
from cassandra.connection import SniEndPoint, SniEndPointFactory
2527
from cassandra.pool import Host
2628
from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy
2729
from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory
@@ -31,6 +33,7 @@
3133

3234
log = logging.getLogger(__name__)
3335

36+
3437
class ExceptionTypeTest(unittest.TestCase):
3538

3639
def test_exception_types(self):
@@ -85,6 +88,12 @@ def test_exception_types(self):
8588
self.assertTrue(issubclass(UnsupportedOperation, DriverException))
8689

8790

91+
class MockOrderedPolicy(RoundRobinPolicy):
92+
all_hosts = set()
93+
94+
def make_query_plan(self, working_keyspace=None, query=None):
95+
return sorted(self.all_hosts, key=lambda x: x.endpoint.ssl_options['server_hostname'])
96+
8897
class ClusterTest(unittest.TestCase):
8998

9099
def test_tuple_for_contact_points(self):
@@ -119,6 +128,26 @@ def test_requests_in_flight_threshold(self):
119128
for n in (0, mn, 128):
120129
self.assertRaises(ValueError, c.set_max_requests_per_connection, d, n)
121130

131+
# Validate that at least the default LBP can create a query plan with end points that resolve
132+
# to different addresses initially. This may not be exactly how things play out in practice
133+
# (the control connection will muck with this even if nothing else does) but it should be
134+
# a pretty good approximation.
135+
def test_query_plan_for_sni_contains_unique_addresses(self):
136+
node_cnt = 5
137+
def _mocked_proxy_dns_resolution(self):
138+
return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, ('127.0.0.%s' % (i,), 9042)) for i in range(node_cnt)]
139+
140+
c = Cluster()
141+
lbp = c.load_balancing_policy
142+
lbp.local_dc = "dc1"
143+
factory = SniEndPointFactory("proxy.foo.bar", 9042)
144+
for host in (Host(factory.create({"host_id": uuid.uuid4().hex, "dc": "dc1"}), SimpleConvictionPolicy) for _ in range(node_cnt)):
145+
lbp.on_up(host)
146+
with patch.object(SniEndPoint, '_resolve_proxy_addresses', _mocked_proxy_dns_resolution):
147+
addrs = [host.endpoint.resolve() for host in lbp.make_query_plan()]
148+
# single SNI endpoint should be resolved to multiple unique IP addresses
149+
self.assertEqual(len(addrs), len(set(addrs)))
150+
122151

123152
class SchedulerTest(unittest.TestCase):
124153
# TODO: this suite could be expanded; for now just adding a test covering a ticket

tests/unit/test_endpoints.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,15 @@ def test_endpoint_resolve(self):
6565
for i in range(10):
6666
(address, _) = endpoint.resolve()
6767
self.assertEqual(address, next(it))
68+
69+
def test_sni_resolution_start_index(self):
70+
factory = SniEndPointFactory("proxy.datastax.com", 9999)
71+
initial_index = factory._init_index
72+
73+
endpoint1 = factory.create_from_sni('sni1')
74+
self.assertEqual(factory._init_index, initial_index + 1)
75+
self.assertEqual(endpoint1._index, factory._init_index)
76+
77+
endpoint2 = factory.create_from_sni('sni2')
78+
self.assertEqual(factory._init_index, initial_index + 2)
79+
self.assertEqual(endpoint2._index, factory._init_index)

0 commit comments

Comments
 (0)