Skip to content

Commit 54d5022

Browse files
Rowan Rodrik van der Molenhalfgaar
authored andcommitted
Introduce RAII wrapper for SSL struct pointers
Under Wiebe's direct supervision, I implemented `FmqSsl`, as a wrapper around OpenSSL's `SSL*` pointer. Besides being safer in principle, this immediately allowed us to remove all `SSL*` resource management from the `IOWrapper` destructor (which was all the destructor did).
1 parent b2546f7 commit 54d5022

15 files changed

+143
-71
lines changed

CMakeLists.shared

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,5 @@ set(FLASHMQ_IMPLS
149149
${RELPATH}fdmanaged.cpp
150150
${RELPATH}http.cpp
151151
${RELPATH}fmqsockaddr.cpp
152+
${RELPATH}fmqssl.cpp
152153
)

FlashMQTests/sharedsubscriptionstests.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ void MainTests::testSharedSubscribersUnit()
1515

1616
ThreadGlobals::assignThreadData(t);
1717

18-
std::shared_ptr<Client> c1(new Client(ClientType::Normal, 0, t, nullptr, ConnectionProtocol::Mqtt, false, nullptr, settings, false));
18+
std::shared_ptr<Client> c1(new Client(ClientType::Normal, 0, t, FmqSsl(), ConnectionProtocol::Mqtt, false, nullptr, settings, false));
1919
c1->setClientProperties(ProtocolVersion::Mqtt5, "clientid1", {}, "user1", true, 60);
2020

21-
std::shared_ptr<Client> c2(new Client(ClientType::Normal, 0, t, nullptr, ConnectionProtocol::Mqtt, false, nullptr, settings, false));
21+
std::shared_ptr<Client> c2(new Client(ClientType::Normal, 0, t, FmqSsl(), ConnectionProtocol::Mqtt, false, nullptr, settings, false));
2222
c2->setClientProperties(ProtocolVersion::Mqtt5, "clientid2", {}, "user2", true, 60);
2323

24-
std::shared_ptr<Client> c3(new Client(ClientType::Normal, 0, t, nullptr, ConnectionProtocol::Mqtt, false, nullptr, settings, false));
24+
std::shared_ptr<Client> c3(new Client(ClientType::Normal, 0, t, FmqSsl(), ConnectionProtocol::Mqtt, false, nullptr, settings, false));
2525
c3->setClientProperties(ProtocolVersion::Mqtt5, "clientid3", {}, "user3", true, 60);
2626

2727
std::shared_ptr<Session> ses1 = std::make_shared<Session>(c1->getClientId(), c1->getUsername(), std::optional<std::string>());

FlashMQTests/tst_maintests.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -964,13 +964,13 @@ void MainTests::testSavingSessions()
964964
std::shared_ptr<SubscriptionStore> store(new SubscriptionStore());
965965
std::shared_ptr<ThreadData> t(new ThreadData(0, settings, pluginLoader));
966966

967-
std::shared_ptr<Client> c1(new Client(ClientType::Normal, 0, t, nullptr, ConnectionProtocol::Mqtt, false, nullptr, settings, false));
967+
std::shared_ptr<Client> c1(new Client(ClientType::Normal, 0, t, FmqSsl(), ConnectionProtocol::Mqtt, false, nullptr, settings, false));
968968
c1->setClientProperties(ProtocolVersion::Mqtt5, "c1", {}, "user1", true, 60);
969969
store->registerClientAndKickExistingOne(c1, false, 512, 120);
970970
c1->getSession()->addIncomingQoS2MessageId(2);
971971
c1->getSession()->addIncomingQoS2MessageId(3);
972972

