22
22
23
23
#include "mbedtls/version.h"
24
24
25
+ #define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
26
+
25
27
#if defined(MBEDTLS_ERROR_C )
26
28
#include "../../lib/mbedtls_errors/mp_mbedtls_errors.c"
27
29
#endif
@@ -220,6 +222,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
220
222
o -> base .type = & ssl_sslsocket_type ;
221
223
o -> ssl_context = self ;
222
224
o -> sock_obj = socket ;
225
+ o -> poll_mask = 0 ;
223
226
224
227
mp_load_method (socket , MP_QSTR_accept , o -> accept_args );
225
228
mp_load_method (socket , MP_QSTR_bind , o -> bind_args );
@@ -330,7 +333,8 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
330
333
}
331
334
}
332
335
333
- mp_uint_t common_hal_ssl_sslsocket_recv_into (ssl_sslsocket_obj_t * self , uint8_t * buf , uint32_t len ) {
336
+ mp_uint_t common_hal_ssl_sslsocket_recv_into (ssl_sslsocket_obj_t * self , uint8_t * buf , mp_uint_t len ) {
337
+ self -> poll_mask = 0 ;
334
338
int ret = mbedtls_ssl_read (& self -> ssl , buf , len );
335
339
DEBUG_PRINT ("recv_into mbedtls_ssl_read() -> %d\n" , ret );
336
340
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ) {
@@ -342,17 +346,24 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t
342
346
DEBUG_PRINT ("returning %d\n" , ret );
343
347
return ret ;
344
348
}
349
+ if (ret == MBEDTLS_ERR_SSL_WANT_WRITE ) {
350
+ self -> poll_mask = MP_STREAM_POLL_WR ;
351
+ }
345
352
DEBUG_PRINT ("raising errno [error case] %d\n" , ret );
346
353
mbedtls_raise_error (ret );
347
354
}
348
355
349
- mp_uint_t common_hal_ssl_sslsocket_send (ssl_sslsocket_obj_t * self , const uint8_t * buf , uint32_t len ) {
356
+ mp_uint_t common_hal_ssl_sslsocket_send (ssl_sslsocket_obj_t * self , const uint8_t * buf , mp_uint_t len ) {
357
+ self -> poll_mask = 0 ;
350
358
int ret = mbedtls_ssl_write (& self -> ssl , buf , len );
351
359
DEBUG_PRINT ("send mbedtls_ssl_write() -> %d\n" , ret );
352
360
if (ret >= 0 ) {
353
361
DEBUG_PRINT ("returning %d\n" , ret );
354
362
return ret ;
355
363
}
364
+ if (ret == MBEDTLS_ERR_SSL_WANT_READ ) {
365
+ self -> poll_mask = MP_STREAM_POLL_RD ;
366
+ }
356
367
DEBUG_PRINT ("raising errno [error case] %d\n" , ret );
357
368
mbedtls_raise_error (ret );
358
369
}
@@ -448,3 +459,37 @@ void common_hal_ssl_sslsocket_setsockopt(ssl_sslsocket_obj_t *self, mp_obj_t lev
448
459
void common_hal_ssl_sslsocket_settimeout (ssl_sslsocket_obj_t * self , mp_obj_t timeout_obj ) {
449
460
ssl_socket_settimeout (self , timeout_obj );
450
461
}
462
+
463
+ static bool poll_common (ssl_sslsocket_obj_t * self , uintptr_t arg ) {
464
+ // Take into account that the library might have buffered data already
465
+ int has_pending = 0 ;
466
+ if (arg & MP_STREAM_POLL_RD ) {
467
+ has_pending = mbedtls_ssl_check_pending (& self -> ssl );
468
+ if (has_pending ) {
469
+ // Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
470
+ return true;
471
+ }
472
+ }
473
+
474
+ // If the library signaled us that it needs reading or writing, only
475
+ // check that direction
476
+ if (self -> poll_mask && (arg & MP_STREAM_POLL_RDWR )) {
477
+ arg = (arg & ~MP_STREAM_POLL_RDWR ) | self -> poll_mask ;
478
+ }
479
+
480
+ // If direction the library needed is available, return a fake
481
+ // result to the caller so that it reenters a read or a write to
482
+ // allow the handshake to progress
483
+ const mp_stream_p_t * stream_p = mp_get_stream_raise (self -> sock_obj , MP_STREAM_OP_IOCTL );
484
+ int errcode ;
485
+ mp_int_t ret = stream_p -> ioctl (self -> sock_obj , MP_STREAM_POLL , arg , & errcode );
486
+ return ret != 0 ;
487
+ }
488
+
489
+ bool common_hal_ssl_sslsocket_readable (ssl_sslsocket_obj_t * self ) {
490
+ return poll_common (self , MP_STREAM_POLL_RD );
491
+ }
492
+
493
+ bool common_hal_ssl_sslsocket_writable (ssl_sslsocket_obj_t * self ) {
494
+ return poll_common (self , MP_STREAM_POLL_WR );
495
+ }
0 commit comments