Skip to content

Commit 14ba2d1

Browse files
committed
[PYTHON-1161] Support Twisted context.
Refactored Twisted code to use newer Endpoint APIs. This allowed us to remove ClientFactory implementation completely. Also used IOpenSSLClientConnectionCreator interface instead of implementing a ContextFactory. This allows us to use our own context and pyOpenSSL connection instead of relying on twisted to create it for us.
1 parent 78f634c commit 14ba2d1

File tree

4 files changed

+218
-199
lines changed

4 files changed

+218
-199
lines changed

cassandra/io/twistedreactor.py

Lines changed: 87 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,26 @@
1616
( https://twistedmatrix.com ).
1717
"""
1818
import atexit
19-
from functools import partial
2019
import logging
21-
from threading import Thread, Lock
2220
import time
23-
from twisted.internet import reactor, protocol
21+
from functools import partial
22+
from threading import Thread, Lock
2423
import weakref
2524

26-
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
25+
from twisted.internet import reactor, protocol
26+
from twisted.internet.endpoints import connectProtocol, TCP4ClientEndpoint, SSL4ClientEndpoint
27+
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
28+
from twisted.python.failure import Failure
29+
from zope.interface import implementer
2730

31+
from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager
2832

33+
try:
34+
from OpenSSL import SSL
35+
_HAS_SSL = True
36+
except ImportError as e:
37+
_HAS_SSL = False
38+
import_exception = e
2939
log = logging.getLogger(__name__)
3040

3141

@@ -42,8 +52,8 @@ class TwistedConnectionProtocol(protocol.Protocol):
4252
made events.
4353
"""
4454

45-
def __init__(self):
46-
self.connection = None
55+
def __init__(self, connection):
56+
self.connection = connection
4757

4858
def dataReceived(self, data):
4959
"""
@@ -55,64 +65,21 @@ def dataReceived(self, data):
5565
"""
5666
self.connection._iobuf.write(data)
5767
self.connection.handle_read()
68+
5869
def connectionMade(self):
5970
"""
6071
Callback function that is called when a connection has succeeded.
6172
6273
Reaches back to the Connection object and confirms that the connection
6374
is ready.
6475
"""
65-
try:
66-
# Non SSL connection
67-
self.connection = self.transport.connector.factory.conn
68-
except AttributeError:
69-
# SSL connection
70-
self.connection = self.transport.connector.factory.wrappedFactory.conn
71-
7276
self.connection.client_connection_made(self.transport)
7377

7478
def connectionLost(self, reason):
7579
# reason is a Failure instance
7680
self.connection.defunct(reason.value)
7781

7882

79-
class TwistedConnectionClientFactory(protocol.ClientFactory):
80-
81-
def __init__(self, connection):
82-
# ClientFactory does not define __init__() in parent classes
83-
# and does not inherit from object.
84-
self.conn = connection
85-
86-
def buildProtocol(self, addr):
87-
"""
88-
Twisted function that defines which kind of protocol to use
89-
in the ClientFactory.
90-
"""
91-
return TwistedConnectionProtocol()
92-
93-
def clientConnectionFailed(self, connector, reason):
94-
"""
95-
Overridden twisted callback which is called when the
96-
connection attempt fails.
97-
"""
98-
log.debug("Connect failed: %s", reason)
99-
self.conn.defunct(reason.value)
100-
101-
def clientConnectionLost(self, connector, reason):
102-
"""
103-
Overridden twisted callback which is called when the
104-
connection goes away (cleanly or otherwise).
105-
106-
It should be safe to call defunct() here instead of just close, because
107-
we can assume that if the connection was closed cleanly, there are no
108-
requests to error out. If this assumption turns out to be false, we
109-
can call close() instead of defunct() when "reason" is an appropriate
110-
type.
111-
"""
112-
log.debug("Connect lost: %s", reason)
113-
self.conn.defunct(reason.value)
114-
115-
11683
class TwistedLoop(object):
11784

11885
_lock = None
@@ -166,47 +133,46 @@ def _on_loop_timer(self):
166133
self._schedule_timeout(self._timers.next_timeout)
167134

168135

169-
try:
170-
from twisted.internet import ssl
171-
import OpenSSL.crypto
172-
from OpenSSL.crypto import load_certificate, FILETYPE_PEM
173-
174-
class _SSLContextFactory(ssl.ClientContextFactory):
175-
def __init__(self, ssl_options, check_hostname, host):
176-
self.ssl_options = ssl_options
177-
self.check_hostname = check_hostname
178-
self.host = host
179-
180-
def getContext(self):
181-
# This version has to be OpenSSL.SSL.DESIRED_VERSION
182-
# instead of ssl.DESIRED_VERSION as in other loops
183-
self.method = self.ssl_options["ssl_version"]
184-
context = ssl.ClientContextFactory.getContext(self)
136+
@implementer(IOpenSSLClientConnectionCreator)
137+
class SSLCreator(object):
138+
def __init__(self, host, ssl_context, ssl_options, check_hostname, timeout):
139+
self.host = host
140+
self.ssl_options = ssl_options
141+
self.check_hostname = check_hostname
142+
self.timeout = timeout
143+
144+
if ssl_context:
145+
self.context = ssl_context
146+
else:
147+
self.context = SSL.Context(SSL.TLSv1_METHOD)
185148
if "certfile" in self.ssl_options:
186-
context.use_certificate_file(self.ssl_options["certfile"])
149+
self.context.use_certificate_file(self.ssl_options["certfile"])
187150
if "keyfile" in self.ssl_options:
188-
context.use_privatekey_file(self.ssl_options["keyfile"])
151+
self.context.use_privatekey_file(self.ssl_options["keyfile"])
189152
if "ca_certs" in self.ssl_options:
190-
x509 = load_certificate(FILETYPE_PEM, open(self.ssl_options["ca_certs"]).read())
191-
store = context.get_cert_store()
192-
store.add_cert(x509)
153+
self.context.load_verify_locations(self.ssl_options["ca_certs"])
193154
if "cert_reqs" in self.ssl_options:
194-
# This expects OpenSSL.SSL.VERIFY_NONE/OpenSSL.SSL.VERIFY_PEER
195-
# or OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT
196-
context.set_verify(self.ssl_options["cert_reqs"],
197-
callback=self.verify_callback)
198-
return context
199-
200-
def verify_callback(self, connection, x509, errnum, errdepth, ok):
201-
if ok:
202-
if self.check_hostname and self.host != x509.get_subject().commonName:
203-
return False
204-
return ok
155+
self.context.set_verify(
156+
self.ssl_options["cert_reqs"],
157+
callback=self.verify_callback
158+
)
159+
self.context.set_info_callback(self.info_callback)
205160

206-
_HAS_SSL = True
161+
def verify_callback(self, connection, x509, errnum, errdepth, ok):
162+
return ok
207163

208-
except ImportError as e:
209-
_HAS_SSL = False
164+
def info_callback(self, connection, where, ret):
165+
if where & SSL.SSL_CB_HANDSHAKE_DONE:
166+
if self.check_hostname and self.host != connection.get_peer_certificate().get_subject().commonName:
167+
transport = connection.get_app_data()
168+
transport.failVerification(Failure(Exception("Hostname verification failed")))
169+
170+
def clientConnectionForTLS(self, tlsProtocol):
171+
connection = SSL.Connection(self.context, None)
172+
connection.set_app_data(tlsProtocol)
173+
if self.ssl_options and "server_hostname" in self.ssl_options:
174+
connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))
175+
return connection
210176

