Skip to content

Commit 085f278

Browse files
committed
Fix race conditions in the SSE and WS code
1 parent b29465d commit 085f278

File tree

4 files changed

+57
-20
lines changed

4 files changed

+57
-20
lines changed

src/AsyncEventSource.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -192,23 +192,24 @@ AsyncEventSourceClient::AsyncEventSourceClient(AsyncWebServerRequest *request, A
192192

193193
AsyncEventSourceClient::~AsyncEventSourceClient() {
194194
#ifdef ESP32
195+
// Protect message queue access (size checks and modifications) which is not thread-safe.
195196
std::lock_guard<std::recursive_mutex> lock(_lockmq);
196197
#endif
197198
_messageQueue.clear();
198199
close();
199200
}
200201

201202
bool AsyncEventSourceClient::_queueMessage(const char *message, size_t len) {
203+
#ifdef ESP32
204+
// Protect message queue access (size checks and modifications) which is not thread-safe.
205+
std::lock_guard<std::recursive_mutex> lock(_lockmq);
206+
#endif
207+
202208
if (_messageQueue.size() >= SSE_MAX_QUEUED_MESSAGES) {
203209
async_ws_log_w("Event message queue overflow: discard message");
204210
return false;
205211
}
206212

207-
#ifdef ESP32
208-
// length() is not thread-safe, thus acquiring the lock before this call..
209-
std::lock_guard<std::recursive_mutex> lock(_lockmq);
210-
#endif
211-
212213
if (_client) {
213214
_messageQueue.emplace_back(message, len);
214215
} else {
@@ -230,16 +231,16 @@ bool AsyncEventSourceClient::_queueMessage(const char *message, size_t len) {
230231
}
231232

232233
bool AsyncEventSourceClient::_queueMessage(AsyncEvent_SharedData_t &&msg) {
234+
#ifdef ESP32
235+
// Protect message queue access (size checks and modifications) which is not thread-safe.
236+
std::lock_guard<std::recursive_mutex> lock(_lockmq);
237+
#endif
238+
233239
if (_messageQueue.size() >= SSE_MAX_QUEUED_MESSAGES) {
234240
async_ws_log_w("Event message queue overflow: discard message");
235241
return false;
236242
}
237243

238-
#ifdef ESP32
239-
// length() is not thread-safe, thus acquiring the lock before this call..
240-
std::lock_guard<std::recursive_mutex> lock(_lockmq);
241-
#endif
242-
243244
if (_client) {
244245
_messageQueue.emplace_back(std::move(msg));
245246
} else {
@@ -261,7 +262,7 @@ bool AsyncEventSourceClient::_queueMessage(AsyncEvent_SharedData_t &&msg) {
261262

262263
void AsyncEventSourceClient::_onAck(size_t len __attribute__((unused)), uint32_t time __attribute__((unused))) {
263264
#ifdef ESP32
264-
// Same here, acquiring the lock early
265+
// Protect message queue access (size checks and modifications) which is not thread-safe.
265266
std::lock_guard<std::recursive_mutex> lock(_lockmq);
266267
#endif
267268

@@ -288,11 +289,11 @@ void AsyncEventSourceClient::_onAck(size_t len __attribute__((unused)), uint32_t
288289
}
289290

290291
void AsyncEventSourceClient::_onPoll() {
291-
if (_messageQueue.size()) {
292292
#ifdef ESP32
293-
// Same here, acquiring the lock early
294-
std::lock_guard<std::recursive_mutex> lock(_lockmq);
293+
// Protect message queue access (size checks and modifications) which is not thread-safe.
294+
std::lock_guard<std::recursive_mutex> lock(_lockmq);
295295
#endif
296+
if (_messageQueue.size()) {
296297
_runQueue();
297298
}
298299
}
@@ -379,12 +380,12 @@ void AsyncEventSource::_addClient(AsyncEventSourceClient *client) {
379380
}
380381

381382
void AsyncEventSource::_handleDisconnect(AsyncEventSourceClient *client) {
382-
if (_disconnectcb) {
383-
_disconnectcb(client);
384-
}
385383
#ifdef ESP32
386384
std::lock_guard<std::recursive_mutex> lock(_client_queue_lock);
387385
#endif
386+
if (_disconnectcb) {
387+
_disconnectcb(client);
388+
}
388389
for (auto i = _clients.begin(); i != _clients.end(); ++i) {
389390
if (i->get() == client) {
390391
_clients.erase(i);

src/AsyncEventSource.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ class AsyncEventSourceClient {
205205
return _lastId;
206206
}
207207
size_t packetsWaiting() const {
208+
#ifdef ESP32
209+
std::lock_guard<std::recursive_mutex> lock(_lockmq);
210+
#endif
208211
return _messageQueue.size();
209212
};
210213

src/AsyncWebSocket.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,12 @@ void AsyncWebSocketClient::_clearQueue() {
309309
void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
310310
_lastMessageTime = millis();
311311

312-
async_ws_log_v("[%s][%" PRIu32 "] START ACK(%u, %" PRIu32 ") Q:%u", _server->url(), _clientId, len, time, _messageQueue.size());
313-
314312
#ifdef ESP32
315313
std::unique_lock<std::recursive_mutex> lock(_lock);
316314
#endif
317315

316+
async_ws_log_v("[%s][%" PRIu32 "] START ACK(%u, %" PRIu32 ") Q:%u", _server->url(), _clientId, len, time, _messageQueue.size());
317+
318318
if (!_controlQueue.empty()) {
319319
auto &head = _controlQueue.front();
320320
if (head.finished()) {
@@ -988,6 +988,9 @@ void AsyncWebSocket::_handleEvent(AsyncWebSocketClient *client, AwsEventType typ
988988
}
989989

990990
AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) {
991+
#ifdef ESP32
992+
std::lock_guard<std::recursive_mutex> lock(_lock);
993+
#endif
991994
_clients.emplace_back(request, this);
992995
// we've just detached AsyncTCP client from AsyncWebServerRequest
993996
_handleEvent(&_clients.back(), WS_EVT_CONNECT, request, NULL, 0);
@@ -997,6 +1000,9 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
9971000
}
9981001

9991002
void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
1003+
#ifdef ESP32
1004+
std::lock_guard<std::recursive_mutex> lock(_lock);
1005+
#endif
10001006
const auto client_id = client->id();
10011007
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [client_id](const AsyncWebSocketClient &c) {
10021008
return c.id() == client_id;
@@ -1007,12 +1013,18 @@ void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
10071013
}
10081014

10091015
bool AsyncWebSocket::availableForWriteAll() {
1016+
#ifdef ESP32
1017+
std::lock_guard<std::recursive_mutex> lock(_lock);
1018+
#endif
10101019
return std::none_of(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
10111020
return c.queueIsFull();
10121021
});
10131022
}
10141023

10151024
bool AsyncWebSocket::availableForWrite(uint32_t id) {
1025+
#ifdef ESP32
1026+
std::lock_guard<std::recursive_mutex> lock(_lock);
1027+
#endif
10161028
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [id](const AsyncWebSocketClient &c) {
10171029
return c.id() == id;
10181030
});
@@ -1023,12 +1035,18 @@ bool AsyncWebSocket::availableForWrite(uint32_t id) {
10231035
}
10241036

10251037
size_t AsyncWebSocket::count() const {
1038+
#ifdef ESP32
1039+
std::lock_guard<std::recursive_mutex> lock(_lock);
1040+
#endif
10261041
return std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
10271042
return c.status() == WS_CONNECTED;
10281043
});
10291044
}
10301045

10311046
AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) {
1047+
#ifdef ESP32
1048+
std::lock_guard<std::recursive_mutex> lock(_lock);
1049+
#endif
10321050
const auto iter = std::find_if(_clients.begin(), _clients.end(), [id](const AsyncWebSocketClient &c) {
10331051
return c.id() == id && c.status() == WS_CONNECTED;
10341052
});
@@ -1046,6 +1064,9 @@ void AsyncWebSocket::close(uint32_t id, uint16_t code, const char *message) {
10461064
}
10471065

10481066
void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
1067+
#ifdef ESP32
1068+
std::lock_guard<std::recursive_mutex> lock(_lock);
1069+
#endif
10491070
for (auto &c : _clients) {
10501071
if (c.status() == WS_CONNECTED) {
10511072
c.close(code, message);
@@ -1054,6 +1075,9 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
10541075
}
10551076

10561077
void AsyncWebSocket::cleanupClients(uint16_t maxClients) {
1078+
#ifdef ESP32
1079+
std::lock_guard<std::recursive_mutex> lock(_lock);
1080+
#endif
10571081
const size_t c = count();
10581082
if (c > maxClients) {
10591083
async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), _clients.front().id(), c, maxClients);
@@ -1074,6 +1098,9 @@ bool AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) {
10741098
}
10751099

10761100
AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t len) {
1101+
#ifdef ESP32
1102+
std::lock_guard<std::recursive_mutex> lock(_lock);
1103+
#endif
10771104
size_t hit = 0;
10781105
size_t miss = 0;
10791106
for (auto &c : _clients) {
@@ -1182,6 +1209,9 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer *
11821209
}
11831210

11841211
AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer buffer) {
1212+
#ifdef ESP32
1213+
std::lock_guard<std::recursive_mutex> lock(_lock);
1214+
#endif
11851215
size_t hit = 0;
11861216
size_t miss = 0;
11871217
for (auto &c : _clients) {
@@ -1271,6 +1301,9 @@ AsyncWebSocket::SendStatus AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer
12711301
return status;
12721302
}
12731303
AsyncWebSocket::SendStatus AsyncWebSocket::binaryAll(AsyncWebSocketSharedBuffer buffer) {
1304+
#ifdef ESP32
1305+
std::lock_guard<std::recursive_mutex> lock(_lock);
1306+
#endif
12741307
size_t hit = 0;
12751308
size_t miss = 0;
12761309
for (auto &c : _clients) {

src/AsyncWebSocket.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ class AsyncWebSocket : public AsyncWebHandler {
373373
AwsHandshakeHandler _handshakeHandler;
374374
bool _enabled;
375375
#ifdef ESP32
376-
mutable std::mutex _lock;
376+
mutable std::recursive_mutex _lock;
377377
#endif
378378

379379
public:

0 commit comments

Comments
 (0)