Skip to content

Commit b355a55

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents d0e42f1 + ed5591c commit b355a55

File tree

7 files changed

+145
-8
lines changed

7 files changed

+145
-8
lines changed

shared-bindings/audiomp3/MP3Decoder.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@
7272
//| decoder.file = stream
7373
//|
7474
//| If the stream is played with ``loop = True``, the loop will start at the beginning.
75+
//|
76+
//| It is possible to stream an mp3 from a socket, including a secure socket.
77+
//| The MP3Decoder may change the timeout and non-blocking status of the socket.
78+
//| Using a larger decode buffer with a stream can be helpful to avoid data underruns.
79+
//| An ``adafruit_requests`` request must be made with ``headers={"Connection": "close"}`` so
80+
//| that the socket closes when the stream ends.
7581
//| """
7682
//| ...
7783

shared-bindings/ssl/SSLSocket.c

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
#include <string.h>
1111

1212
#include "shared/runtime/context_manager_helpers.h"
13-
#include "py/objtuple.h"
13+
#include "py/mperrno.h"
1414
#include "py/objlist.h"
15+
#include "py/objtuple.h"
1516
#include "py/runtime.h"
16-
#include "py/mperrno.h"
17+
#include "py/stream.h"
1718

1819
#include "shared/netutils/netutils.h"
1920

@@ -247,9 +248,69 @@ static const mp_rom_map_elem_t ssl_sslsocket_locals_dict_table[] = {
247248

248249
static MP_DEFINE_CONST_DICT(ssl_sslsocket_locals_dict, ssl_sslsocket_locals_dict_table);
249250

251+
typedef mp_uint_t (*readwrite_func)(ssl_sslsocket_obj_t *, const uint8_t *, mp_uint_t);
252+
253+
static mp_int_t readwrite_common(mp_obj_t self_in, readwrite_func fn, const uint8_t *buf, size_t size, int *errorcode) {
254+
ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in);
255+
mp_int_t ret = -EIO;
256+
nlr_buf_t nlr;
257+
if (nlr_push(&nlr) == 0) {
258+
ret = fn(self, buf, size);
259+
nlr_pop();
260+
} else {
261+
mp_obj_t exc = MP_OBJ_FROM_PTR(nlr.ret_val);
262+
if (nlr_push(&nlr) == 0) {
263+
ret = -mp_obj_get_int(mp_load_attr(exc, MP_QSTR_errno));
264+
nlr_pop();
265+
}
266+
}
267+
if (ret < 0) {
268+
*errorcode = -ret;
269+
return MP_STREAM_ERROR;
270+
}
271+
return ret;
272+
}
273+
274+
static mp_uint_t sslsocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int *errorcode) {
275+
return readwrite_common(self_in, (readwrite_func)common_hal_ssl_sslsocket_recv_into, buf, size, errorcode);
276+
}
277+
278+
static mp_uint_t sslsocket_write(mp_obj_t self_in, const void *buf, mp_uint_t size, int *errorcode) {
279+
return readwrite_common(self_in, common_hal_ssl_sslsocket_send, buf, size, errorcode);
280+
}
281+
282+
static mp_uint_t sslsocket_ioctl(mp_obj_t self_in, mp_uint_t request, mp_uint_t arg, int *errcode) {
283+
ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in);
284+
mp_uint_t ret;
285+
if (request == MP_STREAM_POLL) {
286+
mp_uint_t flags = arg;
287+
ret = 0;
288+
if ((flags & MP_STREAM_POLL_RD) && common_hal_ssl_sslsocket_readable(self) > 0) {
289+
ret |= MP_STREAM_POLL_RD;
290+
}
291+
if ((flags & MP_STREAM_POLL_WR) && common_hal_ssl_sslsocket_writable(self)) {
292+
ret |= MP_STREAM_POLL_WR;
293+
}
294+
} else {
295+
*errcode = MP_EINVAL;
296+
ret = MP_STREAM_ERROR;
297+
}
298+
return ret;
299+
}
300+
301+
302+
static const mp_stream_p_t sslsocket_stream_p = {
303+
.read = sslsocket_read,
304+
.write = sslsocket_write,
305+
.ioctl = sslsocket_ioctl,
306+
.is_text = false,
307+
};
308+
309+
250310
MP_DEFINE_CONST_OBJ_TYPE(
251311
ssl_sslsocket_type,
252312
MP_QSTR_SSLSocket,
253313
MP_TYPE_FLAG_NONE,
254-
locals_dict, &ssl_sslsocket_locals_dict
314+
locals_dict, &ssl_sslsocket_locals_dict,
315+
protocol, &sslsocket_stream_p
255316
);

