Skip to content

Commit cded2c1

Browse files
committed
[PYTHON-1163] Twisted cloud support
1 parent 14ba2d1 commit cded2c1

File tree

5 files changed

+55
-26
lines changed

5 files changed

+55
-26
lines changed

cassandra/cluster.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636

3737
import weakref
3838
from weakref import WeakValueDictionary
39+
40+
try:
41+
from cassandra.io.twistedreactor import TwistedConnection
42+
except ImportError:
43+
TwistedConnection = None
44+
3945
try:
4046
from weakref import WeakSet
4147
except ImportError:
@@ -906,13 +912,18 @@ def __init__(self,
906912
907913
Any of the mutable Cluster attributes may be set as keyword arguments to the constructor.
908914
"""
915+
if connection_class is not None:
916+
self.connection_class = connection_class
909917

910918
if cloud is not None:
911919
if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options:
912920
raise ValueError("contact_points, endpoint_factory, ssl_context, and ssl_options "
913921
"cannot be specified with a cloud configuration")
914922

915-
cloud_config = dscloud.get_cloud_config(cloud)
923+
cloud_config = dscloud.get_cloud_config(
924+
cloud,
925+
create_pyopenssl_context=self.connection_class is TwistedConnection
926+
)
916927

917928
ssl_context = cloud_config.ssl_context
918929
ssl_options = {'check_hostname': True}
@@ -994,9 +1005,6 @@ def __init__(self,
9941005
raise TypeError("address_translator should not be a class, it should be an instance of that class")
9951006
self.address_translator = address_translator
9961007

997-
if connection_class is not None:
998-
self.connection_class = connection_class
999-
10001008
if timestamp_generator is not None:
10011009
if not callable(timestamp_generator):
10021010
raise ValueError("timestamp_generator must be callable")

cassandra/datastax/cloud/__init__.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,34 +75,37 @@ def from_dict(cls, d):
7575
return c
7676

7777

78-
def get_cloud_config(cloud_config):
78+
def get_cloud_config(cloud_config, create_pyopenssl_context=False):
7979
if not _HAS_SSL:
8080
raise DriverException("A Python installation with SSL is required to connect to a cloud cluster.")
8181

8282
if 'secure_connect_bundle' not in cloud_config:
8383
raise ValueError("The cloud config doesn't have a secure_connect_bundle specified.")
8484

8585
try:
86-
config = read_cloud_config_from_zip(cloud_config)
86+
config = read_cloud_config_from_zip(cloud_config, create_pyopenssl_context)
8787
except BadZipFile:
8888
raise ValueError("Unable to open the zip file for the cloud config. Check your secure connect bundle.")
8989

90-
return read_metadata_info(config, cloud_config)
90+
config = read_metadata_info(config, cloud_config)
91+
if create_pyopenssl_context:
92+
config.ssl_context = config.pyopenssl_context
93+
return config
9194

9295

93-
def read_cloud_config_from_zip(cloud_config):
96+
def read_cloud_config_from_zip(cloud_config, create_pyopenssl_context):
9497
secure_bundle = cloud_config['secure_connect_bundle']
9598
with ZipFile(secure_bundle) as zipfile:
9699
base_dir = os.path.dirname(secure_bundle)
97100
tmp_dir = tempfile.mkdtemp(dir=base_dir)
98101
try:
99102
zipfile.extractall(path=tmp_dir)
100-
return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config)
103+
return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config, create_pyopenssl_context)
101104
finally:
102105
shutil.rmtree(tmp_dir)
103106

104107

105-
def parse_cloud_config(path, cloud_config):
108+
def parse_cloud_config(path, cloud_config, create_pyopenssl_context):
106109
with open(path, 'r') as stream:
107110
data = json.load(stream)
108111

@@ -116,7 +119,11 @@ def parse_cloud_config(path, cloud_config):
116119
ca_cert_location = os.path.join(config_dir, 'ca.crt')
117120
cert_location = os.path.join(config_dir, 'cert')
118121
key_location = os.path.join(config_dir, 'key')
122+
# Regardless of if we create a pyopenssl context, we still need the builtin one
123+
# to connect to the metadata service
119124
config.ssl_context = _ssl_context_from_cert(ca_cert_location, cert_location, key_location)
125+
if create_pyopenssl_context:
126+
config.pyopenssl_context = _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location)
120127

121128
return config
122129

@@ -165,3 +172,17 @@ def _ssl_context_from_cert(ca_cert_location, cert_location, key_location):
165172
ssl_context.load_cert_chain(certfile=cert_location, keyfile=key_location)
166173

167174
return ssl_context
175+
176+
177+
def _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location):
178+
try:
179+
from OpenSSL import SSL
180+
except ImportError:
181+
return None
182+
ssl_context = SSL.Context(SSL.TLSv1_METHOD)
183+
ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok)
184+
ssl_context.use_certificate_file(cert_location)
185+
ssl_context.use_privatekey_file(key_location)
186+
ssl_context.load_verify_locations(ca_cert_location)
187+
188+
return ssl_context

cassandra/io/twistedreactor.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from twisted.python.failure import Failure
2929
from zope.interface import implementer
3030

31-
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
31+
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager, ConnectionException
3232

3333
try:
3434
from OpenSSL import SSL
@@ -77,6 +77,7 @@ def connectionMade(self):
7777

7878
def connectionLost(self, reason):
7979
# reason is a Failure instance
80+
log.debug("Connect lost: %s", reason)
8081
self.connection.defunct(reason.value)
8182

8283

@@ -134,9 +135,9 @@ def _on_loop_timer(self):
134135

135136

136137
@implementer(IOpenSSLClientConnectionCreator)
137-
class SSLCreator(object):
138-
def __init__(self, host, ssl_context, ssl_options, check_hostname, timeout):
139-
self.host = host
138+
class _SSLCreator(object):
139+
def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout):
140+
self.endpoint = endpoint
140141
self.ssl_options = ssl_options
141142
self.check_hostname = check_hostname
142143
self.timeout = timeout
@@ -163,9 +164,9 @@ def verify_callback(self, connection, x509, errnum, errdepth, ok):
163164

164165
def info_callback(self, connection, where, ret):
165166
if where & SSL.SSL_CB_HANDSHAKE_DONE:
166-
if self.check_hostname and self.host != connection.get_peer_certificate().get_subject().commonName:
167+
if self.check_hostname and self.endpoint.address != connection.get_peer_certificate().get_subject().commonName:
167168
transport = connection.get_app_data()
168-
transport.failVerification(Failure(Exception("Hostname verification failed")))
169+
transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint)))
169170

170171
def clientConnectionForTLS(self, tlsProtocol):
171172
connection = SSL.Connection(self.context, None)
@@ -213,7 +214,7 @@ def __init__(self, *args, **kwargs):
213214
self._loop.maybe_start()
214215

215216
def _check_pyopenssl(self):
216-
if self.ssl_options:
217+
if self.ssl_context or self.ssl_options:
217218
if not _HAS_SSL:
218219
raise ImportError(
219220
str(import_exception) +
@@ -231,29 +232,29 @@ def add_connection(self):
231232
# Cool they enforce strong security, but we have to be able to turn it off
232233
self._check_pyopenssl()
233234

234-
ssl_options = SSLCreator(
235-
self.endpoint.address,
235+
ssl_connection_creator = _SSLCreator(
236+
self.endpoint,
236237
self.ssl_context if self.ssl_context else None,
237238
self.ssl_options,
238239
self._check_hostname,
239240
self.connect_timeout,
240241
)
241242

242-
point = SSL4ClientEndpoint(
243+
endpoint = SSL4ClientEndpoint(
243244
reactor,
244245
host,
245246
port,
246-
sslContextFactory=ssl_options,
247+
sslContextFactory=ssl_connection_creator,
247248
timeout=self.connect_timeout,
248249
)
249250
else:
250-
point = TCP4ClientEndpoint(
251+
endpoint = TCP4ClientEndpoint(
251252
reactor,
252253
host,
253254
port,
254255
timeout=self.connect_timeout
255256
)
256-
connectProtocol(point, TwistedConnectionProtocol(self))
257+
connectProtocol(endpoint, TwistedConnectionProtocol(self))
257258

258259
def client_connection_made(self, transport):
259260
"""

test-requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@ unittest2
77
pytz
88
sure
99
pure-sasl
10-
twisted; python_version >= '3.5'
10+
twisted[tls]; python_version >= '3.5'
1111
twisted[tls]==19.2.1; python_version < '3.5'
1212
gevent>=1.0
1313
eventlet
1414
cython>=0.20,<0.30
1515
packaging
1616
futurist; python_version >= '3.7'
1717
asynctest; python_version > '3.4'
18-
pyopenssl

tests/integration/advanced/cloud/test_cloud.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_verify_hostname(self):
107107
with patch('cassandra.datastax.cloud.parse_metadata_info', wraps=self._bad_hostname_metadata):
108108
with self.assertRaises(NoHostAvailable) as e:
109109
self.connect(self.creds)
110-
self.assertIn("hostname", str(e.exception))
110+
self.assertIn("hostname", str(e.exception).lower())
111111

112112
def test_error_when_bundle_doesnt_exist(self):
113113
try:

0 commit comments

Comments
 (0)