Skip to content
This repository was archived by the owner on Jan 13, 2021. It is now read-only.

Commit 4a1ebd6

Browse files
Use a custom exception.
Hoisted global _context declaration. Added test cases.
1 parent a9dd05d commit 4a1ebd6

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

hyper/common/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,10 @@ def __init__(self, negotiated, sock):
6464
super(HTTPUpgrade, self).__init__()
6565
self.negotiated = negotiated
6666
self.sock = sock
67+
68+
69+
class MissingCertFile(Exception):
70+
"""
71+
The certificate file could not be found.
72+
"""
73+
pass

hyper/tls.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
Contains the TLS/SSL logic for use in hyper.
77
"""
88
import os.path as path
9-
9+
from common.exceptions import MissingCertFile
1010
from .compat import ignore_missing, ssl
1111

1212

@@ -29,12 +29,13 @@ def wrap_socket(sock, server_hostname, ssl_context=None, force_proto=None):
2929
A vastly simplified SSL wrapping function. We'll probably extend this to
3030
do more things later.
3131
"""
32+
33+
global _context
34+
3235
if ssl_context:
3336
# if an SSLContext is provided then use it instead of default context
3437
_ssl_context = ssl_context
3538
else:
36-
global _context
37-
3839
# create the singleton SSLContext we use
3940
if _context is None: # pragma: no cover
4041
_context = init_context()
@@ -102,7 +103,7 @@ def init_context(cert_path=None, cert=None, cert_password=None):
102103
"ensure the default cert.pem file is included in the " +
103104
"distribution or provide a custom certificate when " +
104105
"creating the connection.")
105-
raise Exception(errMsg)
106+
raise MissingCertFile(errMsg)
106107

107108
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
108109
context.set_default_verify_paths()

test/test_SSLContext.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CLIENT_CERT_FILE = os.path.join(TEST_CERTS_DIR, 'client.crt')
1515
CLIENT_KEY_FILE = os.path.join(TEST_CERTS_DIR, 'client.key')
1616
CLIENT_PEM_FILE = os.path.join(TEST_CERTS_DIR, 'nopassword.pem')
17+
MISSING_PEM_FILE = os.path.join(TEST_CERTS_DIR, 'missing.pem')
1718

1819

1920
class TestSSLContext(object):
@@ -60,3 +61,46 @@ def test_client_certificates(self):
6061
cert=(CLIENT_CERT_FILE, CLIENT_KEY_FILE),
6162
cert_password=b'abc123')
6263
hyper.tls.init_context(cert=CLIENT_PEM_FILE)
64+
65+
def test_HTTPConnection_with_missing_certs(self):
66+
# Clear any prevously created global context
67+
hyper.tls._context = None
68+
backup_cert_loc = hyper.tls.cert_loc
69+
hyper.tls.cert_loc = MISSING_PEM_FILE
70+
71+
succeeded = False
72+
threwExpectedException = False
73+
try:
74+
HTTPConnection('http2bin.org', 443)
75+
succeeded = True
76+
except hyper.common.exceptions.MissingCertFile:
77+
threwExpectedException = True
78+
except:
79+
pass
80+
81+
hyper.tls.cert_loc = backup_cert_loc
82+
83+
assert not succeeded
84+
assert threwExpectedException
85+
86+
def test_HTTPConnection_with_missing_certs_and_custom_context(self):
87+
# Clear any prevously created global context
88+
hyper.tls._context = None
89+
backup_cert_loc = hyper.tls.cert_loc
90+
hyper.tls.cert_loc = MISSING_PEM_FILE
91+
92+
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
93+
context.set_default_verify_paths()
94+
context.verify_mode = ssl.CERT_REQUIRED
95+
context.check_hostname = True
96+
context.set_npn_protocols(['h2', 'h2-15'])
97+
context.options |= ssl.OP_NO_COMPRESSION
98+
99+
conn = HTTPConnection('http2bin.org', 443, ssl_context=context)
100+
101+
hyper.tls.cert_loc = backup_cert_loc
102+
103+
assert conn.ssl_context.check_hostname
104+
assert conn.ssl_context.verify_mode == ssl.CERT_REQUIRED
105+
assert conn.ssl_context.options & ssl.OP_NO_COMPRESSION != 0
106+

0 commit comments

Comments
 (0)