Skip to content

Commit d50079c

Browse files
authored
Merge pull request #5716 from timhawes/ssl_improvements
SSL improvements
2 parents 0292622 + 54e87d3 commit d50079c

File tree

5 files changed

+111
-9
lines changed

5 files changed

+111
-9
lines changed

ports/espressif/common-hal/ssl/SSLContext.c

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
#include "bindings/espidf/__init__.h"
3131

32+
#include "components/mbedtls/esp_crt_bundle/include/esp_crt_bundle.h"
33+
3234
#include "py/runtime.h"
3335

3436
void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t *self) {
@@ -47,6 +49,11 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
4749
sock->ssl_context = self;
4850
sock->sock = socket;
4951

52+
// Create a copy of the ESP-TLS config object and store the server hostname
53+
// Note that ESP-TLS will use common_name for both SNI and verification
54+
memcpy(&sock->ssl_config, &self->ssl_config, sizeof(self->ssl_config));
55+
sock->ssl_config.common_name = server_hostname;
56+
5057
esp_tls_t *tls_handle = esp_tls_init();
5158
if (tls_handle == NULL) {
5259
mp_raise_espidf_MemoryError();
@@ -55,6 +62,28 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
5562

5663
// TODO: do something with the original socket? Don't call a close on the internal LWIP.
5764

58-
// Should we store server hostname on the socket in case connect is called with an ip?
5965
return sock;
6066
}
67+
68+
void common_hal_ssl_sslcontext_load_verify_locations(ssl_sslcontext_obj_t *self,
69+
const char *cadata) {
70+
self->ssl_config.crt_bundle_attach = NULL;
71+
self->ssl_config.use_global_ca_store = false;
72+
self->ssl_config.cacert_buf = (const unsigned char *)cadata;
73+
self->ssl_config.cacert_bytes = strlen(cadata) + 1;
74+
}
75+
76+
void common_hal_ssl_sslcontext_set_default_verify_paths(ssl_sslcontext_obj_t *self) {
77+
self->ssl_config.crt_bundle_attach = esp_crt_bundle_attach;
78+
self->ssl_config.use_global_ca_store = true;
79+
self->ssl_config.cacert_buf = NULL;
80+
self->ssl_config.cacert_bytes = 0;
81+
}
82+
83+
bool common_hal_ssl_sslcontext_get_check_hostname(ssl_sslcontext_obj_t *self) {
84+
return !self->ssl_config.skip_common_name;
85+
}
86+
87+
void common_hal_ssl_sslcontext_set_check_hostname(ssl_sslcontext_obj_t *self, bool value) {
88+
self->ssl_config.skip_common_name = !value;
89+
}

ports/espressif/common-hal/ssl/SSLSocket.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self) {
5555

5656
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self,
5757
const char *host, size_t hostlen, uint32_t port) {
58-
esp_tls_cfg_t *tls_config = NULL;
59-
tls_config = &self->ssl_context->ssl_config;
60-
int result = esp_tls_conn_new_sync(host, hostlen, port, tls_config, self->tls);
58+
int result = esp_tls_conn_new_sync(host, hostlen, port, &self->ssl_config, self->tls);
6159
self->sock->connected = result >= 0;
6260
if (result < 0) {
6361
int esp_tls_code;

ports/espressif/common-hal/ssl/SSLSocket.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ typedef struct {
3939
socketpool_socket_obj_t *sock;
4040
esp_tls_t *tls;
4141
ssl_sslcontext_obj_t *ssl_context;
42+
esp_tls_cfg_t ssl_config;
4243
} ssl_sslsocket_obj_t;
4344

4445
#endif // MICROPY_INCLUDED_ESPRESSIF_COMMON_HAL_SSL_SSLSOCKET_H

shared-bindings/ssl/SSLContext.c

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include "py/objtuple.h"
3131
#include "py/objlist.h"
32+
#include "py/objproperty.h"
3233
#include "py/runtime.h"
3334
#include "py/mperrno.h"
3435

@@ -51,10 +52,69 @@ STATIC mp_obj_t ssl_sslcontext_make_new(const mp_obj_type_t *type, size_t n_args
5152
return MP_OBJ_FROM_PTR(s);
5253
}
5354

54-
//| def wrap_socket(sock: socketpool.Socket, *, server_side: bool = False, server_hostname: Optional[str] = None) -> ssl.SSLSocket:
55-
//| """Wraps the socket into a socket-compatible class that handles SSL negotiation.
56-
//| The socket must be of type SOCK_STREAM."""
57-
//| ...
55+
//| def load_verify_locations(self, cadata: Optional[str] = None) -> None:
56+
//| """Load a set of certification authority (CA) certificates used to validate
57+
//| other peers' certificates."""
58+
//|
59+
60+
STATIC mp_obj_t ssl_sslcontext_load_verify_locations(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
61+
enum { ARG_cadata };
62+
static const mp_arg_t allowed_args[] = {
63+
{ MP_QSTR_cadata, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = mp_const_none} },
64+
};
65+
ssl_sslcontext_obj_t *self = MP_OBJ_TO_PTR(pos_args[0]);
66+
67+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
68+
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
69+
70+
const char *cadata = mp_obj_str_get_str(args[ARG_cadata].u_obj);
71+
72+
common_hal_ssl_sslcontext_load_verify_locations(self, cadata);
73+
return mp_const_none;
74+
}
75+
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_sslcontext_load_verify_locations_obj, 1, ssl_sslcontext_load_verify_locations);
76+
77+
//| def set_default_verify_paths(self) -> None:
78+
//| """Load a set of default certification authority (CA) certificates."""
79+
//|
80+
81+
STATIC mp_obj_t ssl_sslcontext_set_default_verify_paths(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
82+
ssl_sslcontext_obj_t *self = MP_OBJ_TO_PTR(pos_args[0]);
83+
84+
common_hal_ssl_sslcontext_set_default_verify_paths(self);
85+
return mp_const_none;
86+
}
87+
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_sslcontext_set_default_verify_paths_obj, 1, ssl_sslcontext_set_default_verify_paths);
88+
89+
//| check_hostname: bool
90+
//| """Whether to match the peer certificate's hostname."""
91+
//|
92+
93+
STATIC mp_obj_t ssl_sslcontext_get_check_hostname(mp_obj_t self_in) {
94+
ssl_sslcontext_obj_t *self = MP_OBJ_TO_PTR(self_in);
95+
96+
return mp_obj_new_bool(common_hal_ssl_sslcontext_get_check_hostname(self));
97+
}
98+
STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_sslcontext_get_check_hostname_obj, ssl_sslcontext_get_check_hostname);
99+
100+
STATIC mp_obj_t ssl_sslcontext_set_check_hostname(mp_obj_t self_in, mp_obj_t value) {
101+
ssl_sslcontext_obj_t *self = MP_OBJ_TO_PTR(self_in);
102+
103+
common_hal_ssl_sslcontext_set_check_hostname(self, mp_obj_is_true(value));
104+
return mp_const_none;
105+
}
106+
STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslcontext_set_check_hostname_obj, ssl_sslcontext_set_check_hostname);
107+
108+
const mp_obj_property_t ssl_sslcontext_check_hostname_obj = {
109+
.base.type = &mp_type_property,
110+
.proxy = {(mp_obj_t)&ssl_sslcontext_get_check_hostname_obj,
111+
(mp_obj_t)&ssl_sslcontext_set_check_hostname_obj,
112+
MP_ROM_NONE},
113+
};
114+
115+
//| def wrap_socket(self, sock: socketpool.Socket, *, server_side: bool = False, server_hostname: Optional[str] = None) -> ssl.SSLSocket:
116+
//| """Wraps the socket into a socket-compatible class that handles SSL negotiation.
117+
//| The socket must be of type SOCK_STREAM."""
58118
//|
59119