973-
std::shared_ptr<Client> c2(new Client(ClientType::Normal, 0, t, nullptr, ConnectionProtocol::Mqtt, false, nullptr, settings, false));
973+
std::shared_ptr<Client> c2(new Client(ClientType::Normal, 0, t, FmqSsl(), ConnectionProtocol::Mqtt, false, nullptr, settings, false));
974974
c2->setClientProperties(ProtocolVersion::Mqtt5, "c2", {}, "user2", true, 60);
975975
store->registerClientAndKickExistingOne(c2, false, 512, 120);
976976
c2->getSession()->addOutgoingQoS2MessageId(55);
@@ -1122,7 +1122,7 @@ void MainTests::testParsePacketHelper(const std::string &topic, uint8_t from_qos
11221122
std::shared_ptr<PluginLoader> pluginLoader = std::make_shared<PluginLoader>();
11231123
std::shared_ptr<ThreadData> t(new ThreadData(0, settings, pluginLoader));
11241124

1125-
std::shared_ptr<Client> dummyClient(new Client(ClientType::Normal, 0, t, nullptr, ConnectionProtocol::Mqtt, false, nullptr, settings, false));
1125+
std::shared_ptr<Client> dummyClient(new Client(ClientType::Normal, 0, t, FmqSsl(), ConnectionProtocol::Mqtt, false, nullptr, settings, false));
11261126
dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", {}, "user1", true, 60);
11271127
store->registerClientAndKickExistingOne(dummyClient, false, 512, 120);
11281128

@@ -1192,7 +1192,7 @@ void MainTests::testbufferToMqttPacketsFuzz()
11921192

11931193
settings.maxPacketSize = 32768;
11941194

1195-
std::shared_ptr<Client> dummyClient(new Client(ClientType::Normal, 0, t, nullptr, ConnectionProtocol::Mqtt, false, nullptr, settings, false));
1195+
std::shared_ptr<Client> dummyClient(new Client(ClientType::Normal, 0, t, FmqSsl(), ConnectionProtocol::Mqtt, false, nullptr, settings, false));
11961196
dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "dummy", {}, "user1", true, 60);
11971197
store->registerClientAndKickExistingOne(dummyClient, false, 512, 120);
11981198

@@ -3061,7 +3061,7 @@ void MainTests::testTopicMatchingInSubscriptionTreeHelper(const std::string &sub
30613061

30623062
std::shared_ptr<ThreadData> td;
30633063
const Settings *settings = ThreadGlobals::getSettings();
3064-
std::shared_ptr<Client> client = std::make_shared<Client>(ClientType::Normal, 0, td, nullptr, ConnectionProtocol::Mqtt, false, nullptr, *settings, false);
3064+
std::shared_ptr<Client> client = std::make_shared<Client>(ClientType::Normal, 0, td, FmqSsl(), ConnectionProtocol::Mqtt, false, nullptr, *settings, false);
30653065
client->setClientProperties(ProtocolVersion::Mqtt5, "mytestclient", {}, "myusername", true, 60);
30663066
store.registerClientAndKickExistingOne(client);
30673067

FlashMQTests/websockettests.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void MainTests::testWebsocketPing()
7575
int flags = fcntl(listen_socket, F_GETFL);
7676
check<std::runtime_error>(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK ));
7777

78-
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, nullptr, ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
78+
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
7979
std::shared_ptr<Client> client = c1;
8080
t->giveClient(std::move(c1));
8181

@@ -253,7 +253,7 @@ void MainTests::testWebsocketCorruptLengthFrame()
253253
int flags = fcntl(listen_socket, F_GETFL);
254254
check<std::runtime_error>(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK ));
255255

256-
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, nullptr, ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
256+
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
257257
std::shared_ptr<Client> client = c1;
258258
t->giveClient(std::move(c1));
259259

@@ -380,7 +380,7 @@ void MainTests::testWebsocketHugePing()
380380
int flags = fcntl(listen_socket, F_GETFL);
381381
check<std::runtime_error>(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK ));
382382

383-
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, nullptr, ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
383+
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
384384
std::shared_ptr<Client> client = c1;
385385
t->giveClient(std::move(c1));
386386

@@ -500,7 +500,7 @@ void MainTests::testWebsocketManyBigPingFrames()
500500
int flags = fcntl(listen_socket, F_GETFL);
501501
check<std::runtime_error>(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK ));
502502