shared-bindings/ssl/SSLSocket.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self);
2020
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr);
2121
bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t *self);
2222
bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self);
23+
bool common_hal_ssl_sslsocket_readable(ssl_sslsocket_obj_t *self);
24+
bool common_hal_ssl_sslsocket_writable(ssl_sslsocket_obj_t *self);
2325
void common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t *self, int backlog);
24-
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len);
25-
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len);
26+
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, mp_uint_t len);
27+
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, mp_uint_t len);
2628
void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, mp_obj_t timeout_obj);
2729
void common_hal_ssl_sslsocket_setsockopt(ssl_sslsocket_obj_t *self, mp_obj_t level, mp_obj_t optname, mp_obj_t optval);

shared-module/audiomp3/MP3Decoder.c

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ static off_t stream_lseek(void *stream, off_t offset, int whence) {
9595
#define INPUT_BUFFER_CONSUME(i, n) ((i).read_off += (n))
9696
#define INPUT_BUFFER_CLEAR(i) ((i).read_off = (i).write_off = 0)
9797

98+
static void stream_set_blocking(audiomp3_mp3file_obj_t *self, bool block_ok) {
99+
if (!self->settimeout_args[0]) {
100+
return;
101+
}
102+
if (block_ok == self->block_ok) {
103+
return;
104+
}
105+
self->block_ok = block_ok;
106+
self->settimeout_args[2] = block_ok ? mp_const_none : mp_obj_new_int(0);
107+
mp_call_method_n_kw(1, 0, self->settimeout_args);
108+
}
109+
98110
/** Fill the input buffer unconditionally.
99111
*
100112
* Returns true if the input buffer contains any useful data,
@@ -110,6 +122,8 @@ static bool mp3file_update_inbuf_always(audiomp3_mp3file_obj_t *self, bool block
110122
return INPUT_BUFFER_AVAILABLE(self->inbuf) > 0;
111123
}
112124

125+
stream_set_blocking(self, block_ok);
126+
113127
// We didn't previously reach EOF and we have input buffer space available
114128

115129
// Move the unconsumed portion of the buffer to the start
@@ -119,7 +133,7 @@ static bool mp3file_update_inbuf_always(audiomp3_mp3file_obj_t *self, bool block
119133
self->inbuf.read_off = 0;
120134
}
121135

122-
for (size_t to_read; !self->eof && (to_read = INPUT_BUFFER_SPACE(self->inbuf)) > 0 && (block_ok || stream_readable(self->stream));) {
136+
for (size_t to_read; !self->eof && (to_read = INPUT_BUFFER_SPACE(self->inbuf)) > 0;) {
123137
uint8_t *write_ptr = self->inbuf.buf + self->inbuf.write_off;
124138
ssize_t n_read = stream_read(self->stream, write_ptr, to_read);
125139

@@ -328,9 +342,14 @@ void common_hal_audiomp3_mp3file_set_file(audiomp3_mp3file_obj_t *self, mp_obj_t
328342
background_callback_prevent();
329343

330344
self->stream = stream;
345+
mp_load_method_maybe(stream, MP_QSTR_settimeout, self->settimeout_args);
331346

332347
INPUT_BUFFER_CLEAR(self->inbuf);
333348
self->eof = 0;
349+
350+
self->block_ok = false;
351+
stream_set_blocking(self, true);
352+
334353
self->other_channel = -1;
335354
mp3file_update_inbuf_half(self, true);
336355
mp3file_find_sync_word(self, true);
@@ -365,6 +384,7 @@ void common_hal_audiomp3_mp3file_deinit(audiomp3_mp3file_obj_t *self) {
365384
self->pcm_buffer[0] = NULL;
366385
self->pcm_buffer[1] = NULL;
367386
self->stream = mp_const_none;
387+
self->settimeout_args[0] = MP_OBJ_NULL;
368388
self->samples_decoded = 0;
369389
}
370390

shared-module/audiomp3/MP3Decoder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ typedef struct {
3535
uint8_t buffer_index;
3636
uint8_t channel_count;
3737
bool eof;
38+
bool block_ok;
39+
mp_obj_t settimeout_args[3];
3840

3941
int8_t other_channel;
4042
int8_t other_buffer_index;

shared-module/ssl/SSLSocket.c

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#include "mbedtls/version.h"
2424

25+
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
26+
2527
#if defined(MBEDTLS_ERROR_C)
2628
#include "../../lib/mbedtls_errors/mp_mbedtls_errors.c"
2729
#endif
@@ -220,6 +222,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
220222
o->base.type = &ssl_sslsocket_type;
221223
o->ssl_context = self;
222224
o->sock_obj = socket;
225+
o->poll_mask = 0;
223226

224227
mp_load_method(socket, MP_QSTR_accept, o->accept_args);
225228
mp_load_method(socket, MP_QSTR_bind, o->bind_args);
@@ -330,7 +333,8 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
330333
}
331334
}
332335

333-
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len) {
336+
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, mp_uint_t len) {
337+
self->poll_mask = 0;
334338
int ret = mbedtls_ssl_read(&self->ssl, buf, len);
335339
DEBUG_PRINT("recv_into mbedtls_ssl_read() -> %d\n", ret);
336340
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
@@ -342,17 +346,24 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t
342346
DEBUG_PRINT("returning %d\n", ret);
343347
return ret;
344348
}
349+
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
350+
self->poll_mask = MP_STREAM_POLL_WR;
351+
}
345352
DEBUG_PRINT("raising errno [error case] %d\n", ret);
346353
mbedtls_raise_error(ret);
347354
}
348355

349-
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len) {
356+
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, mp_uint_t len) {
357+
self->poll_mask = 0;
350358
int ret = mbedtls_ssl_write(&self->ssl, buf, len);
351359
DEBUG_PRINT("send mbedtls_ssl_write() -> %d\n", ret);
352360
if (ret >= 0) {
353361
DEBUG_PRINT("returning %d\n", ret);
354362
return ret;
355363
}
364+
if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
365+
self->poll_mask = MP_STREAM_POLL_RD;
366+
}
356367
DEBUG_PRINT("raising errno [error case] %d\n", ret);
357368
mbedtls_raise_error(ret);
358369
}
@@ -448,3 +459,37 @@ void common_hal_ssl_sslsocket_setsockopt(ssl_sslsocket_obj_t *self, mp_obj_t lev
448459
void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, mp_obj_t timeout_obj) {
449460
ssl_socket_settimeout(self, timeout_obj);
450461
}
462+
463+
static bool poll_common(ssl_sslsocket_obj_t *self, uintptr_t arg) {
464+
// Take into account that the library might have buffered data already
465+
int has_pending = 0;
466+
if (arg & MP_STREAM_POLL_RD) {
467+
has_pending = mbedtls_ssl_check_pending(&self->ssl);
468+
if (has_pending) {
469+
// Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
470+
return true;
471+
}
472+
}
473+
474+
// If the library signaled us that it needs reading or writing, only
475+
// check that direction
476+
if (self->poll_mask && (arg & MP_STREAM_POLL_RDWR)) {
477+
arg = (arg & ~MP_STREAM_POLL_RDWR) | self->poll_mask;
478+
}
479+
480+
// If direction the library needed is available, return a fake
481+
// result to the caller so that it reenters a read or a write to
482+
// allow the handshake to progress
483+
const mp_stream_p_t *stream_p = mp_get_stream_raise(self->sock_obj, MP_STREAM_OP_IOCTL);
484+
int errcode;
485+
mp_int_t ret = stream_p->ioctl(self->sock_obj, MP_STREAM_POLL, arg, &errcode);
486+
return ret != 0;
487+
}
488+
489+
bool common_hal_ssl_sslsocket_readable(ssl_sslsocket_obj_t *self) {
490+
return poll_common(self, MP_STREAM_POLL_RD);
491+
}
492+
493+
bool common_hal_ssl_sslsocket_writable(ssl_sslsocket_obj_t *self) {
494+
return poll_common(self, MP_STREAM_POLL_WR);
495+
}

shared-module/ssl/SSLSocket.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ typedef struct ssl_sslsocket_obj {
2929
mbedtls_x509_crt cacert;
3030
mbedtls_x509_crt cert;
3131
mbedtls_pk_context pkey;
32+
uintptr_t poll_mask;
3233
bool closed;
3334
mp_obj_t accept_args[2];
3435
mp_obj_t bind_args[3];

0 commit comments

Comments
 (0)