3434
3535#include "py/runtime.h"
3636#include "py/stream.h"
37+ #include "py/objstr.h"
3738
3839// mbedtls_time_t
3940#include "mbedtls/platform.h"
4344#include "mbedtls/entropy.h"
4445#include "mbedtls/ctr_drbg.h"
4546#include "mbedtls/debug.h"
47+ #include "mbedtls/error.h"
4648
4749typedef struct _mp_obj_ssl_socket_t {
4850 mp_obj_base_t base ;
@@ -74,8 +76,48 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons
7476}
7577#endif
7678
79+ STATIC NORETURN void mbedtls_raise_error (int err ) {
80+ // _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the
81+ // underlying socket into negative codes to pass them through mbedtls. Here we turn them
82+ // positive again so they get interpreted as the OSError they really are. The
83+ // cut-off of -256 is a bit hacky, sigh.
84+ if (err < 0 && err > -256 ) {
85+ mp_raise_OSError (- err );
86+ }
87+
88+ #if defined(MBEDTLS_ERROR_C )
89+ // Including mbedtls_strerror takes about 1.5KB due to the error strings.
90+ // MBEDTLS_ERROR_C is the define used by mbedtls to conditionally include mbedtls_strerror.
91+ // It is set/unset in the MBEDTLS_CONFIG_FILE which is defined in the Makefile.
92+
93+ // Try to allocate memory for the message
94+ #define ERR_STR_MAX 80 // mbedtls_strerror truncates if it doesn't fit
95+ mp_obj_str_t * o_str = m_new_obj_maybe (mp_obj_str_t );
96+ byte * o_str_buf = m_new_maybe (byte , ERR_STR_MAX );
97+ if (o_str == NULL || o_str_buf == NULL ) {
98+ mp_raise_OSError (err );
99+ }
100+
101+ // print the error message into the allocated buffer
102+ mbedtls_strerror (err , (char * )o_str_buf , ERR_STR_MAX );
103+ size_t len = strlen ((char * )o_str_buf );
104+
105+ // Put the exception object together
106+ o_str -> base .type = & mp_type_str ;
107+ o_str -> data = o_str_buf ;
108+ o_str -> len = len ;
109+ o_str -> hash = qstr_compute_hash (o_str -> data , o_str -> len );
110+ // raise
111+ mp_obj_t args [2 ] = { MP_OBJ_NEW_SMALL_INT (err ), MP_OBJ_FROM_PTR (o_str )};
112+ nlr_raise (mp_obj_exception_make_new (& mp_type_OSError , 2 , 0 , args ));
113+ #else
114+ // mbedtls is compiled without error strings so we simply return the err number
115+ mp_raise_OSError (err ); // err is typically a large negative number
116+ #endif
117+ }
118+
77119STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
78- mp_obj_t sock = * (mp_obj_t * )ctx ;
120+ mp_obj_t sock = * (mp_obj_t * )ctx ;
79121
80122 const mp_stream_p_t * sock_stream = mp_get_stream (sock );
81123 int err ;
@@ -85,14 +127,14 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
85127 if (mp_is_nonblocking_error (err )) {
86128 return MBEDTLS_ERR_SSL_WANT_WRITE ;
87129 }
88- return - err ;
130+ return - err ; // convert an MP_ERRNO to something mbedtls passes through as error
89131 } else {
90132 return out_sz ;
91133 }
92134}
93135
94136STATIC int _mbedtls_ssl_recv (void * ctx , byte * buf , size_t len ) {
95- mp_obj_t sock = * (mp_obj_t * )ctx ;
137+ mp_obj_t sock = * (mp_obj_t * )ctx ;
96138
97139 const mp_stream_p_t * sock_stream = mp_get_stream (sock );
98140 int err ;
@@ -113,11 +155,11 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
113155 // Verify the socket object has the full stream protocol
114156 mp_get_stream_raise (sock , MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL );
115157
116- #if MICROPY_PY_USSL_FINALISER
158+ #if MICROPY_PY_USSL_FINALISER
117159 mp_obj_ssl_socket_t * o = m_new_obj_with_finaliser (mp_obj_ssl_socket_t );
118- #else
160+ #else
119161 mp_obj_ssl_socket_t * o = m_new_obj (mp_obj_ssl_socket_t );
120- #endif
162+ #endif
121163 o -> base .type = & ussl_socket_type ;
122164 o -> sock = sock ;
123165
@@ -141,9 +183,9 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
141183 }
142184
143185 ret = mbedtls_ssl_config_defaults (& o -> conf ,
144- args -> server_side .u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT ,
145- MBEDTLS_SSL_TRANSPORT_STREAM ,
146- MBEDTLS_SSL_PRESET_DEFAULT );
186+ args -> server_side .u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT ,
187+ MBEDTLS_SSL_TRANSPORT_STREAM ,
188+ MBEDTLS_SSL_PRESET_DEFAULT );
147189 if (ret != 0 ) {
148190 goto cleanup ;
149191 }
@@ -171,7 +213,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
171213
172214 if (args -> key .u_obj != mp_const_none ) {
173215 size_t key_len ;
174- const byte * key = (const byte * )mp_obj_str_get_data (args -> key .u_obj , & key_len );
216+ const byte * key = (const byte * )mp_obj_str_get_data (args -> key .u_obj , & key_len );
175217 // len should include terminating null
176218 ret = mbedtls_pk_parse_key (& o -> pkey , key , key_len + 1 , NULL , 0 );
177219 if (ret != 0 ) {
@@ -180,7 +222,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
180222 }
181223
182224 size_t cert_len ;
183- const byte * cert = (const byte * )mp_obj_str_get_data (args -> cert .u_obj , & cert_len );
225+ const byte * cert = (const byte * )mp_obj_str_get_data (args -> cert .u_obj , & cert_len );
184226 // len should include terminating null
185227 ret = mbedtls_x509_crt_parse (& o -> cert , cert , cert_len + 1 );
186228 if (ret != 0 ) {
@@ -197,7 +239,6 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
197239 if (args -> do_handshake .u_bool ) {
198240 while ((ret = mbedtls_ssl_handshake (& o -> ssl )) != 0 ) {
199241 if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE ) {
200- printf ("mbedtls_ssl_handshake error: -%x\n" , - ret );
201242 goto cleanup ;
202243 }
203244 }
@@ -217,11 +258,11 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
217258 if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED ) {
218259 mp_raise_OSError (MP_ENOMEM );
219260 } else if (ret == MBEDTLS_ERR_PK_BAD_INPUT_DATA ) {
220- mp_raise_ValueError ("invalid key" );
261+ mp_raise_ValueError (MP_ERROR_TEXT ( "invalid key" ) );
221262 } else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA ) {
222- mp_raise_ValueError ("invalid cert" );
263+ mp_raise_ValueError (MP_ERROR_TEXT ( "invalid cert" ) );
223264 } else {
224- mp_raise_OSError ( MP_EIO );
265+ mbedtls_raise_error ( ret );
225266 }
226267}
227268
@@ -230,7 +271,7 @@ STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) {
230271 if (!mp_obj_is_true (binary_form )) {
231272 mp_raise_NotImplementedError (NULL );
232273 }
233- const mbedtls_x509_crt * peer_cert = mbedtls_ssl_get_peer_cert (& o -> ssl );
274+ const mbedtls_x509_crt * peer_cert = mbedtls_ssl_get_peer_cert (& o -> ssl );
234275 if (peer_cert == NULL ) {
235276 return mp_const_none ;
236277 }
@@ -318,9 +359,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
318359 { MP_ROM_QSTR (MP_QSTR_write ), MP_ROM_PTR (& mp_stream_write_obj ) },
319360 { MP_ROM_QSTR (MP_QSTR_setblocking ), MP_ROM_PTR (& socket_setblocking_obj ) },
320361 { MP_ROM_QSTR (MP_QSTR_close ), MP_ROM_PTR (& mp_stream_close_obj ) },
321- #if MICROPY_PY_USSL_FINALISER
362+ #if MICROPY_PY_USSL_FINALISER
322363 { MP_ROM_QSTR (MP_QSTR___del__ ), MP_ROM_PTR (& mp_stream_close_obj ) },
323- #endif
364+ #endif
324365 { MP_ROM_QSTR (MP_QSTR_getpeercert ), MP_ROM_PTR (& mod_ssl_getpeercert_obj ) },
325366};
326367
@@ -340,7 +381,7 @@ STATIC const mp_obj_type_t ussl_socket_type = {
340381 .getiter = NULL ,
341382 .iternext = NULL ,
342383 .protocol = & ussl_socket_stream_p ,
343- .locals_dict = (void * )& ussl_socket_locals_dict ,
384+ .locals_dict = (void * )& ussl_socket_locals_dict ,
344385};
345386
346387STATIC mp_obj_t mod_ssl_wrap_socket (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
@@ -358,7 +399,7 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_
358399
359400 struct ssl_args args ;
360401 mp_arg_parse_all (n_args - 1 , pos_args + 1 , kw_args ,
361- MP_ARRAY_SIZE (allowed_args ), allowed_args , (mp_arg_val_t * )& args );
402+ MP_ARRAY_SIZE (allowed_args ), allowed_args , (mp_arg_val_t * )& args );
362403
363404 return MP_OBJ_FROM_PTR (socket_new (sock , & args ));
364405}
@@ -373,7 +414,7 @@ STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
373414
374415const mp_obj_module_t mp_module_ussl = {
375416 .base = { & mp_type_module },
376- .globals = (mp_obj_dict_t * )& mp_module_ssl_globals ,
417+ .globals = (mp_obj_dict_t * )& mp_module_ssl_globals ,
377418};
378419
379420#endif // MICROPY_PY_USSL
0 commit comments