16
16
( https://twistedmatrix.com ).
17
17
"""
18
18
import atexit
19
- from functools import partial
20
19
import logging
21
- from threading import Thread , Lock
22
20
import time
23
- from twisted .internet import reactor , protocol
21
+ from functools import partial
22
+ from threading import Thread , Lock
24
23
import weakref
25
24
26
- from cassandra .connection import Connection , ConnectionShutdown , Timer , TimerManager
25
+ from twisted .internet import reactor , protocol
26
+ from twisted .internet .endpoints import connectProtocol , TCP4ClientEndpoint , SSL4ClientEndpoint
27
+ from twisted .internet .interfaces import IOpenSSLClientConnectionCreator
28
+ from twisted .python .failure import Failure
29
+ from zope .interface import implementer
27
30
31
+ from cassandra .connection import Connection , ConnectionShutdown , Timer , TimerManager
28
32
33
+ try :
34
+ from OpenSSL import SSL
35
+ _HAS_SSL = True
36
+ except ImportError as e :
37
+ _HAS_SSL = False
38
+ import_exception = e
29
39
log = logging .getLogger (__name__ )
30
40
31
41
@@ -42,8 +52,8 @@ class TwistedConnectionProtocol(protocol.Protocol):
42
52
made events.
43
53
"""
44
54
45
- def __init__ (self ):
46
- self .connection = None
55
+ def __init__ (self , connection ):
56
+ self .connection = connection
47
57
48
58
def dataReceived (self , data ):
49
59
"""
@@ -55,64 +65,21 @@ def dataReceived(self, data):
55
65
"""
56
66
self .connection ._iobuf .write (data )
57
67
self .connection .handle_read ()
68
+
58
69
def connectionMade (self ):
59
70
"""
60
71
Callback function that is called when a connection has succeeded.
61
72
62
73
Reaches back to the Connection object and confirms that the connection
63
74
is ready.
64
75
"""
65
- try :
66
- # Non SSL connection
67
- self .connection = self .transport .connector .factory .conn
68
- except AttributeError :
69
- # SSL connection
70
- self .connection = self .transport .connector .factory .wrappedFactory .conn
71
-
72
76
self .connection .client_connection_made (self .transport )
73
77
74
78
def connectionLost (self , reason ):
75
79
# reason is a Failure instance
76
80
self .connection .defunct (reason .value )
77
81
78
82
79
- class TwistedConnectionClientFactory (protocol .ClientFactory ):
80
-
81
- def __init__ (self , connection ):
82
- # ClientFactory does not define __init__() in parent classes
83
- # and does not inherit from object.
84
- self .conn = connection
85
-
86
- def buildProtocol (self , addr ):
87
- """
88
- Twisted function that defines which kind of protocol to use
89
- in the ClientFactory.
90
- """
91
- return TwistedConnectionProtocol ()
92
-
93
- def clientConnectionFailed (self , connector , reason ):
94
- """
95
- Overridden twisted callback which is called when the
96
- connection attempt fails.
97
- """
98
- log .debug ("Connect failed: %s" , reason )
99
- self .conn .defunct (reason .value )
100
-
101
- def clientConnectionLost (self , connector , reason ):
102
- """
103
- Overridden twisted callback which is called when the
104
- connection goes away (cleanly or otherwise).
105
-
106
- It should be safe to call defunct() here instead of just close, because
107
- we can assume that if the connection was closed cleanly, there are no
108
- requests to error out. If this assumption turns out to be false, we
109
- can call close() instead of defunct() when "reason" is an appropriate
110
- type.
111
- """
112
- log .debug ("Connect lost: %s" , reason )
113
- self .conn .defunct (reason .value )
114
-
115
-
116
83
class TwistedLoop (object ):
117
84
118
85
_lock = None
@@ -166,47 +133,46 @@ def _on_loop_timer(self):
166
133
self ._schedule_timeout (self ._timers .next_timeout )
167
134
168
135
169
- try :
170
- from twisted .internet import ssl
171
- import OpenSSL .crypto
172
- from OpenSSL .crypto import load_certificate , FILETYPE_PEM
173
-
174
- class _SSLContextFactory (ssl .ClientContextFactory ):
175
- def __init__ (self , ssl_options , check_hostname , host ):
176
- self .ssl_options = ssl_options
177
- self .check_hostname = check_hostname
178
- self .host = host
179
-
180
- def getContext (self ):
181
- # This version has to be OpenSSL.SSL.DESIRED_VERSION
182
- # instead of ssl.DESIRED_VERSION as in other loops
183
- self .method = self .ssl_options ["ssl_version" ]
184
- context = ssl .ClientContextFactory .getContext (self )
136
+ @implementer (IOpenSSLClientConnectionCreator )
137
+ class SSLCreator (object ):
138
+ def __init__ (self , host , ssl_context , ssl_options , check_hostname , timeout ):
139
+ self .host = host
140
+ self .ssl_options = ssl_options
141
+ self .check_hostname = check_hostname
142
+ self .timeout = timeout
143
+
144
+ if ssl_context :
145
+ self .context = ssl_context
146
+ else :
147
+ self .context = SSL .Context (SSL .TLSv1_METHOD )
185
148
if "certfile" in self .ssl_options :
186
- context .use_certificate_file (self .ssl_options ["certfile" ])
149
+ self . context .use_certificate_file (self .ssl_options ["certfile" ])
187
150
if "keyfile" in self .ssl_options :
188
- context .use_privatekey_file (self .ssl_options ["keyfile" ])
151
+ self . context .use_privatekey_file (self .ssl_options ["keyfile" ])
189
152
if "ca_certs" in self .ssl_options :
190
- x509 = load_certificate (FILETYPE_PEM , open (self .ssl_options ["ca_certs" ]).read ())
191
- store = context .get_cert_store ()
192
- store .add_cert (x509 )
153
+ self .context .load_verify_locations (self .ssl_options ["ca_certs" ])
193
154
if "cert_reqs" in self .ssl_options :
194
- # This expects OpenSSL.SSL.VERIFY_NONE/OpenSSL.SSL.VERIFY_PEER
195
- # or OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT
196
- context .set_verify (self .ssl_options ["cert_reqs" ],
197
- callback = self .verify_callback )
198
- return context
199
-
200
- def verify_callback (self , connection , x509 , errnum , errdepth , ok ):
201
- if ok :
202
- if self .check_hostname and self .host != x509 .get_subject ().commonName :
203
- return False
204
- return ok
155
+ self .context .set_verify (
156
+ self .ssl_options ["cert_reqs" ],
157
+ callback = self .verify_callback
158
+ )
159
+ self .context .set_info_callback (self .info_callback )
205
160
206
- _HAS_SSL = True
161
+ def verify_callback (self , connection , x509 , errnum , errdepth , ok ):
162
+ return ok
207
163
208
- except ImportError as e :
209
- _HAS_SSL = False
164
+ def info_callback (self , connection , where , ret ):
165
+ if where & SSL .SSL_CB_HANDSHAKE_DONE :
166
+ if self .check_hostname and self .host != connection .get_peer_certificate ().get_subject ().commonName :
167
+ transport = connection .get_app_data ()
168
+ transport .failVerification (Failure (Exception ("Hostname verification failed" )))
169
+
170
+ def clientConnectionForTLS (self , tlsProtocol ):
171
+ connection = SSL .Connection (self .context , None )
172
+ connection .set_app_data (tlsProtocol )
173
+ if self .ssl_options and "server_hostname" in self .ssl_options :
174
+ connection .set_tlsext_host_name (self .ssl_options ['server_hostname' ].encode ('ascii' ))
175
+ return connection
210
176
211
177
212
178
class TwistedConnection (Connection ):
@@ -246,29 +212,48 @@ def __init__(self, *args, **kwargs):
246
212
reactor .callFromThread (self .add_connection )
247
213
self ._loop .maybe_start ()
248
214
249
- def add_connection (self ):
250
- """
251
- Convenience function to connect and store the resulting
252
- connector.
253
- """
215
+ def _check_pyopenssl (self ):
254
216
if self .ssl_options :
255
-
256
217
if not _HAS_SSL :
257
218
raise ImportError (
258
- str (e ) +
219
+ str (import_exception ) +
259
220
', pyOpenSSL must be installed to enable SSL support with the Twisted event loop'
260
221
)
261
222
262
- self .connector = reactor .connectSSL (
263
- host = self .endpoint .address , port = self .port ,
264
- factory = TwistedConnectionClientFactory (self ),
265
- contextFactory = _SSLContextFactory (self .ssl_options , self ._check_hostname , self .endpoint .address ),
266
- timeout = self .connect_timeout )
223
+ def add_connection (self ):
224
+ """
225
+ Convenience function to connect and store the resulting
226
+ connector.
227
+ """
228
+ host , port = self .endpoint .resolve ()
229
+ if self .ssl_context or self .ssl_options :
230
+ # Can't use optionsForClientTLS here because it *forces* hostname verification.
231
+ # Cool they enforce strong security, but we have to be able to turn it off
232
+ self ._check_pyopenssl ()
233
+
234
+ ssl_options = SSLCreator (
235
+ self .endpoint .address ,
236
+ self .ssl_context if self .ssl_context else None ,
237
+ self .ssl_options ,
238
+ self ._check_hostname ,
239
+ self .connect_timeout ,
240
+ )
241
+
242
+ point = SSL4ClientEndpoint (
243
+ reactor ,
244
+ host ,
245
+ port ,
246
+ sslContextFactory = ssl_options ,
247
+ timeout = self .connect_timeout ,
248
+ )
267
249
else :
268
- self .connector = reactor .connectTCP (
269
- host = self .endpoint .address , port = self .port ,
270
- factory = TwistedConnectionClientFactory (self ),
271
- timeout = self .connect_timeout )
250
+ point = TCP4ClientEndpoint (
251
+ reactor ,
252
+ host ,
253
+ port ,
254
+ timeout = self .connect_timeout
255
+ )
256
+ connectProtocol (point , TwistedConnectionProtocol (self ))
272
257
273
258
def client_connection_made (self , transport ):
274
259
"""
@@ -290,7 +275,7 @@ def close(self):
290
275
self .is_closed = True
291
276
292
277
log .debug ("Closing connection (%s) to %s" , id (self ), self .endpoint )
293
- reactor .callFromThread (self .connector .disconnect )
278
+ reactor .callFromThread (self .transport . connector .disconnect )
294
279
log .debug ("Closed socket to %s" , self .endpoint )
295
280
296
281
if not self .is_defunct :
0 commit comments