503-
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, nullptr, ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
503+
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
504504
std::shared_ptr<Client> client = c1;
505505
t->giveClient(std::move(c1));
506506

@@ -646,7 +646,7 @@ void MainTests::testWebsocketClose()
646646
int flags = fcntl(listen_socket, F_GETFL);
647647
check<std::runtime_error>(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK ));
648648

649-
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, nullptr, ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
649+
std::shared_ptr<Client> c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, false, nullptr, settings, false));
650650
std::shared_ptr<Client> client = c1;
651651
t->giveClient(std::move(c1));
652652

client.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ Client::WriteBuf::WriteBuf(size_t size) :
6161
* @param fuzzMode
6262
*/
6363
Client::Client(
64-
ClientType type, int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, ConnectionProtocol connectionProtocol,
64+
ClientType type, int fd, std::shared_ptr<ThreadData> threadData, FmqSsl &&ssl, ConnectionProtocol connectionProtocol,
6565
bool haproxy, const struct sockaddr *addr, const Settings &settings, bool fuzzMode) :
6666
fd(fd),
6767
fuzzMode(fuzzMode),
6868
maxOutgoingPacketSize(settings.maxPacketSize),
6969
maxIncomingPacketSize(settings.maxPacketSize),
7070
maxIncomingTopicAliasValue(settings.maxIncomingTopicAliasValue), // Retaining snapshot of current setting, to not confuse clients when the setting changes.
71-
ioWrapper(ssl, connectionProtocol, settings.clientInitialBufferSize, this),
71+
ioWrapper(std::move(ssl), connectionProtocol, settings.clientInitialBufferSize, this),
7272
readbuf(settings.clientInitialBufferSize),
7373
writebuf(settings.clientInitialBufferSize),
7474
clientType(type),

client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ See LICENSE for license details.
3232
#include "enums.h"
3333
#include "fdmanaged.h"
3434
#include "mutexowned.h"
35+
#include "fmqssl.h"
3536

3637
#include "publishcopyfactory.h"
3738

@@ -161,7 +162,7 @@ class Client
161162
uint8_t preAuthPacketCounter = 0;
162163

163164
Client(
164-
ClientType type, int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, ConnectionProtocol connectionProtocol, bool haproxy,
165+
ClientType type, int fd, std::shared_ptr<ThreadData> threadData, FmqSsl &&ssl, ConnectionProtocol connectionProtocol, bool haproxy,
165166
const struct sockaddr *addr, const Settings &settings, bool fuzzMode=false);
166167
Client(const Client &other) = delete;
167168
Client(Client &&other) = delete;

flashmqtestclient.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion, bool clea
127127
const std::string clientid = formatString("testclient_%d", clientCount++);
128128

129129
std::shared_ptr<Client> client = std::make_shared<Client>(
130-
ClientType::Normal, sockfd, testServerWorkerThreadData.getThreadData(), nullptr, ConnectionProtocol::Mqtt, false, reinterpret_cast<struct sockaddr*>(&servaddr), settings);
130+
ClientType::Normal, sockfd, testServerWorkerThreadData.getThreadData(), FmqSsl(), ConnectionProtocol::Mqtt, false, reinterpret_cast<struct sockaddr*>(&servaddr), settings);
131131
this->client_weak = client;
132132
client->setClientProperties(protocolVersion, clientid, {}, "user", false, 60);
133133

