|
15 | 15 |
|
16 | 16 | import logging
|
17 | 17 | import socket
|
| 18 | +import uuid |
18 | 19 |
|
19 | 20 | from unittest.mock import patch, Mock
|
20 | 21 |
|
21 | 22 | from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\
|
22 | 23 | InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion
|
23 | 24 | from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \
|
24 | 25 | ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT
|
| 26 | +from cassandra.connection import SniEndPoint, SniEndPointFactory |
25 | 27 | from cassandra.pool import Host
|
26 | 28 | from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy
|
27 | 29 | from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory
|
|
31 | 33 |
|
32 | 34 | log = logging.getLogger(__name__)
|
33 | 35 |
|
| 36 | + |
34 | 37 | class ExceptionTypeTest(unittest.TestCase):
|
35 | 38 |
|
36 | 39 | def test_exception_types(self):
|
@@ -85,6 +88,12 @@ def test_exception_types(self):
|
85 | 88 | self.assertTrue(issubclass(UnsupportedOperation, DriverException))
|
86 | 89 |
|
87 | 90 |
|
| 91 | +class MockOrderedPolicy(RoundRobinPolicy): |
| 92 | + all_hosts = set() |
| 93 | + |
| 94 | + def make_query_plan(self, working_keyspace=None, query=None): |
| 95 | + return sorted(self.all_hosts, key=lambda x: x.endpoint.ssl_options['server_hostname']) |
| 96 | + |
88 | 97 | class ClusterTest(unittest.TestCase):
|
89 | 98 |
|
90 | 99 | def test_tuple_for_contact_points(self):
|
@@ -119,6 +128,26 @@ def test_requests_in_flight_threshold(self):
|
119 | 128 | for n in (0, mn, 128):
|
120 | 129 | self.assertRaises(ValueError, c.set_max_requests_per_connection, d, n)
|
121 | 130 |
|
| 131 | + # Validate that at least the default LBP can create a query plan with end points that resolve |
| 132 | + # to different addresses initially. This may not be exactly how things play out in practice |
| 133 | + # (the control connection will muck with this even if nothing else does) but it should be |
| 134 | + # a pretty good approximation. |
| 135 | + def test_query_plan_for_sni_contains_unique_addresses(self): |
| 136 | + node_cnt = 5 |
| 137 | + def _mocked_proxy_dns_resolution(self): |
| 138 | + return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, ('127.0.0.%s' % (i,), 9042)) for i in range(node_cnt)] |
| 139 | + |
| 140 | + c = Cluster() |
| 141 | + lbp = c.load_balancing_policy |
| 142 | + lbp.local_dc = "dc1" |
| 143 | + factory = SniEndPointFactory("proxy.foo.bar", 9042) |
| 144 | + for host in (Host(factory.create({"host_id": uuid.uuid4().hex, "dc": "dc1"}), SimpleConvictionPolicy) for _ in range(node_cnt)): |
| 145 | + lbp.on_up(host) |
| 146 | + with patch.object(SniEndPoint, '_resolve_proxy_addresses', _mocked_proxy_dns_resolution): |
| 147 | + addrs = [host.endpoint.resolve() for host in lbp.make_query_plan()] |
| 148 | + # single SNI endpoint should be resolved to multiple unique IP addresses |
| 149 | + self.assertEqual(len(addrs), len(set(addrs))) |
| 150 | + |
122 | 151 |
|
123 | 152 | class SchedulerTest(unittest.TestCase):
|
124 | 153 | # TODO: this suite could be expanded; for now just adding a test covering a ticket
|
|
0 commit comments