Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,25 @@ def getpass(self):
# Make sure the password function isn't called if it isn't needed
ctx.load_cert_chain(CERTFILE, password=getpass_exception)

@threading_helper.requires_working_threading()
def test_load_cert_chain_thread_safety(self):
# gh-134698: _ssl detaches the thread state (and as such,
# releases the GIL and critical sections) around expensive
# OpenSSL calls. Unfortunately, OpenSSL structures aren't
# thread-safe, so executing these calls concurrently led
# to crashes.
ctx = ssl.create_default_context()

def race():
ctx.load_cert_chain(CERTFILE)

threads = [threading.Thread(target=race) for _ in range(8)]
with threading_helper.catch_threading_exception() as cm:
with threading_helper.start_threads(threads):
pass

self.assertIsNone(cm.exc_value)

def test_load_verify_locations(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.load_verify_locations(CERTFILE)
Expand Down Expand Up @@ -4538,6 +4557,42 @@ def server_callback(identity):
with client_context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))

def test_thread_recv_while_main_thread_sends(self):
# GH-137583: Locking was added to calls to send() and recv() on SSL
# socket objects. This seemed fine at the surface level because those
# calls weren't re-entrant, but recv() calls would implicitly mimick
# holding a lock by blocking until it received data. This means that
# if a thread started to infinitely block until data was received, calls
# to send() would deadlock, because it would wait forever on the lock
# that the recv() call held.
data = b"1" * 1024
event = threading.Event()
def background(sock):
event.set()
received = sock.recv(len(data))
self.assertEqual(received, data)

client_context, server_context, hostname = testing_context()
server = ThreadedEchoServer(context=server_context)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as sock:
sock.connect((HOST, server.port))
sock.settimeout(1)
sock.setblocking(1)
# Ensure that the server is ready to accept requests
sock.sendall(b"123")
self.assertEqual(sock.recv(3), b"123")
with threading_helper.catch_threading_exception() as cm:
thread = threading.Thread(target=background,
args=(sock,), daemon=True)
thread.start()
event.wait()
sock.sendall(data)
thread.join()
if cm.exc_value is not None:
raise cm.exc_value


@unittest.skipUnless(has_tls_version('TLSv1_3') and ssl.HAS_PHA,
"Test needs TLS 1.3 PHA")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix a crash when calling methods of :class:`ssl.SSLContext` or
:class:`ssl.SSLSocket` across multiple threads.
100 changes: 58 additions & 42 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@
/* Redefined below for Windows debug builds after important #includes */
#define _PySSL_FIX_ERRNO

#define PySSL_BEGIN_ALLOW_THREADS_S(save) \
do { (save) = PyEval_SaveThread(); } while(0)
#define PySSL_END_ALLOW_THREADS_S(save) \
do { PyEval_RestoreThread(save); _PySSL_FIX_ERRNO; } while(0)
#define PySSL_BEGIN_ALLOW_THREADS { \
#define PySSL_BEGIN_ALLOW_THREADS_S(save, mutex) \
do { (save) = PyEval_SaveThread(); PyMutex_Lock(mutex); } while(0)
#define PySSL_END_ALLOW_THREADS_S(save, mutex) \
do { PyMutex_Unlock(mutex); PyEval_RestoreThread(save); _PySSL_FIX_ERRNO; } while(0)
#define PySSL_BEGIN_ALLOW_THREADS(self) { \
PyThreadState *_save = NULL; \
PySSL_BEGIN_ALLOW_THREADS_S(_save);
#define PySSL_END_ALLOW_THREADS PySSL_END_ALLOW_THREADS_S(_save); }
PySSL_BEGIN_ALLOW_THREADS_S(_save, &self->tstate_mutex);
#define PySSL_END_ALLOW_THREADS(self) PySSL_END_ALLOW_THREADS_S(_save, &self->tstate_mutex); }

#if defined(HAVE_POLL_H)
#include <poll.h>
Expand Down Expand Up @@ -309,6 +309,9 @@ typedef struct {
PyObject *psk_client_callback;
PyObject *psk_server_callback;
#endif
/* Lock to synchronize calls when the thread state is detached.
See also gh-134698. */
PyMutex tstate_mutex;
} PySSLContext;

#define PySSLContext_CAST(op) ((PySSLContext *)(op))
Expand Down Expand Up @@ -889,9 +892,9 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
/* Make sure the SSL error state is initialized */
ERR_clear_error();

PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(sslctx)
self->ssl = SSL_new(ctx);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(sslctx)
if (self->ssl == NULL) {
Py_DECREF(self);
_setSSLError(get_state_ctx(self), NULL, 0, __FILE__, __LINE__);
Expand Down Expand Up @@ -960,12 +963,12 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
BIO_set_nbio(SSL_get_wbio(self->ssl), 1);
}

PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS;
if (socket_type == PY_SSL_CLIENT)
SSL_set_connect_state(self->ssl);
else
SSL_set_accept_state(self->ssl);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS;

self->socket_type = socket_type;
if (sock != NULL) {
Expand Down Expand Up @@ -1034,10 +1037,11 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
/* Actually negotiate SSL connection */
/* XXX If SSL_do_handshake() returns 0, it's also a failure. */
do {
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
ret = SSL_do_handshake(self->ssl);
err = _PySSL_errno(ret < 1, self->ssl, ret);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;
self->err = err;

if (PyErr_CheckSignals())
Expand Down Expand Up @@ -2414,9 +2418,10 @@ PySSL_select(PySocketSockObject *s, int writing, PyTime_t timeout)
ms = (int)_PyTime_AsMilliseconds(timeout, _PyTime_ROUND_CEILING);
assert(ms <= INT_MAX);

PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
rc = poll(&pollfd, 1, (int)ms);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
#else
/* Guard against socket too large for select*/
if (!_PyIsSelectable_fd(s->sock_fd))
Expand All @@ -2428,13 +2433,14 @@ PySSL_select(PySocketSockObject *s, int writing, PyTime_t timeout)
FD_SET(s->sock_fd, &fds);

/* Wait until the socket becomes ready */
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
nfds = Py_SAFE_DOWNCAST(s->sock_fd+1, SOCKET_T, int);
if (writing)
rc = select(nfds, NULL, &fds, NULL, &tv);
else
rc = select(nfds, &fds, NULL, NULL, &tv);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
#endif

/* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise
Expand Down Expand Up @@ -2505,10 +2511,11 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b)
}

do {
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS;
retval = SSL_write_ex(self->ssl, b->buf, (size_t)b->len, &count);
err = _PySSL_errno(retval == 0, self->ssl, retval);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;
self->err = err;

if (PyErr_CheckSignals())
Expand Down Expand Up @@ -2566,10 +2573,11 @@ _ssl__SSLSocket_pending_impl(PySSLSocket *self)
int count = 0;
_PySSLError err;

PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS;
count = SSL_pending(self->ssl);
err = _PySSL_errno(count < 0, self->ssl, count);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;
self->err = err;

if (count < 0)
Expand Down Expand Up @@ -2660,10 +2668,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
deadline = _PyDeadline_Init(timeout);

do {
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS;
retval = SSL_read_ex(self->ssl, mem, (size_t)len, &count);
err = _PySSL_errno(retval == 0, self->ssl, retval);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;
self->err = err;

if (PyErr_CheckSignals())
Expand Down Expand Up @@ -2762,7 +2771,7 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
}

while (1) {
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS;
/* Disable read-ahead so that unwrap can work correctly.
* Otherwise OpenSSL might read in too much data,
* eating clear text data that happens to be
Expand All @@ -2775,7 +2784,8 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
SSL_set_read_ahead(self->ssl, 0);
ret = SSL_shutdown(self->ssl);
err = _PySSL_errno(ret < 0, self->ssl, ret);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS;
_PySSL_FIX_ERRNO;
self->err = err;

/* If err == 1, a secure shutdown with SSL_shutdown() is complete */
Expand Down Expand Up @@ -3167,9 +3177,10 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
// no other thread can be touching this object yet.
// (Technically, we can't even lock if we wanted to, as the
// lock hasn't been initialized yet.)
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
ctx = SSL_CTX_new(method);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;

if (ctx == NULL) {
_setSSLError(get_ssl_state(module), NULL, 0, __FILE__, __LINE__);
Expand All @@ -3194,6 +3205,7 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
self->psk_client_callback = NULL;
self->psk_server_callback = NULL;
#endif
self->tstate_mutex = (PyMutex){0};

/* Don't check host name by default */
if (proto_version == PY_SSL_VERSION_TLS_CLIENT) {
Expand Down Expand Up @@ -3312,9 +3324,10 @@ context_clear(PyObject *op)
Py_CLEAR(self->psk_server_callback);
#endif
if (self->keylog_bio != NULL) {
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
BIO_free_all(self->keylog_bio);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
self->keylog_bio = NULL;
}
return 0;
Expand Down Expand Up @@ -4037,7 +4050,8 @@ _password_callback(char *buf, int size, int rwflag, void *userdata)
_PySSLPasswordInfo *pw_info = (_PySSLPasswordInfo*) userdata;
PyObject *fn_ret = NULL;

PySSL_END_ALLOW_THREADS_S(pw_info->thread_state);
pw_info->thread_state = PyThreadState_Swap(pw_info->thread_state);
_PySSL_FIX_ERRNO;

if (pw_info->error) {
/* already failed previously. OpenSSL 3.0.0-alpha14 invokes the
Expand Down Expand Up @@ -4067,13 +4081,13 @@ _password_callback(char *buf, int size, int rwflag, void *userdata)
goto error;
}

PySSL_BEGIN_ALLOW_THREADS_S(pw_info->thread_state);
pw_info->thread_state = PyThreadState_Swap(pw_info->thread_state);
memcpy(buf, pw_info->password, pw_info->size);
return pw_info->size;

error:
Py_XDECREF(fn_ret);
PySSL_BEGIN_ALLOW_THREADS_S(pw_info->thread_state);
pw_info->thread_state = PyThreadState_Swap(pw_info->thread_state);
pw_info->error = 1;
return -1;
}
Expand Down Expand Up @@ -4126,10 +4140,10 @@ _ssl__SSLContext_load_cert_chain_impl(PySSLContext *self, PyObject *certfile,
SSL_CTX_set_default_passwd_cb(self->ctx, _password_callback);
SSL_CTX_set_default_passwd_cb_userdata(self->ctx, &pw_info);
}
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
r = SSL_CTX_use_certificate_chain_file(self->ctx,
PyBytes_AS_STRING(certfile_bytes));
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
if (r != 1) {
if (pw_info.error) {
ERR_clear_error();
Expand All @@ -4144,11 +4158,11 @@ _ssl__SSLContext_load_cert_chain_impl(PySSLContext *self, PyObject *certfile,
}
goto error;
}
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
r = SSL_CTX_use_PrivateKey_file(self->ctx,
PyBytes_AS_STRING(keyfile ? keyfile_bytes : certfile_bytes),
SSL_FILETYPE_PEM);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
Py_CLEAR(keyfile_bytes);
Py_CLEAR(certfile_bytes);
if (r != 1) {
Expand All @@ -4165,9 +4179,9 @@ _ssl__SSLContext_load_cert_chain_impl(PySSLContext *self, PyObject *certfile,
}
goto error;
}
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
r = SSL_CTX_check_private_key(self->ctx);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
PySSL_END_ALLOW_THREADS_S(pw_info.thread_state, &self->tstate_mutex);
if (r != 1) {
_setSSLError(get_state_ctx(self), NULL, 0, __FILE__, __LINE__);
goto error;
Expand Down Expand Up @@ -4384,9 +4398,9 @@ _ssl__SSLContext_load_verify_locations_impl(PySSLContext *self,
cafile_buf = PyBytes_AS_STRING(cafile_bytes);
if (capath)
capath_buf = PyBytes_AS_STRING(capath_bytes);
PySSL_BEGIN_ALLOW_THREADS
PySSL_BEGIN_ALLOW_THREADS(self)
r = SSL_CTX_load_verify_locations(self->ctx, cafile_buf, capath_buf);
PySSL_END_ALLOW_THREADS
PySSL_END_ALLOW_THREADS(self)
if (r != 1) {
if (errno != 0) {
PyErr_SetFromErrno(PyExc_OSError);
Expand Down Expand Up @@ -4438,10 +4452,11 @@ _ssl__SSLContext_load_dh_params_impl(PySSLContext *self, PyObject *filepath)
return NULL;

errno = 0;
PySSL_BEGIN_ALLOW_THREADS
Py_BEGIN_ALLOW_THREADS
dh = PEM_read_DHparams(f, NULL, NULL, NULL);
fclose(f);
PySSL_END_ALLOW_THREADS
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
if (dh == NULL) {
if (errno != 0) {
PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, filepath);
Expand Down Expand Up @@ -4593,6 +4608,7 @@ _ssl__SSLContext_set_default_verify_paths_impl(PySSLContext *self)
Py_BEGIN_ALLOW_THREADS
rc = SSL_CTX_set_default_verify_paths(self->ctx);
Py_END_ALLOW_THREADS
_PySSL_FIX_ERRNO;
if (!rc) {
_setSSLError(get_state_ctx(self), NULL, 0, __FILE__, __LINE__);
return NULL;
Expand Down
Loading
Loading