fmqssl.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include <stdexcept>
2+
3+
#include <openssl/ssl.h>
4+
5+
#include "fmqssl.h"
6+
7+
FmqSsl::~FmqSsl()
8+
{
9+
if (d == nullptr) return;
10+
11+
/*
12+
* We write the shutdown when we can, but don't take error conditions into account. If socket buffers are full, because
13+
* clients disappear for instance, the socket is just closed. We don't care.
14+
*
15+
* Truncation attacks seem irrelevant. MQTT is frame based, so either end knows if the transmission is done or not. The
16+
* close_notify is not used in determining whether to use or discard the received data.
17+
*/
18+
SSL_shutdown(d);
19+
20+
SSL_free(d);
21+
d = nullptr;
22+
}
23+
24+
FmqSsl::FmqSsl(const SslCtxManager &ssl_ctx) :
25+
d(SSL_new(ssl_ctx.get()))
26+
{
27+
}
28+
29+
FmqSsl::FmqSsl(FmqSsl &&other) :
30+
d(other.d)
31+
{
32+
other.d = nullptr;
33+
}
34+
35+
FmqSsl &FmqSsl::operator=(FmqSsl &&other)
36+
{
37+
d = other.d;
38+
other.d = nullptr;
39+
return *this;
40+
}
41+
42+
void FmqSsl::set_fd(int fd)
43+
{
44+
if (!d) return;
45+
46+
SSL_set_fd(d, fd);
47+
}

fmqssl.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef FMQSSL_H
2+
#define FMQSSL_H
3+
4+
#include <openssl/ssl.h>
5+
6+
#include "sslctxmanager.h"
7+
8+
class FmqSsl
9+
{
10+
SSL* d = nullptr;
11+
12+
public:
13+
FmqSsl() = default;
14+
15+
FmqSsl(const SslCtxManager &ssl_ctx);
16+
17+
FmqSsl(const FmqSsl &other) = delete;
18+
19+
FmqSsl(FmqSsl &&other);
20+
21+
FmqSsl &operator=(const FmqSsl &other) = delete;
22+
23+
FmqSsl &operator=(FmqSsl &&other);
24+
25+
~FmqSsl();
26+
27+
operator bool() const
28+
{
29+
return d != nullptr;
30+
}
31+
SSL* get() const
32+
{
33+
return d;
34+
}
35+
36+
void set_fd(int fd);
37+
};
38+
39+
#endif // FMQSSL_H

iowrapper.cpp

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,34 +65,16 @@ IncompleteWebsocketRead::IncompleteWebsocketRead()
6565
reset();
6666
}
6767

68-
IoWrapper::IoWrapper(SSL *ssl, ConnectionProtocol connectionProtocol, const size_t initialBufferSize, Client *parent) :
68+
IoWrapper::IoWrapper(FmqSsl &&ssl, ConnectionProtocol connectionProtocol, const size_t initialBufferSize, Client *parent) :
6969
parentClient(parent),
70-
ssl(ssl),
70+
ssl(std::move(ssl)),
7171
connectionProtocol(connectionProtocol),
7272
websocketPendingBytes(connectionProtocol == ConnectionProtocol::WebsocketMqtt ? initialBufferSize : 0),
7373
websocketWriteRemainder(connectionProtocol == ConnectionProtocol::WebsocketMqtt ? initialBufferSize : 0)
7474
{
7575

7676
}
7777