211177

212178
class TwistedConnection(Connection):
@@ -246,29 +212,48 @@ def __init__(self, *args, **kwargs):
246212
reactor.callFromThread(self.add_connection)
247213
self._loop.maybe_start()
248214

249-
def add_connection(self):
250-
"""
251-
Convenience function to connect and store the resulting
252-
connector.
253-
"""
215+
def _check_pyopenssl(self):
254216
if self.ssl_options:
255-
256217
if not _HAS_SSL:
257218
raise ImportError(
258-
str(e) +
219+
str(import_exception) +
259220
', pyOpenSSL must be installed to enable SSL support with the Twisted event loop'
260221
)
261222

262-
self.connector = reactor.connectSSL(
263-
host=self.endpoint.address, port=self.port,
264-
factory=TwistedConnectionClientFactory(self),
265-
contextFactory=_SSLContextFactory(self.ssl_options, self._check_hostname, self.endpoint.address),
266-
timeout=self.connect_timeout)
223+
def add_connection(self):
224+
"""
225+
Convenience function to connect and store the resulting
226+
connector.
227+
"""
228+
host, port = self.endpoint.resolve()
229+
if self.ssl_context or self.ssl_options:
230+
# Can't use optionsForClientTLS here because it *forces* hostname verification.
231+
# Cool they enforce strong security, but we have to be able to turn it off
232+
self._check_pyopenssl()
233+
234+
ssl_options = SSLCreator(
235+
self.endpoint.address,
236+
self.ssl_context if self.ssl_context else None,
237+
self.ssl_options,
238+
self._check_hostname,
239+
self.connect_timeout,
240+
)
241+
242+
point = SSL4ClientEndpoint(
243+
reactor,
244+
host,
245+
port,
246+
sslContextFactory=ssl_options,
247+
timeout=self.connect_timeout,
248+
)
267249
else:
268-
self.connector = reactor.connectTCP(
269-
host=self.endpoint.address, port=self.port,
270-
factory=TwistedConnectionClientFactory(self),
271-
timeout=self.connect_timeout)
250+
point = TCP4ClientEndpoint(
251+
reactor,
252+
host,
253+
port,
254+
timeout=self.connect_timeout
255+
)
256+
connectProtocol(point, TwistedConnectionProtocol(self))
272257

273258
def client_connection_made(self, transport):
274259
"""
@@ -290,7 +275,7 @@ def close(self):
290275
self.is_closed = True
291276

292277
log.debug("Closing connection (%s) to %s", id(self), self.endpoint)
293-
reactor.callFromThread(self.connector.disconnect)
278+
reactor.callFromThread(self.transport.connector.disconnect)
294279
log.debug("Closed socket to %s", self.endpoint)
295280

296281
if not self.is_defunct:

test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ cython>=0.20,<0.30
1515
packaging
1616
futurist; python_version >= '3.7'
1717
asynctest; python_version > '3.4'
18+
pyopenssl

0 commit comments

Comments
 (0)