19
19
import re
20
20
21
21
from OpenSSL import SSL as ossl
22
+ from service_identity .pyopenssl import verify_hostname as _verify
22
23
23
24
CERT_NONE = ossl .VERIFY_NONE
24
25
CERT_REQUIRED = ossl .VERIFY_PEER | ossl .VERIFY_FAIL_IF_NO_PEER_CERT
38
39
for bit in [31 ] + list (range (10 )): # TODO figure out the names of these other flags
39
40
OP_ALL |= 1 << bit
40
41
41
- HAS_NPN = False # TODO
42
+ HAS_NPN = True
42
43
43
44
def _proxy (method ):
44
45
return lambda self , * args , ** kwargs : getattr (self ._conn , method )(* args , ** kwargs )
@@ -51,102 +52,11 @@ class CertificateError(SSLError):
51
52
pass
52
53
53
54
54
- # lifted from the Python 3.4 stdlib
55
- def _dnsname_match (dn , hostname , max_wildcards = 1 ):
55
+ def verify_hostname (ssl_sock , server_hostname ):
56
56
"""
57
- Matching according to RFC 6125, section 6.4.3.
58
-
59
- See http://tools.ietf.org/html/rfc6125#section-6.4.3
60
- """
61
- pats = []
62
- if not dn :
63
- return False
64
-
65
- parts = dn .split (r'.' )
66
- leftmost = parts [0 ]
67
- remainder = parts [1 :]
68
-
69
- wildcards = leftmost .count ('*' )
70
- if wildcards > max_wildcards :
71
- # Issue #17980: avoid denials of service by refusing more
72
- # than one wildcard per fragment. A survery of established
73
- # policy among SSL implementations showed it to be a
74
- # reasonable choice.
75
- raise CertificateError (
76
- "too many wildcards in certificate DNS name: " + repr (dn ))
77
-
78
- # speed up common case w/o wildcards
79
- if not wildcards :
80
- return dn .lower () == hostname .lower ()
81
-
82
- # RFC 6125, section 6.4.3, subitem 1.
83
- # The client SHOULD NOT attempt to match a presented identifier in which
84
- # the wildcard character comprises a label other than the left-most label.
85
- if leftmost == '*' :
86
- # When '*' is a fragment by itself, it matches a non-empty dotless
87
- # fragment.
88
- pats .append ('[^.]+' )
89
- elif leftmost .startswith ('xn--' ) or hostname .startswith ('xn--' ):
90
- # RFC 6125, section 6.4.3, subitem 3.
91
- # The client SHOULD NOT attempt to match a presented identifier
92
- # where the wildcard character is embedded within an A-label or
93
- # U-label of an internationalized domain name.
94
- pats .append (re .escape (leftmost ))
95
- else :
96
- # Otherwise, '*' matches any dotless string, e.g. www*
97
- pats .append (re .escape (leftmost ).replace (r'\*' , '[^.]*' ))
98
-
99
- # add the remaining fragments, ignore any wildcards
100
- for frag in remainder :
101
- pats .append (re .escape (frag ))
102
-
103
- pat = re .compile (r'\A' + r'\.' .join (pats ) + r'\Z' , re .IGNORECASE )
104
- return pat .match (hostname )
105
-
106
-
107
- # lifted from the Python 3.4 stdlib
108
- def match_hostname (cert , hostname ):
109
- """
110
- Verify that ``cert`` (in decoded format as returned by
111
- ``SSLSocket.getpeercert())`` matches the ``hostname``. RFC 2818 and RFC
112
- 6125 rules are followed, but IP addresses are not accepted for ``hostname``.
113
-
114
- ``CertificateError`` is raised on failure. On success, the function returns
115
- nothing.
57
+ A method nearly compatible with the stdlib's match_hostname.
116
58
"""
117
- if not cert :
118
- raise ValueError ("empty or no certificate, match_hostname needs a "
119
- "SSL socket or SSL context with either "
120
- "CERT_OPTIONAL or CERT_REQUIRED" )
121
- dnsnames = []
122
- san = cert .get ('subjectAltName' , ())
123
- for key , value in san :
124
- if key == 'DNS' :
125
- if _dnsname_match (value , hostname ):
126
- return
127
- dnsnames .append (value )
128
- if not dnsnames :
129
- # The subject is only checked when there is no dNSName entry
130
- # in subjectAltName
131
- for sub in cert .get ('subject' , ()):
132
- for key , value in sub :
133
- # XXX according to RFC 2818, the most specific Common Name
134
- # must be used.
135
- if key == 'commonName' :
136
- if _dnsname_match (value , hostname ):
137
- return
138
- dnsnames .append (value )
139
- if len (dnsnames ) > 1 :
140
- raise CertificateError ("hostname %r "
141
- "doesn't match either of %s"
142
- % (hostname , ', ' .join (map (repr , dnsnames ))))
143
- elif len (dnsnames ) == 1 :
144
- raise CertificateError ("hostname %r "
145
- "doesn't match %r"
146
- % (hostname , dnsnames [0 ]))
147
- else :
148
- raise CertificateError ("no appropriate commonName or "
149
- "subjectAltName fields were found" )
59
+ return _verify (ssl_sock ._conn , server_hostname )
150
60
151
61
152
62
class SSLSocket (object ):
@@ -165,6 +75,7 @@ def __init__(self, conn, server_side, do_handshake_on_connect,
165
75
else :
166
76
if server_hostname :
167
77
self ._conn .set_tlsext_host_name (server_hostname .encode ('utf-8' ))
78
+ self ._server_hostname = server_hostname
168
79
self ._conn .set_connect_state () # FIXME does this override do_handshake_on_connect=False?
169
80
170
81
if self .connected and self ._do_handshake_on_connect :
@@ -211,7 +122,7 @@ def connect(self, address):
211
122
def do_handshake (self ):
212
123
self ._safe_ssl_call (False , self ._conn .do_handshake )
213
124
if self ._check_hostname :
214
- match_hostname (self . getpeercert () , self ._conn . get_servername (). decode ( 'utf-8' ) )
125
+ verify_hostname (self , self ._server_hostname )
215
126
216
127
def recv (self , bufsize , flags = None ):
217
128
return self ._safe_ssl_call (self ._suppress_ragged_eofs , self ._conn .recv ,
@@ -232,7 +143,11 @@ def send(self, data, flags=None):
232
143
return self ._safe_ssl_call (False , self ._conn .send , data , flags )
233
144
234
145
def selected_npn_protocol (self ):
235
- raise NotImplementedError ()
146
+ proto = self ._conn .get_next_proto_negotiated ()
147
+ if isinstance (proto , bytes ):
148
+ proto = proto .decode ('ascii' )
149
+
150
+ return proto if proto else None
236
151
237
152
def getpeercert (self ):
238
153
def resolve_alias (alias ):
@@ -276,6 +191,7 @@ def __init__(self, protocol):
276
191
self ._ctx = ossl .Context (protocol )
277
192
self .options = OP_ALL
278
193
self .check_hostname = False
194
+ self .npn_protos = []
279
195
280
196
@property
281
197
def options (self ):
@@ -315,8 +231,20 @@ def load_cert_chain(self, certfile, keyfile=None, password=None):
315
231
self ._ctx .use_privatekey_file (keyfile or certfile )
316
232
317
233
def set_npn_protocols (self , protocols ):
318
- # TODO
319
- raise NotImplementedError ()
234
+ self .protocols = list (map (lambda x :x .encode ('ascii' ), protocols ))
235
+
236
+ def cb (conn , protos ):
237
+ # Detect the overlapping set of protocols.
238
+ overlap = set (protos ) & set (self .protocols )
239
+
240
+ # Select the option that comes last in the list in the overlap.
241
+ for p in self .protocols :
242
+ if p in overlap :
243
+ return p
244
+ else :
245
+ return b''
246
+
247
+ self ._ctx .set_npn_select_callback (cb )
320
248
321
249
def wrap_socket (self , sock , server_side = False , do_handshake_on_connect = True ,
322
250
suppress_ragged_eofs = True , server_hostname = None ):
0 commit comments