16
16
( https://twistedmatrix.com ).
17
17
"""
18
18
import atexit
19
- from functools import partial
20
19
import logging
21
- from threading import Thread , Lock
22
20
import time
23
- from twisted .internet import reactor , protocol
21
+ from functools import partial
22
+ from threading import Thread , Lock
24
23
import weakref
25
24
26
- from cassandra .connection import Connection , ConnectionShutdown , Timer , TimerManager
25
+ from twisted .internet import reactor , protocol
26
+ from twisted .internet .endpoints import connectProtocol , TCP4ClientEndpoint , SSL4ClientEndpoint
27
+ from twisted .internet .interfaces import IOpenSSLClientConnectionCreator
28
+ from twisted .python .failure import Failure
29
+ from zope .interface import implementer
27
30
31
+ from cassandra .connection import Connection , ConnectionShutdown , Timer , TimerManager , ConnectionException
28
32
33
+ try :
34
+ from OpenSSL import SSL
35
+ _HAS_SSL = True
36
+ except ImportError as e :
37
+ _HAS_SSL = False
38
+ import_exception = e
29
39
log = logging .getLogger (__name__ )
30
40
31
41
@@ -42,8 +52,8 @@ class TwistedConnectionProtocol(protocol.Protocol):
42
52
made events.
43
53
"""
44
54
45
- def __init__ (self ):
46
- self .connection = None
55
+ def __init__ (self , connection ):
56
+ self .connection = connection
47
57
48
58
def dataReceived (self , data ):
49
59
"""
@@ -55,62 +65,20 @@ def dataReceived(self, data):
55
65
"""
56
66
self .connection ._iobuf .write (data )
57
67
self .connection .handle_read ()
68
+
58
69
def connectionMade (self ):
59
70
"""
60
71
Callback function that is called when a connection has succeeded.
61
72
62
73
Reaches back to the Connection object and confirms that the connection
63
74
is ready.
64
75
"""
65
- try :
66
- # Non SSL connection
67
- self .connection = self .transport .connector .factory .conn
68
- except AttributeError :
69
- # SSL connection
70
- self .connection = self .transport .connector .factory .wrappedFactory .conn
71
-
72
76
self .connection .client_connection_made (self .transport )
73
77
74
78
def connectionLost (self , reason ):
75
79
# reason is a Failure instance
76
- self .connection .defunct (reason .value )
77
-
78
-
79
- class TwistedConnectionClientFactory (protocol .ClientFactory ):
80
-
81
- def __init__ (self , connection ):
82
- # ClientFactory does not define __init__() in parent classes
83
- # and does not inherit from object.
84
- self .conn = connection
85
-
86
- def buildProtocol (self , addr ):
87
- """
88
- Twisted function that defines which kind of protocol to use
89
- in the ClientFactory.
90
- """
91
- return TwistedConnectionProtocol ()
92
-
93
- def clientConnectionFailed (self , connector , reason ):
94
- """
95
- Overridden twisted callback which is called when the
96
- connection attempt fails.
97
- """
98
- log .debug ("Connect failed: %s" , reason )
99
- self .conn .defunct (reason .value )
100
-
101
- def clientConnectionLost (self , connector , reason ):
102
- """
103
- Overridden twisted callback which is called when the
104
- connection goes away (cleanly or otherwise).
105
-
106
- It should be safe to call defunct() here instead of just close, because
107
- we can assume that if the connection was closed cleanly, there are no
108
- requests to error out. If this assumption turns out to be false, we
109
- can call close() instead of defunct() when "reason" is an appropriate
110
- type.
111
- """
112
80
log .debug ("Connect lost: %s" , reason )
113
- self .conn .defunct (reason .value )
81
+ self .connection .defunct (reason .value )
114
82
115
83
116
84
class TwistedLoop (object ):
@@ -166,47 +134,46 @@ def _on_loop_timer(self):
166
134
self ._schedule_timeout (self ._timers .next_timeout )
167
135
168
136
169
- try :
170
- from twisted .internet import ssl
171
- import OpenSSL .crypto
172
- from OpenSSL .crypto import load_certificate , FILETYPE_PEM
173
-
174
- class _SSLContextFactory (ssl .ClientContextFactory ):
175
- def __init__ (self , ssl_options , check_hostname , host ):
176
- self .ssl_options = ssl_options
177
- self .check_hostname = check_hostname
178
- self .host = host
179
-
180
- def getContext (self ):
181
- # This version has to be OpenSSL.SSL.DESIRED_VERSION
182
- # instead of ssl.DESIRED_VERSION as in other loops
183
- self .method = self .ssl_options ["ssl_version" ]
184
- context = ssl .ClientContextFactory .getContext (self )
137
+ @implementer (IOpenSSLClientConnectionCreator )
138
+ class _SSLCreator (object ):
139
+ def __init__ (self , endpoint , ssl_context , ssl_options , check_hostname , timeout ):
140
+ self .endpoint = endpoint
141
+ self .ssl_options = ssl_options
142
+ self .check_hostname = check_hostname
143
+ self .timeout = timeout
144
+
145
+ if ssl_context :
146
+ self .context = ssl_context
147
+ else :
148
+ self .context = SSL .Context (SSL .TLSv1_METHOD )
185
149
if "certfile" in self .ssl_options :
186
- context .use_certificate_file (self .ssl_options ["certfile" ])
150
+ self . context .use_certificate_file (self .ssl_options ["certfile" ])
187
151
if "keyfile" in self .ssl_options :
188
- context .use_privatekey_file (self .ssl_options ["keyfile" ])
152
+ self . context .use_privatekey_file (self .ssl_options ["keyfile" ])
189
153
if "ca_certs" in self .ssl_options :
190
- x509 = load_certificate (FILETYPE_PEM , open (self .ssl_options ["ca_certs" ]).read ())
191
- store = context .get_cert_store ()
192
- store .add_cert (x509 )
154
+ self .context .load_verify_locations (self .ssl_options ["ca_certs" ])
193
155
if "cert_reqs" in self .ssl_options :
194
- # This expects OpenSSL.SSL.VERIFY_NONE/OpenSSL.SSL.VERIFY_PEER
195
- # or OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT
196
- context .set_verify (self .ssl_options ["cert_reqs" ],
197
- callback = self .verify_callback )
198
- return context
199
-
200
- def verify_callback (self , connection , x509 , errnum , errdepth , ok ):
201
- if ok :
202
- if self .check_hostname and self .host != x509 .get_subject ().commonName :
203
- return False
204
- return ok
156
+ self .context .set_verify (
157
+ self .ssl_options ["cert_reqs" ],
158
+ callback = self .verify_callback
159
+ )
160
+ self .context .set_info_callback (self .info_callback )
205
161
206
- _HAS_SSL = True
162
+ def verify_callback (self , connection , x509 , errnum , errdepth , ok ):
163
+ return ok
207
164
208
- except ImportError as e :
209
- _HAS_SSL = False
165
+ def info_callback (self , connection , where , ret ):
166
+ if where & SSL .SSL_CB_HANDSHAKE_DONE :
167
+ if self .check_hostname and self .endpoint .address != connection .get_peer_certificate ().get_subject ().commonName :
168
+ transport = connection .get_app_data ()
169
+ transport .failVerification (Failure (ConnectionException ("Hostname verification failed" , self .endpoint )))
170
+
171
+ def clientConnectionForTLS (self , tlsProtocol ):
172
+ connection = SSL .Connection (self .context , None )
173
+ connection .set_app_data (tlsProtocol )
174
+ if self .ssl_options and "server_hostname" in self .ssl_options :
175
+ connection .set_tlsext_host_name (self .ssl_options ['server_hostname' ].encode ('ascii' ))
176
+ return connection
210
177
211
178
212
179
class TwistedConnection (Connection ):
@@ -246,29 +213,48 @@ def __init__(self, *args, **kwargs):
246
213
reactor .callFromThread (self .add_connection )
247
214
self ._loop .maybe_start ()
248
215
249
- def add_connection (self ):
250
- """
251
- Convenience function to connect and store the resulting
252
- connector.
253
- """
254
- if self .ssl_options :
255
-
216
+ def _check_pyopenssl (self ):
217
+ if self .ssl_context or self .ssl_options :
256
218
if not _HAS_SSL :
257
219
raise ImportError (
258
- str (e ) +
220
+ str (import_exception ) +
259
221
', pyOpenSSL must be installed to enable SSL support with the Twisted event loop'
260
222
)
261
223
262
- self .connector = reactor .connectSSL (
263
- host = self .endpoint .address , port = self .port ,
264
- factory = TwistedConnectionClientFactory (self ),
265
- contextFactory = _SSLContextFactory (self .ssl_options , self ._check_hostname , self .endpoint .address ),
266
- timeout = self .connect_timeout )
224
+ def add_connection (self ):
225
+ """
226
+ Convenience function to connect and store the resulting
227
+ connector.
228
+ """
229
+ host , port = self .endpoint .resolve ()
230
+ if self .ssl_context or self .ssl_options :
231
+ # Can't use optionsForClientTLS here because it *forces* hostname verification.
232
+ # Cool they enforce strong security, but we have to be able to turn it off
233
+ self ._check_pyopenssl ()
234
+
235
+ ssl_connection_creator = _SSLCreator (
236
+ self .endpoint ,
237
+ self .ssl_context if self .ssl_context else None ,
238
+ self .ssl_options ,
239
+ self ._check_hostname ,
240
+ self .connect_timeout ,
241
+ )
242
+
243
+ endpoint = SSL4ClientEndpoint (
244
+ reactor ,
245
+ host ,
246
+ port ,
247
+ sslContextFactory = ssl_connection_creator ,
248
+ timeout = self .connect_timeout ,
249
+ )
267
250
else :
268
- self .connector = reactor .connectTCP (
269
- host = self .endpoint .address , port = self .port ,
270
- factory = TwistedConnectionClientFactory (self ),
271
- timeout = self .connect_timeout )
251
+ endpoint = TCP4ClientEndpoint (
252
+ reactor ,
253
+ host ,
254
+ port ,
255
+ timeout = self .connect_timeout
256
+ )
257
+ connectProtocol (endpoint , TwistedConnectionProtocol (self ))
272
258
273
259
def client_connection_made (self , transport ):
274
260
"""
@@ -290,7 +276,7 @@ def close(self):
290
276
self .is_closed = True
291
277
292
278
log .debug ("Closing connection (%s) to %s" , id (self ), self .endpoint )
293
- reactor .callFromThread (self .connector .disconnect )
279
+ reactor .callFromThread (self .transport . connector .disconnect )
294
280
log .debug ("Closed socket to %s" , self .endpoint )
295
281
296
282
if not self .is_defunct :
0 commit comments