Skip to content

Commit 49a6120

Browse files
committed
take micropython tricks for selectability of ssl sockets
1 parent 7969638 commit 49a6120

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

shared-module/ssl/SSLSocket.c

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#include "mbedtls/version.h"
2424

25+
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
26+
2527
#if defined(MBEDTLS_ERROR_C)
2628
#include "../../lib/mbedtls_errors/mp_mbedtls_errors.c"
2729
#endif
@@ -220,6 +222,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
220222
o->base.type = &ssl_sslsocket_type;
221223
o->ssl_context = self;
222224
o->sock_obj = socket;
225+
o->poll_mask = 0;
223226

224227
mp_load_method(socket, MP_QSTR_accept, o->accept_args);
225228
mp_load_method(socket, MP_QSTR_bind, o->bind_args);
@@ -331,6 +334,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
331334
}
332335

333336
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len) {
337+
self->poll_mask = 0;
334338
int ret = mbedtls_ssl_read(&self->ssl, buf, len);
335339
DEBUG_PRINT("recv_into mbedtls_ssl_read() -> %d\n", ret);
336340
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
@@ -342,17 +346,24 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t
342346
DEBUG_PRINT("returning %d\n", ret);
343347
return ret;
344348
}
349+
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
350+
self->poll_mask = MP_STREAM_POLL_WR;
351+
}
345352
DEBUG_PRINT("raising errno [error case] %d\n", ret);
346353
mbedtls_raise_error(ret);
347354
}
348355

349356
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len) {
357+
self->poll_mask = 0;
350358
int ret = mbedtls_ssl_write(&self->ssl, buf, len);
351359
DEBUG_PRINT("send mbedtls_ssl_write() -> %d\n", ret);
352360
if (ret >= 0) {
353361
DEBUG_PRINT("returning %d\n", ret);
354362
return ret;
355363
}
364+
if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
365+
self->poll_mask = MP_STREAM_POLL_RD;
366+
}
356367
DEBUG_PRINT("raising errno [error case] %d\n", ret);
357368
mbedtls_raise_error(ret);
358369
}
@@ -449,10 +460,29 @@ void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, mp_obj_t tim
449460
ssl_socket_settimeout(self, timeout_obj);
450461
}
451462

452-
static bool poll_common(ssl_sslsocket_obj_t *self, int mode) {
463+
static bool poll_common(ssl_sslsocket_obj_t *self, uintptr_t arg) {
464+
// Take into account that the library might have buffered data already
465+
int has_pending = 0;
466+
if (arg & MP_STREAM_POLL_RD) {
467+
has_pending = mbedtls_ssl_check_pending(&self->ssl);
468+
if (has_pending) {
469+
// Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
470+
return true;
471+
}
472+
}
473+
474+
// If the library signaled us that it needs reading or writing, only
475+
// check that direction
476+
if (self->poll_mask && (arg & MP_STREAM_POLL_RDWR)) {
477+
arg = (arg & ~MP_STREAM_POLL_RDWR) | self->poll_mask;
478+
}
479+
480+
// If direction the library needed is available, return a fake
481+
// result to the caller so that it reenters a read or a write to
482+
// allow the handshake to progress
453483
const mp_stream_p_t *stream_p = mp_get_stream_raise(self->sock_obj, MP_STREAM_OP_IOCTL);
454484
int errcode;
455-
mp_int_t ret = stream_p->ioctl(self->sock_obj, MP_STREAM_POLL, mode, &errcode);
485+
mp_int_t ret = stream_p->ioctl(self->sock_obj, MP_STREAM_POLL, arg, &errcode);
456486
return ret != 0;
457487
}
458488

shared-module/ssl/SSLSocket.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ typedef struct ssl_sslsocket_obj {
2929
mbedtls_x509_crt cacert;
3030
mbedtls_x509_crt cert;
3131
mbedtls_pk_context pkey;
32+
uintptr_t poll_mask;
3233
bool closed;
3334
mp_obj_t accept_args[2];
3435
mp_obj_t bind_args[3];

0 commit comments

Comments
 (0)