Skip to content
This repository was archived by the owner on Jan 20, 2025. It is now read-only.

Commit 110a997

Browse files
author
Me No Dev
committed
Initial SSL implementation
1 parent 5987225 commit 110a997

File tree

6 files changed

+581
-17
lines changed

6 files changed

+581
-17
lines changed

src/ESPAsyncTCP.cpp

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ extern "C"{
2828
#include "lwip/inet.h"
2929
#include "lwip/dns.h"
3030
}
31+
#include "lwipr_compat.h"
3132

3233
/*
3334
Async TCP Client
@@ -47,6 +48,8 @@ AsyncClient::AsyncClient(tcp_pcb* pcb):
4748
, _timeout_cb(0)
4849
, _timeout_cb_arg(0)
4950
, _pcb_busy(false)
51+
, _pcb_secure(false)
52+
, _handshake_done(true)
5053
, _pcb_sent_at(0)
5154
, _close_pcb(false)
5255
, _ack_pcb(true)
@@ -73,7 +76,7 @@ AsyncClient::~AsyncClient(){
7376
_close();
7477
}
7578

76-
bool AsyncClient::connect(IPAddress ip, uint16_t port){
79+
bool AsyncClient::connect(IPAddress ip, uint16_t port, bool secure){
7780

7881
if (_pcb) //already connected
7982
return false;
@@ -88,18 +91,22 @@ bool AsyncClient::connect(IPAddress ip, uint16_t port){
8891
if (!pcb) //could not allocate pcb
8992
return false;
9093

94+
_pcb_secure = secure;
95+
_handshake_done = !secure;
9196
tcp_arg(pcb, this);
9297
tcp_err(pcb, &_s_error);
9398
tcp_connect(pcb, &addr, port,(tcp_connected_fn)&_s_connected);
9499
return true;
95100
}
96101

97-
bool AsyncClient::connect(const char* host, uint16_t port) {
102+
bool AsyncClient::connect(const char* host, uint16_t port, bool secure) {
98103
ip_addr_t addr;
99104
err_t err = dns_gethostbyname(host, &addr, (dns_found_callback)&_s_dns_found, this);
100105
if(err == ERR_OK) {
101-
return connect(IPAddress(addr.addr), port);
106+
return connect(IPAddress(addr.addr), port, secure);
102107
} else if(err == ERR_INPROGRESS) {
108+
_pcb_secure = secure;
109+
_handshake_done = !secure;
103110
_connect_port = port;
104111
return true;
105112
}
@@ -165,6 +172,13 @@ size_t AsyncClient::write(const char* data, size_t size) {
165172
return 0;
166173
if(!canSend())
167174
return 0;
175+
if(_pcb_secure){
176+
int sent = axl_write(_pcb, (uint8_t*)data, size);
177+
if(sent >= 0)
178+
return sent;
179+
//ssl error
180+
return 0;
181+
}
168182
size_t room = tcp_sndbuf(_pcb);
169183
size_t will_send = (room < size) ? room : size;
170184
int8_t err = tcp_write(_pcb, data, will_send, 0);
@@ -188,6 +202,13 @@ size_t AsyncClient::add(const char* data, size_t size) {
188202
size_t room = tcp_sndbuf(_pcb);
189203
if(!room)
190204
return 0;
205+
if(_pcb_secure){
206+
int sent = axl_write(_pcb, (uint8_t*)data, size);
207+
if(sent >= 0)
208+
return sent;
209+
//ssl error
210+
return 0;
211+
}
191212
size_t will_send = (room < size) ? room : size;
192213
int8_t err = tcp_write(_pcb, data, will_send, 0);
193214
if(err != ERR_OK)
@@ -196,6 +217,8 @@ size_t AsyncClient::add(const char* data, size_t size) {
196217
}
197218

198219
bool AsyncClient::send(){
220+
if(_pcb_secure)
221+
return true;
199222
if(!canSend())
200223
return false;
201224
if(tcp_output(_pcb) == ERR_OK){
@@ -225,6 +248,8 @@ int8_t AsyncClient::_close(){
225248
tcp_recv(_pcb, NULL);
226249
tcp_err(_pcb, NULL);
227250
tcp_poll(_pcb, NULL, 0);
251+
if(_pcb_secure)
252+
axl_free(_pcb);
228253
err = tcp_close(_pcb);
229254
if(err != ERR_OK) {
230255
err = abort();
@@ -239,13 +264,23 @@ int8_t AsyncClient::_close(){
239264
int8_t AsyncClient::_connected(void* pcb, int8_t err){
240265
_pcb = reinterpret_cast<tcp_pcb*>(pcb);
241266
if(_pcb){
267+
if(_pcb_secure){
268+
axl_tcp_t * axl = axl_new(_pcb);
269+
if(axl == NULL){
270+
return _close();
271+
}
272+
axl_arg(_pcb, this);
273+
axl_data(_pcb, &_s_data);
274+
axl_handshake(_pcb, &_s_handshake);
275+
axl_err(_pcb, &_s_ssl_error);
276+
}
242277
tcp_setprio(_pcb, TCP_PRIO_MIN);
243278
tcp_recv(_pcb, &_s_recv);
244279
tcp_sent(_pcb, &_s_sent);
245280
tcp_poll(_pcb, &_s_poll, 1);
246281
_pcb_busy = false;
247282
}
248-
if(_connect_cb)
283+
if(!_pcb_secure && _connect_cb)
249284
_connect_cb(_connect_cb_arg, this);
250285
return ERR_OK;
251286
}
@@ -257,6 +292,8 @@ void AsyncClient::_error(int8_t err) {
257292
tcp_recv(_pcb, NULL);
258293
tcp_err(_pcb, NULL);
259294
tcp_poll(_pcb, NULL, 0);
295+
if(_pcb_secure)
296+
axl_free(_pcb);
260297
_pcb = NULL;
261298
}
262299
if(_error_cb)
@@ -265,6 +302,11 @@ void AsyncClient::_error(int8_t err) {
265302
_discard_cb(_discard_cb_arg, this);
266303
}
267304

305+
void AsyncClient::_ssl_error(int8_t err){
306+
if(_error_cb)
307+
_error_cb(_error_cb_arg, this, err+64);
308+
}
309+
268310
int8_t AsyncClient::_sent(tcp_pcb* pcb, uint16_t len) {
269311
_rx_last_packet = millis();
270312
_pcb_busy = false;
@@ -279,7 +321,16 @@ int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) {
279321
}
280322

281323
_rx_last_packet = millis();
282-
//use callback (onData defined)
324+
if(_pcb_secure){
325+
int read_bytes = axl_read(pcb, pb);
326+
if(read_bytes < 0){
327+
if (read_bytes != SSL_CLOSE_NOTIFY && read_bytes != SSL_ERROR_CONN_LOST) {
328+
_close();
329+
}
330+
return read_bytes;
331+
}
332+
return ERR_OK;
333+
}
283334
while(pb != NULL){
284335
//we should not ack before we assimilate the data
285336
_ack_pcb = true;
@@ -326,7 +377,7 @@ int8_t AsyncClient::_poll(tcp_pcb* pcb){
326377

327378
void AsyncClient::_dns_found(ip_addr_t *ipaddr){
328379
if(ipaddr){
329-
connect(IPAddress(ipaddr->addr), _connect_port);
380+
connect(IPAddress(ipaddr->addr), _connect_port, _pcb_secure);
330381
} else {
331382
if(_error_cb)
332383
_error_cb(_error_cb_arg, this, -55);
@@ -361,6 +412,24 @@ int8_t AsyncClient::_s_connected(void* arg, void* tpcb, int8_t err){
361412
return reinterpret_cast<AsyncClient*>(arg)->_connected(tpcb, err);
362413
}
363414

415+
void AsyncClient::_s_data(void *arg, struct tcp_pcb *tcp, uint8_t * data, size_t len){
416+
AsyncClient *c = reinterpret_cast<AsyncClient*>(arg);
417+
if(c->_recv_cb)
418+
c->_recv_cb(c->_recv_cb_arg, c, data, len);
419+
}
420+
421+
void AsyncClient::_s_handshake(void *arg, struct tcp_pcb *tcp, SSL *ssl){
422+
AsyncClient *c = reinterpret_cast<AsyncClient*>(arg);
423+
c->_handshake_done = true;
424+
if(c->_connect_cb)
425+
c->_connect_cb(c->_connect_cb_arg, c);
426+
}
427+
428+
void AsyncClient::_s_ssl_error(void *arg, struct tcp_pcb *tcp, int8_t err){
429+
reinterpret_cast<AsyncClient*>(arg)->_ssl_error(err);
430+
}
431+
432+
364433
// Operators
365434

366435
AsyncClient & AsyncClient::operator+=(const AsyncClient &other) {
@@ -447,6 +516,16 @@ uint16_t AsyncClient::localPort() {
447516
return getLocalPort();
448517
}
449518

519+
SSL * AsyncClient::getSSL(){
520+
if(_pcb && _pcb_secure){
521+
axl_tcp_t* axl = axl_get(_pcb);
522+
if(axl){
523+
return axl->ssl;
524+
}
525+
}
526+
return NULL;
527+
}
528+
450529
uint8_t AsyncClient::state() {
451530
if(!_pcb)
452531
return 0;
@@ -456,7 +535,7 @@ uint8_t AsyncClient::state() {
456535
bool AsyncClient::connected(){
457536
if (!_pcb)
458537
return false;
459-
return _pcb->state == 4;
538+
return _pcb->state == 4 && _handshake_done;
460539
}
461540

462541
bool AsyncClient::connecting(){

src/ESPAsyncTCP.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ typedef std::function<void(void*, AsyncClient*, uint32_t time)> AcTimeoutHandler
4444
struct tcp_pcb;
4545
struct pbuf;
4646
struct ip_addr;
47+
struct SSL_;
48+
typedef struct SSL_ SSL;
4749

4850
class AsyncClient {
4951
protected:
@@ -64,6 +66,8 @@ class AsyncClient {
6466
AcConnectHandler _poll_cb;
6567
void* _poll_cb_arg;
6668
bool _pcb_busy;
69+
bool _pcb_secure;
70+
bool _handshake_done;
6771
uint32_t _pcb_sent_at;
6872
bool _close_pcb;
6973
bool _ack_pcb;
@@ -76,6 +80,7 @@ class AsyncClient {
7680
int8_t _close();
7781
int8_t _connected(void* pcb, int8_t err);
7882
void _error(int8_t err);
83+
void _ssl_error(int8_t err);
7984
int8_t _poll(tcp_pcb* pcb);
8085
int8_t _sent(tcp_pcb* pcb, uint16_t len);
8186
int8_t _recv(tcp_pcb* pcb, pbuf* pb, int8_t err);
@@ -86,6 +91,9 @@ class AsyncClient {
8691
static int8_t _s_sent(void *arg, struct tcp_pcb *tpcb, uint16_t len);
8792
static int8_t _s_connected(void* arg, void* tpcb, int8_t err);
8893
static void _s_dns_found(const char *name, struct ip_addr *ipaddr, void *arg);
94+
static void _s_data(void *arg, struct tcp_pcb *tcp, uint8_t * data, size_t len);
95+
static void _s_handshake(void *arg, struct tcp_pcb *tcp, SSL *ssl);
96+
static void _s_ssl_error(void *arg, struct tcp_pcb *tcp, int8_t err);
8997

9098
public:
9199
AsyncClient* prev;
@@ -103,8 +111,8 @@ class AsyncClient {
103111
return !(*this == other);
104112
}
105113

106-
bool connect(IPAddress ip, uint16_t port);
107-
bool connect(const char* host, uint16_t port);
114+
bool connect(IPAddress ip, uint16_t port, bool secure=false);
115+
bool connect(const char* host, uint16_t port, bool secure=false);
108116
void close(bool now = false);
109117
void stop();
110118
int8_t abort();
@@ -117,6 +125,7 @@ class AsyncClient {
117125
size_t ack(size_t len); //ack data that you have not acked using the method below
118126
void ackLater(){ _ack_pcb = false; } //will not ack the current packet. Call from onData
119127

128+
SSL *getSSL();
120129

121130
size_t write(const char* data);
122131
size_t write(const char* data, size_t size); //only when canSend() == true

src/SyncClient.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,26 +53,26 @@ SyncClient::~SyncClient(){
5353
}
5454
}
5555

56-
int SyncClient::connect(IPAddress ip, uint16_t port){
56+
int SyncClient::connect(IPAddress ip, uint16_t port, bool secure){
5757
if(_client != NULL && connected())
5858
return 0;
5959
_client = new AsyncClient();
6060
_client->onConnect([](void *obj, AsyncClient *c){ ((SyncClient*)(obj))->_onConnect(c); }, this);
61-
if(_client->connect(ip, port)){
62-
while(_client->state() < 4)
61+
if(_client->connect(ip, port, secure)){
62+
while(_client != NULL && !_client->connected() && !_client->disconnecting())
6363
delay(1);
6464
return connected();
6565
}
6666
return 0;
6767
}
6868

69-
int SyncClient::connect(const char *host, uint16_t port){
69+
int SyncClient::connect(const char *host, uint16_t port, bool secure){
7070
if(_client != NULL && connected())
7171
return 0;
7272
_client = new AsyncClient();
7373
_client->onConnect([](void *obj, AsyncClient *c){ ((SyncClient*)(obj))->_onConnect(c); }, this);
74-
if(_client->connect(host, port)){
75-
while(_client->state() < 4)
74+
if(_client->connect(host, port, secure)){
75+
while(_client != NULL && !_client->connected() && !_client->disconnecting())
7676
delay(1);
7777
return connected();
7878
}

src/SyncClient.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,14 @@ class SyncClient: public Client {
4747
operator bool(){ return connected(); }
4848
SyncClient & operator=(const SyncClient &other);
4949

50-
int connect(IPAddress ip, uint16_t port);
51-
int connect(const char *host, uint16_t port);
50+
int connect(IPAddress ip, uint16_t port, bool secure);
51+
int connect(const char *host, uint16_t port, bool secure);
52+
int connect(IPAddress ip, uint16_t port){
53+
return connect(ip, port, false);
54+
}
55+
int connect(const char *host, uint16_t port){
56+
return connect(host, port, false);
57+
}
5258
void setTimeout(uint32_t seconds);
5359

5460
uint8_t status();

0 commit comments

Comments
 (0)