Skip to content

Commit 1455d59

Browse files
committed
Add safety checks around missed allocations for AsyncWebSocketMessageBuffer
1 parent ed538f9 commit 1455d59

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

src/AsyncWebSocket.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,21 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer()
135135
AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(uint8_t* data, size_t size)
136136
: _buffer(std::make_shared<std::vector<uint8_t>>(size))
137137
{
138-
std::memcpy(_buffer->data(), data, size);
138+
if (_buffer->capacity() < size) {
139+
_buffer.reset();
140+
_buffer = std::make_shared<std::vector<uint8_t>>(0);
141+
} else {
142+
std::memcpy(_buffer->data(), data, size);
143+
}
139144
}
140145

141146
AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(size_t size)
142147
: _buffer(std::make_shared<std::vector<uint8_t>>(size))
143148
{
149+
if (_buffer->capacity() < size) {
150+
_buffer.reset();
151+
_buffer = std::make_shared<std::vector<uint8_t>>(0);
152+
}
144153
}
145154

146155
AsyncWebSocketMessageBuffer::~AsyncWebSocketMessageBuffer()
@@ -443,6 +452,9 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
443452
if (!_client)
444453
return;
445454

455+
if (buffer->size() == 0)
456+
return;
457+
446458
{
447459
AsyncWebLockGuard l(_lock);
448460
if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES)
@@ -687,8 +699,10 @@ std::shared_ptr<std::vector<uint8_t>> makeSharedBuffer(const uint8_t *message, s
687699

688700
void AsyncWebSocketClient::text(AsyncWebSocketMessageBuffer * buffer)
689701
{
690-
text(std::move(buffer->_buffer));
691-
delete buffer;
702+
if (buffer) {
703+
text(std::move(buffer->_buffer));
704+
delete buffer;
705+
}
692706
}
693707

694708
void AsyncWebSocketClient::text(std::shared_ptr<std::vector<uint8_t>> buffer)
@@ -739,8 +753,10 @@ void AsyncWebSocketClient::text(const __FlashStringHelper *data)
739753

740754
void AsyncWebSocketClient::binary(AsyncWebSocketMessageBuffer * buffer)
741755
{
742-
binary(std::move(buffer->_buffer));
743-
delete buffer;
756+
if (buffer) {
757+
binary(std::move(buffer->_buffer));
758+
delete buffer;
759+
}
744760
}
745761

746762
void AsyncWebSocketClient::binary(std::shared_ptr<std::vector<uint8_t>> buffer)
@@ -936,8 +952,10 @@ void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *data)
936952

937953
void AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer * buffer)
938954
{
939-
textAll(std::move(buffer->_buffer));
940-
delete buffer;
955+
if (buffer) {
956+
textAll(std::move(buffer->_buffer));
957+
delete buffer;
958+
}
941959
}
942960

943961
void AsyncWebSocket::textAll(std::shared_ptr<std::vector<uint8_t>> buffer)
@@ -1014,8 +1032,10 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t
10141032

10151033
void AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer * buffer)
10161034
{
1017-
binaryAll(std::move(buffer->_buffer));
1018-
delete buffer;
1035+
if (buffer) {
1036+
binaryAll(std::move(buffer->_buffer));
1037+
delete buffer;
1038+
}
10191039
}
10201040

10211041
void AsyncWebSocket::binaryAll(std::shared_ptr<std::vector<uint8_t>> buffer)
@@ -1200,12 +1220,26 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
12001220

12011221
AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(size_t size)
12021222
{
1203-
return new AsyncWebSocketMessageBuffer(size);
1223+
AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(size);
1224+
if (buffer->length() != size)
1225+
{
1226+
delete buffer;
1227+
return nullptr;
1228+
} else {
1229+
return buffer;
1230+
}
12041231
}
12051232

12061233
AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(uint8_t * data, size_t size)
12071234
{
1208-
return new AsyncWebSocketMessageBuffer(data, size);
1235+
AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(data, size);
1236+
if (buffer->length() != size)
1237+
{
1238+
delete buffer;
1239+
return nullptr;
1240+
} else {
1241+
return buffer;
1242+
}
12091243
}
12101244

12111245
/*

0 commit comments

Comments
 (0)