60120
STATIC mp_obj_t ssl_sslcontext_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
@@ -69,7 +129,10 @@ STATIC mp_obj_t ssl_sslcontext_wrap_socket(size_t n_args, const mp_obj_t *pos_ar
69129
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
70130
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
71131

72-
const char *server_hostname = mp_obj_str_get_str(args[ARG_server_hostname].u_obj);
132+
const char *server_hostname = NULL;
133+
if (args[ARG_server_hostname].u_obj != mp_const_none) {
134+
server_hostname = mp_obj_str_get_str(args[ARG_server_hostname].u_obj);
135+
}
73136
bool server_side = args[ARG_server_side].u_bool;
74137
if (server_side && server_hostname != NULL) {
75138
mp_raise_ValueError(translate("Server side context cannot have hostname"));
@@ -83,6 +146,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_sslcontext_wrap_socket_obj, 1, ssl_sslcont
83146

84147
STATIC const mp_rom_map_elem_t ssl_sslcontext_locals_dict_table[] = {
85148
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_sslcontext_wrap_socket_obj) },
149+
{ MP_ROM_QSTR(MP_QSTR_load_verify_locations), MP_ROM_PTR(&ssl_sslcontext_load_verify_locations_obj) },
150+
{ MP_ROM_QSTR(MP_QSTR_set_default_verify_paths), MP_ROM_PTR(&ssl_sslcontext_set_default_verify_paths_obj) },
151+
{ MP_ROM_QSTR(MP_QSTR_check_hostname), MP_ROM_PTR(&ssl_sslcontext_check_hostname_obj) },
86152
};
87153

88154
STATIC MP_DEFINE_CONST_DICT(ssl_sslcontext_locals_dict, ssl_sslcontext_locals_dict_table);

shared-bindings/ssl/SSLContext.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,12 @@ void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t *self);
3939
ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t *self,
4040
socketpool_socket_obj_t *sock, bool server_side, const char *server_hostname);
4141

42+
void common_hal_ssl_sslcontext_load_verify_locations(ssl_sslcontext_obj_t *self,
43+
const char *cadata);
44+
45+
void common_hal_ssl_sslcontext_set_default_verify_paths(ssl_sslcontext_obj_t *self);
46+
47+
bool common_hal_ssl_sslcontext_get_check_hostname(ssl_sslcontext_obj_t *self);
48+
void common_hal_ssl_sslcontext_set_check_hostname(ssl_sslcontext_obj_t *self, bool value);
49+
4250
#endif // MICROPY_INCLUDED_SHARED_BINDINGS_SSL_SSLCONTEXT_H

0 commit comments

Comments
 (0)