Skip to content

Commit 4073c0a

Browse files
committed
Merge branch 'master' into oss-next for the twisted cloud work
2 parents 76fdd52 + 9d7dddd commit 4073c0a

File tree

9 files changed

+294
-215
lines changed

9 files changed

+294
-215
lines changed

cassandra/cluster.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@
3838

3939
import weakref
4040
from weakref import WeakValueDictionary
41+
42+
try:
43+
from cassandra.io.twistedreactor import TwistedConnection
44+
except ImportError:
45+
TwistedConnection = None
46+
4147
try:
4248
from weakref import WeakSet
4349
except ImportError:
@@ -1090,13 +1096,18 @@ def __init__(self,
10901096
10911097
Any of the mutable Cluster attributes may be set as keyword arguments to the constructor.
10921098
"""
1099+
if connection_class is not None:
1100+
self.connection_class = connection_class
10931101

10941102
if cloud is not None:
10951103
if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options:
10961104
raise ValueError("contact_points, endpoint_factory, ssl_context, and ssl_options "
10971105
"cannot be specified with a cloud configuration")
10981106

1099-
cloud_config = dscloud.get_cloud_config(cloud)
1107+
cloud_config = dscloud.get_cloud_config(
1108+
cloud,
1109+
create_pyopenssl_context=self.connection_class is TwistedConnection
1110+
)
11001111

11011112
ssl_context = cloud_config.ssl_context
11021113
ssl_options = {'check_hostname': True}
@@ -1188,9 +1199,6 @@ def __init__(self,
11881199
raise TypeError("address_translator should not be a class, it should be an instance of that class")
11891200
self.address_translator = address_translator
11901201

1191-
if connection_class is not None:
1192-
self.connection_class = connection_class
1193-
11941202
if timestamp_generator is not None:
11951203
if not callable(timestamp_generator):
11961204
raise ValueError("timestamp_generator must be callable")

cassandra/datastax/cloud/__init__.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,34 +75,37 @@ def from_dict(cls, d):
7575
return c
7676

7777

78-
def get_cloud_config(cloud_config):
78+
def get_cloud_config(cloud_config, create_pyopenssl_context=False):
7979
if not _HAS_SSL:
8080
raise DriverException("A Python installation with SSL is required to connect to a cloud cluster.")
8181

8282
if 'secure_connect_bundle' not in cloud_config:
8383
raise ValueError("The cloud config doesn't have a secure_connect_bundle specified.")
8484

8585
try:
86-
config = read_cloud_config_from_zip(cloud_config)
86+
config = read_cloud_config_from_zip(cloud_config, create_pyopenssl_context)
8787
except BadZipFile:
8888
raise ValueError("Unable to open the zip file for the cloud config. Check your secure connect bundle.")
8989

90-
return read_metadata_info(config, cloud_config)
90+
config = read_metadata_info(config, cloud_config)
91+
if create_pyopenssl_context:
92+
config.ssl_context = config.pyopenssl_context
93+
return config
9194

9295

93-
def read_cloud_config_from_zip(cloud_config):
96+
def read_cloud_config_from_zip(cloud_config, create_pyopenssl_context):
9497
secure_bundle = cloud_config['secure_connect_bundle']
9598
with ZipFile(secure_bundle) as zipfile:
9699
base_dir = os.path.dirname(secure_bundle)
97100
tmp_dir = tempfile.mkdtemp(dir=base_dir)
98101
try:
99102
zipfile.extractall(path=tmp_dir)
100-
return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config)
103+
return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config, create_pyopenssl_context)
101104
finally:
102105
shutil.rmtree(tmp_dir)
103106

104107

105-
def parse_cloud_config(path, cloud_config):
108+
def parse_cloud_config(path, cloud_config, create_pyopenssl_context):
106109
with open(path, 'r') as stream:
107110
data = json.load(stream)
108111

@@ -116,7 +119,11 @@ def parse_cloud_config(path, cloud_config):
116119
ca_cert_location = os.path.join(config_dir, 'ca.crt')
117120
cert_location = os.path.join(config_dir, 'cert')
118121
key_location = os.path.join(config_dir, 'key')
122+
# Regardless of if we create a pyopenssl context, we still need the builtin one
123+
# to connect to the metadata service
119124
config.ssl_context = _ssl_context_from_cert(ca_cert_location, cert_location, key_location)
125+
if create_pyopenssl_context:
126+
config.pyopenssl_context = _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location)
120127

121128
return config
122129

@@ -165,3 +172,17 @@ def _ssl_context_from_cert(ca_cert_location, cert_location, key_location):
165172
ssl_context.load_cert_chain(certfile=cert_location, keyfile=key_location)
166173

167174
return ssl_context
175+
176+
177+
def _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location):
178+
try:
179+
from OpenSSL import SSL
180+
except ImportError:
181+
return None
182+
ssl_context = SSL.Context(SSL.TLSv1_METHOD)
183+
ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok)
184+
ssl_context.use_certificate_file(cert_location)
185+
ssl_context.use_privatekey_file(key_location)
186+
ssl_context.load_verify_locations(ca_cert_location)
187+
188+
return ssl_context

cassandra/io/twistedreactor.py

Lines changed: 89 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,26 @@
1616
( https://twistedmatrix.com ).
1717
"""
1818
import atexit
19-
from functools import partial
2019
import logging
21-
from threading import Thread, Lock
2220
import time
23-
from twisted.internet import reactor, protocol
21+
from functools import partial
22+
from threading import Thread, Lock
2423
import weakref
2524

26-
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
25+
from twisted.internet import reactor, protocol
26+
from twisted.internet.endpoints import connectProtocol, TCP4ClientEndpoint, SSL4ClientEndpoint
27+
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
28+
from twisted.python.failure import Failure
29+
from zope.interface import implementer
2730

31+
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager, ConnectionException
2832

33+
try:
34+
from OpenSSL import SSL
35+
_HAS_SSL = True
36+
except ImportError as e:
37+
_HAS_SSL = False
38+
import_exception = e
2939
log = logging.getLogger(__name__)
3040

3141

@@ -42,8 +52,8 @@ class TwistedConnectionProtocol(protocol.Protocol):
4252
made events.
4353
"""
4454

45-
def __init__(self):
46-
self.connection = None
55+
def __init__(self, connection):
56+
self.connection = connection
4757

4858
def dataReceived(self, data):
4959
"""
@@ -55,62 +65,20 @@ def dataReceived(self, data):
5565
"""
5666
self.connection._iobuf.write(data)
5767
self.connection.handle_read()
68+
5869
def connectionMade(self):
5970
"""
6071
Callback function that is called when a connection has succeeded.
6172
6273
Reaches back to the Connection object and confirms that the connection
6374
is ready.
6475
"""
65-
try:
66-
# Non SSL connection
67-
self.connection = self.transport.connector.factory.conn
68-
except AttributeError:
69-
# SSL connection
70-
self.connection = self.transport.connector.factory.wrappedFactory.conn
71-
7276
self.connection.client_connection_made(self.transport)
7377

7478
def connectionLost(self, reason):
7579
# reason is a Failure instance
76-
self.connection.defunct(reason.value)
77-
78-
79-
class TwistedConnectionClientFactory(protocol.ClientFactory):
80-
81-
def __init__(self, connection):
82-
# ClientFactory does not define __init__() in parent classes
83-
# and does not inherit from object.
84-
self.conn = connection
85-
86-
def buildProtocol(self, addr):
87-
"""
88-
Twisted function that defines which kind of protocol to use
89-
in the ClientFactory.
90-
"""
91-
return TwistedConnectionProtocol()
92-
93-
def clientConnectionFailed(self, connector, reason):
94-
"""
95-
Overridden twisted callback which is called when the
96-
connection attempt fails.
97-
"""
98-
log.debug("Connect failed: %s", reason)
99-
self.conn.defunct(reason.value)
100-
101-
def clientConnectionLost(self, connector, reason):
102-
"""
103-
Overridden twisted callback which is called when the
104-
connection goes away (cleanly or otherwise).
105-
106-
It should be safe to call defunct() here instead of just close, because
107-
we can assume that if the connection was closed cleanly, there are no
108-
requests to error out. If this assumption turns out to be false, we
109-
can call close() instead of defunct() when "reason" is an appropriate
110-
type.
111-
"""
11280
log.debug("Connect lost: %s", reason)
113-
self.conn.defunct(reason.value)
81+
self.connection.defunct(reason.value)
11482

11583

11684
class TwistedLoop(object):
@@ -166,47 +134,46 @@ def _on_loop_timer(self):
166134
self._schedule_timeout(self._timers.next_timeout)
167135

168136

169-
try:
170-
from twisted.internet import ssl
171-
import OpenSSL.crypto
172-
from OpenSSL.crypto import load_certificate, FILETYPE_PEM
173-
174-
class _SSLContextFactory(ssl.ClientContextFactory):
175-
def __init__(self, ssl_options, check_hostname, host):
176-
self.ssl_options = ssl_options
177-
self.check_hostname = check_hostname
178-
self.host = host
179-
180-
def getContext(self):
181-
# This version has to be OpenSSL.SSL.DESIRED_VERSION
182-
# instead of ssl.DESIRED_VERSION as in other loops
183-
self.method = self.ssl_options["ssl_version"]
184-
context = ssl.ClientContextFactory.getContext(self)
137+
@implementer(IOpenSSLClientConnectionCreator)
138+
class _SSLCreator(object):
139+
def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout):
140+
self.endpoint = endpoint
141+
self.ssl_options = ssl_options
142+
self.check_hostname = check_hostname
143+
self.timeout = timeout
144+
145+
if ssl_context:
146+
self.context = ssl_context
147+
else:
148+
self.context = SSL.Context(SSL.TLSv1_METHOD)
185149
if "certfile" in self.ssl_options:
186-
context.use_certificate_file(self.ssl_options["certfile"])
150+
self.context.use_certificate_file(self.ssl_options["certfile"])
187151
if "keyfile" in self.ssl_options:
188-
context.use_privatekey_file(self.ssl_options["keyfile"])
152+
self.context.use_privatekey_file(self.ssl_options["keyfile"])
189153
if "ca_certs" in self.ssl_options:
190-
x509 = load_certificate(FILETYPE_PEM, open(self.ssl_options["ca_certs"]).read())
191-
store = context.get_cert_store()
192-
store.add_cert(x509)
154+
self.context.load_verify_locations(self.ssl_options["ca_certs"])
193155
if "cert_reqs" in self.ssl_options:
194-
# This expects OpenSSL.SSL.VERIFY_NONE/OpenSSL.SSL.VERIFY_PEER
195-
# or OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT
196-
context.set_verify(self.ssl_options["cert_reqs"],
197-
callback=self.verify_callback)
198-
return context
199-
200-
def verify_callback(self, connection, x509, errnum, errdepth, ok):
201-
if ok:
202-
if self.check_hostname and self.host != x509.get_subject().commonName:
203-
return False
204-
return ok
156+
self.context.set_verify(
157+
self.ssl_options["cert_reqs"],
158+
callback=self.verify_callback
159+
)
160+
self.context.set_info_callback(self.info_callback)
205161

206-
_HAS_SSL = True
162+
def verify_callback(self, connection, x509, errnum, errdepth, ok):
163+
return ok
207164

208-
except ImportError as e:
209-
_HAS_SSL = False
165+
def info_callback(self, connection, where, ret):
166+
if where & SSL.SSL_CB_HANDSHAKE_DONE:
167+
if self.check_hostname and self.endpoint.address != connection.get_peer_certificate().get_subject().commonName:
168+
transport = connection.get_app_data()
169+
transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint)))
170+
171+
def clientConnectionForTLS(self, tlsProtocol):
172+
connection = SSL.Connection(self.context, None)
173+
connection.set_app_data(tlsProtocol)
174+
if self.ssl_options and "server_hostname" in self.ssl_options:
175+
connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))
176+
return connection
210177

211178

212179
class TwistedConnection(Connection):
@@ -246,29 +213,48 @@ def __init__(self, *args, **kwargs):
246213
reactor.callFromThread(self.add_connection)
247214
self._loop.maybe_start()
248215

249-
def add_connection(self):
250-
"""
251-
Convenience function to connect and store the resulting
252-
connector.
253-
"""
254-
if self.ssl_options:
255-
216+
def _check_pyopenssl(self):
217+
if self.ssl_context or self.ssl_options:
256218
if not _HAS_SSL:
257219
raise ImportError(
258-
str(e) +
220+
str(import_exception) +
259221
', pyOpenSSL must be installed to enable SSL support with the Twisted event loop'
260222
)
261223

262-
self.connector = reactor.connectSSL(
263-
host=self.endpoint.address, port=self.port,
264-
factory=TwistedConnectionClientFactory(self),
265-
contextFactory=_SSLContextFactory(self.ssl_options, self._check_hostname, self.endpoint.address),
266-
timeout=self.connect_timeout)
224+
def add_connection(self):
225+
"""
226+
Convenience function to connect and store the resulting
227+
connector.
228+
"""
229+
host, port = self.endpoint.resolve()
230+
if self.ssl_context or self.ssl_options:
231+
# Can't use optionsForClientTLS here because it *forces* hostname verification.
232+
# Cool they enforce strong security, but we have to be able to turn it off
233+
self._check_pyopenssl()
234+
235+
ssl_connection_creator = _SSLCreator(
236+
self.endpoint,
237+
self.ssl_context if self.ssl_context else None,
238+
self.ssl_options,
239+
self._check_hostname,
240+
self.connect_timeout,
241+
)
242+
243+
endpoint = SSL4ClientEndpoint(
244+
reactor,
245+
host,
246+
port,
247+
sslContextFactory=ssl_connection_creator,
248+
timeout=self.connect_timeout,
249+
)
267250
else:
268-
self.connector = reactor.connectTCP(
269-
host=self.endpoint.address, port=self.port,
270-
factory=TwistedConnectionClientFactory(self),
271-
timeout=self.connect_timeout)
251+
endpoint = TCP4ClientEndpoint(
252+
reactor,
253+
host,
254+
port,
255+
timeout=self.connect_timeout
256+
)
257+
connectProtocol(endpoint, TwistedConnectionProtocol(self))
272258

273259
def client_connection_made(self, transport):
274260
"""
@@ -290,7 +276,7 @@ def close(self):
290276
self.is_closed = True
291277

292278
log.debug("Closing connection (%s) to %s", id(self), self.endpoint)
293-
reactor.callFromThread(self.connector.disconnect)
279+
reactor.callFromThread(self.transport.connector.disconnect)
294280
log.debug("Closed socket to %s", self.endpoint)
295281

296282
if not self.is_defunct:

docs/cloud.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,5 @@ Limitations
3434

3535
Event loops
3636
^^^^^^^^^^^
37-
Twisted and Evenlet aren't supported yet. These event loops are still using the old way to configure
37+
Evenlet isn't supported yet. Eventlet still uses the old way to configure
3838
SSL (ssl_options), which is not compatible with the secure connect bundle provided by Apollo.

0 commit comments

Comments
 (0)