Skip to content

Commit 1fbe064

Browse files
davidbenalex
authored andcommitted
Make test_ssl pass in an IPv6-only environment (#827)
* Make test_ssl pass in an IPv6-only environment * Review comments * Update tests/test_ssl.py Co-Authored-By: davidben <[email protected]> * Wrap long line with parens.
1 parent 4d57590 commit 1fbe064

File tree

1 file changed

+38
-20
lines changed

1 file changed

+38
-20
lines changed

tests/test_ssl.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import uuid
1111

1212
from gc import collect, get_referrers
13-
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN
13+
from errno import (
14+
EAFNOSUPPORT, ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN)
1415
from sys import platform, getfilesystemencoding
15-
from socket import MSG_PEEK, SHUT_RDWR, error, socket
16+
from socket import AF_INET, AF_INET6, MSG_PEEK, SHUT_RDWR, error, socket
1617
from os import makedirs
1718
from os.path import join
1819
from weakref import ref
@@ -101,6 +102,23 @@
101102
skip_if_py3 = pytest.mark.skipif(PY3, reason="Python 2 only")
102103

103104

105+
def socket_any_family():
106+
try:
107+
return socket(AF_INET)
108+
except error as e:
109+
if e.errno == EAFNOSUPPORT:
110+
return socket(AF_INET6)
111+
raise
112+
113+
114+
def loopback_address(socket):
115+
if socket.family == AF_INET:
116+
return "127.0.0.1"
117+
else:
118+
assert socket.family == AF_INET6
119+
return "::1"
120+
121+
104122
def join_bytes_or_unicode(prefix, suffix):
105123
"""
106124
Join two path components of either ``bytes`` or ``unicode``.
@@ -127,12 +145,12 @@ def socket_pair():
127145
Establish and return a pair of network sockets connected to each other.
128146
"""
129147
# Connect a pair of sockets
130-
port = socket()
148+
port = socket_any_family()
131149
port.bind(('', 0))
132150
port.listen(1)
133-
client = socket()
151+
client = socket(port.family)
134152
client.setblocking(False)
135-
client.connect_ex(("127.0.0.1", port.getsockname()[1]))
153+
client.connect_ex((loopback_address(port), port.getsockname()[1]))
136154
client.setblocking(True)
137155
server = port.accept()[0]
138156

@@ -1209,7 +1227,7 @@ def test_set_default_verify_paths(self):
12091227
VERIFY_PEER,
12101228
lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
12111229

1212-
client = socket()
1230+
client = socket_any_family()
12131231
client.connect(("encrypted.google.com", 443))
12141232
clientSSL = Connection(context, client)
12151233
clientSSL.set_connect_state()
@@ -2237,7 +2255,7 @@ def test_connect_wrong_args(self):
22372255
`Connection.connect` raises `TypeError` if called with a non-address
22382256
argument.
22392257
"""
2240-
connection = Connection(Context(TLSv1_METHOD), socket())
2258+
connection = Connection(Context(TLSv1_METHOD), socket_any_family())
22412259
with pytest.raises(TypeError):
22422260
connection.connect(None)
22432261

@@ -2246,13 +2264,13 @@ def test_connect_refused(self):
22462264
`Connection.connect` raises `socket.error` if the underlying socket
22472265
connect method raises it.
22482266
"""
2249-
client = socket()
2267+
client = socket_any_family()
22502268
context = Context(TLSv1_METHOD)
22512269
clientSSL = Connection(context, client)
22522270
# pytest.raises here doesn't work because of a bug in py.test on Python
22532271
# 2.6: https://github.com/pytest-dev/pytest/issues/988
22542272
try:
2255-
clientSSL.connect(("127.0.0.1", 1))
2273+
clientSSL.connect((loopback_address(client), 1))
22562274
except error as e:
22572275
exc = e
22582276
assert exc.args[0] == ECONNREFUSED
@@ -2261,12 +2279,12 @@ def test_connect(self):
22612279
"""
22622280
`Connection.connect` establishes a connection to the specified address.
22632281
"""
2264-
port = socket()
2282+
port = socket_any_family()
22652283
port.bind(('', 0))
22662284
port.listen(3)
22672285

2268-
clientSSL = Connection(Context(TLSv1_METHOD), socket())
2269-
clientSSL.connect(('127.0.0.1', port.getsockname()[1]))
2286+
clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family))
2287+
clientSSL.connect((loopback_address(port), port.getsockname()[1]))
22702288
# XXX An assertion? Or something?
22712289

22722290
@pytest.mark.skipif(
@@ -2278,11 +2296,11 @@ def test_connect_ex(self):
22782296
If there is a connection error, `Connection.connect_ex` returns the
22792297
errno instead of raising an exception.
22802298
"""
2281-
port = socket()
2299+
port = socket_any_family()
22822300
port.bind(('', 0))
22832301
port.listen(3)
22842302

2285-
clientSSL = Connection(Context(TLSv1_METHOD), socket())
2303+
clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family))
22862304
clientSSL.setblocking(False)
22872305
result = clientSSL.connect_ex(port.getsockname())
22882306
expected = (EINPROGRESS, EWOULDBLOCK)
@@ -2297,16 +2315,16 @@ def test_accept(self):
22972315
ctx = Context(TLSv1_METHOD)
22982316
ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
22992317
ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
2300-
port = socket()
2318+
port = socket_any_family()
23012319
portSSL = Connection(ctx, port)
23022320
portSSL.bind(('', 0))
23032321
portSSL.listen(3)
23042322

2305-
clientSSL = Connection(Context(TLSv1_METHOD), socket())
2323+
clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family))
23062324

23072325
# Calling portSSL.getsockname() here to get the server IP address
23082326
# sounds great, but frequently fails on Windows.
2309-
clientSSL.connect(('127.0.0.1', portSSL.getsockname()[1]))
2327+
clientSSL.connect((loopback_address(port), portSSL.getsockname()[1]))
23102328

23112329
serverSSL, address = portSSL.accept()
23122330

@@ -2379,7 +2397,7 @@ def test_set_shutdown(self):
23792397
`Connection.set_shutdown` sets the state of the SSL connection
23802398
shutdown process.
23812399
"""
2382-
connection = Connection(Context(TLSv1_METHOD), socket())
2400+
connection = Connection(Context(TLSv1_METHOD), socket_any_family())
23832401
connection.set_shutdown(RECEIVED_SHUTDOWN)
23842402
assert connection.get_shutdown() == RECEIVED_SHUTDOWN
23852403

@@ -2389,7 +2407,7 @@ def test_set_shutdown_long(self):
23892407
On Python 2 `Connection.set_shutdown` accepts an argument
23902408
of type `long` as well as `int`.
23912409
"""
2392-
connection = Connection(Context(TLSv1_METHOD), socket())
2410+
connection = Connection(Context(TLSv1_METHOD), socket_any_family())
23932411
connection.set_shutdown(long(RECEIVED_SHUTDOWN))
23942412
assert connection.get_shutdown() == RECEIVED_SHUTDOWN
23952413

@@ -3503,7 +3521,7 @@ def test_socket_overrides_memory(self):
35033521
work on `OpenSSL.SSL.Connection`() that use sockets.
35043522
"""
35053523
context = Context(TLSv1_METHOD)
3506-
client = socket()
3524+
client = socket_any_family()
35073525
clientSSL = Connection(context, client)
35083526
with pytest.raises(TypeError):
35093527
clientSSL.bio_read(100)

0 commit comments

Comments
 (0)