22
22
23
23
#include "mbedtls/version.h"
24
24
25
+ #define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
26
+
25
27
#if defined(MBEDTLS_ERROR_C )
26
28
#include "../../lib/mbedtls_errors/mp_mbedtls_errors.c"
27
29
#endif
@@ -220,6 +222,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
220
222
o -> base .type = & ssl_sslsocket_type ;
221
223
o -> ssl_context = self ;
222
224
o -> sock_obj = socket ;
225
+ o -> poll_mask = 0 ;
223
226
224
227
mp_load_method (socket , MP_QSTR_accept , o -> accept_args );
225
228
mp_load_method (socket , MP_QSTR_bind , o -> bind_args );
@@ -331,6 +334,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
331
334
}
332
335
333
336
mp_uint_t common_hal_ssl_sslsocket_recv_into (ssl_sslsocket_obj_t * self , uint8_t * buf , uint32_t len ) {
337
+ self -> poll_mask = 0 ;
334
338
int ret = mbedtls_ssl_read (& self -> ssl , buf , len );
335
339
DEBUG_PRINT ("recv_into mbedtls_ssl_read() -> %d\n" , ret );
336
340
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ) {
@@ -342,17 +346,24 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t
342
346
DEBUG_PRINT ("returning %d\n" , ret );
343
347
return ret ;
344
348
}
349
+ if (ret == MBEDTLS_ERR_SSL_WANT_WRITE ) {
350
+ self -> poll_mask = MP_STREAM_POLL_WR ;
351
+ }
345
352
DEBUG_PRINT ("raising errno [error case] %d\n" , ret );
346
353
mbedtls_raise_error (ret );
347
354
}
348
355
349
356
mp_uint_t common_hal_ssl_sslsocket_send (ssl_sslsocket_obj_t * self , const uint8_t * buf , uint32_t len ) {
357
+ self -> poll_mask = 0 ;
350
358
int ret = mbedtls_ssl_write (& self -> ssl , buf , len );
351
359
DEBUG_PRINT ("send mbedtls_ssl_write() -> %d\n" , ret );
352
360
if (ret >= 0 ) {
353
361
DEBUG_PRINT ("returning %d\n" , ret );
354
362
return ret ;
355
363
}
364
+ if (ret == MBEDTLS_ERR_SSL_WANT_READ ) {
365
+ self -> poll_mask = MP_STREAM_POLL_RD ;
366
+ }
356
367
DEBUG_PRINT ("raising errno [error case] %d\n" , ret );
357
368
mbedtls_raise_error (ret );
358
369
}
@@ -449,10 +460,29 @@ void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, mp_obj_t tim
449
460
ssl_socket_settimeout (self , timeout_obj );
450
461
}
451
462
452
- static bool poll_common (ssl_sslsocket_obj_t * self , int mode ) {
463
+ static bool poll_common (ssl_sslsocket_obj_t * self , uintptr_t arg ) {
464
+ // Take into account that the library might have buffered data already
465
+ int has_pending = 0 ;
466
+ if (arg & MP_STREAM_POLL_RD ) {
467
+ has_pending = mbedtls_ssl_check_pending (& self -> ssl );
468
+ if (has_pending ) {
469
+ // Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
470
+ return true;
471
+ }
472
+ }
473
+
474
+ // If the library signaled us that it needs reading or writing, only
475
+ // check that direction
476
+ if (self -> poll_mask && (arg & MP_STREAM_POLL_RDWR )) {
477
+ arg = (arg & ~MP_STREAM_POLL_RDWR ) | self -> poll_mask ;
478
+ }
479
+
480
+ // If direction the library needed is available, return a fake
481
+ // result to the caller so that it reenters a read or a write to
482
+ // allow the handshake to progress
453
483
const mp_stream_p_t * stream_p = mp_get_stream_raise (self -> sock_obj , MP_STREAM_OP_IOCTL );
454
484
int errcode ;
455
- mp_int_t ret = stream_p -> ioctl (self -> sock_obj , MP_STREAM_POLL , mode , & errcode );
485
+ mp_int_t ret = stream_p -> ioctl (self -> sock_obj , MP_STREAM_POLL , arg , & errcode );
456
486
return ret != 0 ;
457
487
}
458
488
0 commit comments