Skip to content

Commit be42010

Browse files
committed
Merge branch 'test/moving_buf' into jan23-2024
2 parents 93bf6b3 + 994338d commit be42010

File tree

1 file changed

+44
-27
lines changed

1 file changed

+44
-27
lines changed

tests/test_ssl.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -445,43 +445,59 @@ def create_ssl_nonblocking_connection(
445445
"""
446446
Create a pair of sockets and set up an SSL connection between them.
447447
"""
448-
# Create a private key and a certificate to use for the server
449-
key = PKey()
450-
key.generate_key(TYPE_RSA, 2048)
451-
cert = X509()
452-
cert.set_version(2)
453-
cert.get_subject().C = b"US"
454-
cert.get_subject().ST = b"California"
455-
cert.get_subject().L = b"Palo Alto"
456-
cert.get_subject().O = b"pyOpenSSL"
457-
cert.get_subject().CN = b"localhost"
458-
cert.set_serial_number(1)
459-
cert.gmtime_adj_notBefore(0)
460-
cert.gmtime_adj_notAfter(60 * 60)
461-
cert.set_issuer(cert.get_subject())
462-
cert.set_pubkey(key)
463-
cert.sign(key, "sha1")
464-
465-
# Create a context with the necessary modes
466-
ctx = Context(SSLv23_METHOD)
448+
chain = _create_certificate_chain()
449+
450+
# Extract the server's key and certificate from the chain ---
451+
# The chain is [ (root_key, root_cert),
452+
# (intermediate_key, intermediate_cert), (server_key, server_cert) ]
453+
server_key, server_cert = chain[
454+
2
455+
] # Index 2 gets the last tuple: (skey, scert)
456+
457+
# Set up the server's SSL context ---
458+
server_ctx = Context(SSLv23_METHOD)
459+
server_ctx.use_privatekey(server_key) # Use the server_key from the chain
460+
server_ctx.use_certificate(
461+
server_cert
462+
) # Use the server_cert from the chain
463+
server_ctx.add_extra_chain_cert(
464+
chain[1][1]
465+
) # Add the intermediate cert to the server's extra chain
466+
467+
# Set up client context
468+
client_ctx = Context(SSLv23_METHOD)
467469

468470
# these modes are set by default when ctx is initialized
469471
# clear them so we can run tests with or without them
470-
ctx.clear_mode(
472+
client_ctx.clear_mode(
471473
_lib.SSL_MODE_ENABLE_PARTIAL_WRITE
472474
| _lib.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER
473475
)
476+
client_ctx.set_mode(mode)
477+
478+
# Get the certificate store from the context
479+
cert_store = client_ctx.get_cert_store()
474480

475-
ctx.set_mode(mode)
476-
ctx.use_privatekey(key)
477-
ctx.use_certificate(cert)
481+
# Assert that cert_store is not None to satisfy mypy
482+
assert cert_store is not None, (
483+
"Expected X509Store, but got None from get_cert_store()"
484+
)
485+
486+
# Add the Root CA certificate to the store
487+
cert_store.add_cert(
488+
chain[0][1]
489+
) # chain[0][1] is the pyOpenSSL X509 object for the root CA
490+
# Enable peer verification so the client actually checks the server's cert
491+
client_ctx.set_verify(
492+
SSL.VERIFY_PEER, lambda conn, cert, errnum, depth, ok: bool(ok)
493+
)
478494

479495
# Create connections with real sockets
480496
client_socket, server_socket = socket_pair()
481497

482498
# Create Connection objects from the sockets
483-
client = Connection(ctx, client_socket)
484-
server = Connection(ctx, server_socket)
499+
client = Connection(client_ctx, client_socket)
500+
server = Connection(server_ctx, server_socket)
485501

486502
# Set the buffers to be very small so we can easily fill them
487503
client_socket.setsockopt(SOL_SOCKET, SO_SNDBUF, 256)
@@ -3109,12 +3125,13 @@ def test_wantWriteError(self) -> None:
31093125
# signal a short write via its return value it seems this doesn't
31103126
# always happen on all platforms (FreeBSD and OS X particular) for the
31113127
# very last bit of available buffer space.
3112-
for msg in [b"x" * 65536, b"x"]:
3128+
for msg in [b"x" * 65536, b"x" * 16, b"x"]:
31133129
for i in range(1024 * 1024 * 64):
31143130
try:
31153131
client_socket.send(msg)
31163132
except OSError as e:
31173133
if e.errno == EWOULDBLOCK:
3134+
time.sleep(0.1)
31183135
break
31193136
raise # pragma: no cover
31203137
else: # pragma: no cover
@@ -3286,7 +3303,7 @@ def _perform_moving_buffer_test(
32863303
except SSL.Error as e:
32873304
reason = get_ssl_error_reason(e)
32883305
if reason == "bad write retry":
3289-
print(f"Got expected SSL error: {e} ({reason}).")
3306+
print(f"Got SSL error: {e} ({reason}).")
32903307
return True # Bad write retry
32913308
else:
32923309
pytest.fail(

0 commit comments

Comments
 (0)