78-
IoWrapper::~IoWrapper()
79-
{
80-
if (ssl)
81-
{
82-
/*
83-
* We write the shutdown when we can, but don't take error conditions into account. If socket buffers are full, because
84-
* clients disappear for instance, the socket is just closed. We don't care.
85-
*
86-
* Truncation attacks seem irrelevant. MQTT is frame based, so either end knows if the transmission is done or not. The
87-
* close_notify is not used in determining whether to use or discard the received data.
88-
*/
89-
SSL_shutdown(ssl);
90-
91-
SSL_free(ssl);
92-
ssl = nullptr;
93-
}
94-
}
95-
9678
void IoWrapper::startOrContinueSslHandshake()
9779
{
9880
if (parentClient->isOutgoingConnection())
@@ -105,10 +87,10 @@ void IoWrapper::startOrContinueSslConnect()
10587
{
10688
assert(ssl);
10789
ERR_clear_error();
108-
int connected = SSL_connect(ssl);
90+
int connected = SSL_connect(ssl.get());
10991
if (connected <= 0)
11092
{
111-
int err = SSL_get_error(ssl, connected);
93+
int err = SSL_get_error(ssl.get(), connected);
11294

11395
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
11496
{
@@ -134,10 +116,10 @@ void IoWrapper::startOrContinueSslConnect()
134116
void IoWrapper::startOrContinueSslAccept()
135117
{
136118
ERR_clear_error();
137-
int accepted = SSL_accept(ssl);
119+
int accepted = SSL_accept(ssl.get());
138120
if (accepted <= 0)
139121
{
140-
int err = SSL_get_error(ssl, accepted);
122+
int err = SSL_get_error(ssl.get(), accepted);
141123

142124
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
143125
{
@@ -176,7 +158,7 @@ bool IoWrapper::isSslAccepted() const
176158

177159
bool IoWrapper::isSsl() const
178160
{
179-
return this->ssl != nullptr;
161+
return this->ssl;
180162
}
181163

182164
static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
@@ -227,18 +209,18 @@ void IoWrapper::setSslVerify(int mode, const std::string &hostname)
227209
if (!ssl)
228210
return;
229211

230-
SSL_set_hostflags(ssl, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
212+
SSL_set_hostflags(ssl.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
231213

232214
if (!hostname.empty())
233215
{
234-
if (!SSL_set1_host(ssl, hostname.c_str()))
216+
if (!SSL_set1_host(ssl.get(), hostname.c_str()))
235217
throw std::runtime_error("Failed setting hostname of SSL context.");
236218

237-
if (SSL_set_tlsext_host_name(ssl, hostname.c_str()) != 1)
219+
if (SSL_set_tlsext_host_name(ssl.get(), hostname.c_str()) != 1)
238220
throw std::runtime_error("Failed setting SNI hostname of SSL context.");
239221
}
240222

241-
SSL_set_verify(ssl, mode, verify_callback);
223+
SSL_set_verify(ssl.get(), mode, verify_callback);
242224
}
243225

244226
bool IoWrapper::hasPendingWrite() const
@@ -259,7 +241,7 @@ bool IoWrapper::hasProcessedBufferedBytesToRead() const
259241
bool result = false;
260242

261243
if (ssl)
262-
result |= SSL_pending(ssl) > 0;
244+
result |= SSL_pending(ssl.get()) > 0;
263245

264246
/*
265247
* Note that this is tecnhically not 100% correct. If the only bytes are part of a header, doing a read will actually
@@ -278,13 +260,13 @@ WebsocketState IoWrapper::getWebsocketState() const
278260

279261
X509Manager IoWrapper::getPeerCertificate() const
280262
{
281-
X509Manager result(this->ssl);
263+
X509Manager result(this->ssl.get());
282264
return result;
283265
}
284266

285267
const char *IoWrapper::getSslVersion() const
286268
{
287-
return SSL_get_version(ssl);
269+
return SSL_get_version(ssl.get());
288270
}
289271

290272
bool IoWrapper::needsHaProxyParsing() const
@@ -428,12 +410,12 @@ ssize_t IoWrapper::readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult
428410
{
429411
this->sslReadWantsWrite = false;
430412
ERR_clear_error();
431-
ssize_t n = SSL_read(ssl, buf, nbytes);
413+
ssize_t n = SSL_read(ssl.get(), buf, nbytes);
432414

433415
if (n > 0)
434416
return n;
435417

436-
int err = SSL_get_error(ssl, n);
418+
int err = SSL_get_error(ssl.get(), n);
437419
unsigned long error_code = ERR_get_error();
438420

439421
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
@@ -539,11 +521,11 @@ ssize_t IoWrapper::writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWra
539521
this->incompleteSslWrite.reset();
540522

541523
ERR_clear_error();
542-
n = SSL_write(ssl, buf, nbytes_);
524+
n = SSL_write(ssl.get(), buf, nbytes_);
543525

544526
if (n <= 0)
545527
{
546-
int err = SSL_get_error(ssl, n);
528+
int err = SSL_get_error(ssl.get(), n);
547529
unsigned long error_code = ERR_get_error();
548530
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
549531
{

0 commit comments

Comments
 (0)