Skip to content

Commit 1d12eb4

Browse files
author
Guilherme Ferreira
committed
Merge branch 'feat/add-ws-get-response-headers' into 'master'
feat(tcp_transport): Add ws get HTTP response headers Closes IDFGH-14252 See merge request espressif/esp-idf!38212
2 parents dadcc7b + 1d71a9e commit 1d12eb4

File tree

3 files changed

+135
-52
lines changed

3 files changed

+135
-52
lines changed

components/tcp_transport/host_test/main/test_websocket_transport.cpp

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ int mock_poll_read_callback(esp_transport_handle_t t, int timeout_ms, int num_ca
113113

114114
int mock_valid_read_callback(esp_transport_handle_t transport, char *buffer, int len, int timeout_ms, int num_call)
115115
{
116-
if (num_call) {
117-
return 0;
118-
}
119116
std::string websocket_response = make_response();
120117
std::memcpy(buffer, websocket_response.data(), websocket_response.size());
121118
return websocket_response.size();
@@ -160,6 +157,21 @@ TEST_CASE("WebSocket Transport Connection", "[success]")
160157
unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy};
161158
REQUIRE(websocket_transport);
162159

160+
// Allocate buffer for response header
161+
constexpr size_t response_header_len = 1024;
162+
std::vector<char> response_header_buffer(response_header_len);
163+
esp_transport_ws_config_t ws_config = {
164+
.ws_path = "/",
165+
.sub_protocol = nullptr,
166+
.user_agent = nullptr,
167+
.headers = nullptr,
168+
.auth = nullptr,
169+
.response_headers = response_header_buffer.data(),
170+
.response_headers_len = response_header_len,
171+
.propagate_control_frames = false
172+
};
173+
REQUIRE(esp_transport_ws_set_config(websocket_transport.get(), &ws_config) == ESP_OK);
174+
163175
fmt::print("Attempting to connect to WebSocket\n");
164176
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
165177
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
@@ -176,6 +188,11 @@ TEST_CASE("WebSocket Transport Connection", "[success]")
176188
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
177189

178190
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
191+
192+
// Verify the response header was stored correctly
193+
std::string expected_header = make_response();
194+
REQUIRE(std::string(response_header_buffer.data()) == expected_header);
195+
179196
char buffer[WS_BUFFER_SIZE];
180197
int read_len = 0;
181198
read_len = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout);
@@ -196,6 +213,14 @@ TEST_CASE("WebSocket Transport Connection", "[success]")
196213

197214
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
198215

216+
// Verify the response header was stored correctly
217+
std::string expected_header = "HTTP/1.1 101 Switching Protocols\r\n"
218+
"Upgrade: websocket\r\n"
219+
"Connection: Upgrade\r\n"
220+
"Sec-WebSocket-Accept:\r\n"
221+
"\r\n";
222+
REQUIRE(std::string(response_header_buffer.data()) == expected_header);
223+
199224
char buffer[WS_BUFFER_SIZE];
200225
int read_len = 0;
201226
int partial_read;
@@ -208,6 +233,25 @@ TEST_CASE("WebSocket Transport Connection", "[success]")
208233
std::string response(buffer, read_len);
209234
REQUIRE(response == "Test");
210235
}
236+
237+
SECTION("Happy flow with smaller response header") {
238+
// Set the response header length to 10
239+
ws_config.response_headers_len = 10;
240+
REQUIRE(esp_transport_ws_set_config(websocket_transport.get(), &ws_config) == ESP_OK);
241+
242+
// Set the callback function for mock_read
243+
mock_read_Stub(mock_valid_read_callback);
244+
mock_poll_read_Stub(mock_poll_read_callback);
245+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
246+
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
247+
248+
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
249+
250+
// Verify the response header was stored correctly. it must contain only ten bytes and be null terminated
251+
std::string expected_header = "HTTP/1.1 1\0";
252+
253+
REQUIRE(std::string(response_header_buffer.data()) == expected_header);
254+
}
211255
}
212256

213257
TEST_CASE("WebSocket Transport Connection", "[failure]")
@@ -225,6 +269,21 @@ TEST_CASE("WebSocket Transport Connection", "[failure]")
225269
unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy};
226270
REQUIRE(websocket_transport);
227271

272+
// Allocate buffer for response header
273+
constexpr size_t response_header_len = 1024;
274+
std::vector<char> response_header_buffer(response_header_len);
275+
esp_transport_ws_config_t ws_config = {
276+
.ws_path = "/",
277+
.sub_protocol = nullptr,
278+
.user_agent = nullptr,
279+
.headers = nullptr,
280+
.auth = nullptr,
281+
.response_headers = response_header_buffer.data(),
282+
.response_headers_len = response_header_len,
283+
.propagate_control_frames = false
284+
};
285+
REQUIRE(esp_transport_ws_set_config(websocket_transport.get(), &ws_config) == ESP_OK);
286+
228287
fmt::print("Attempting to connect to WebSocket\n");
229288
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
230289
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
@@ -244,6 +303,9 @@ TEST_CASE("WebSocket Transport Connection", "[failure]")
244303

245304
// check that the connect() function fails
246305
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
306+
307+
// Verify the response header is empty
308+
REQUIRE(std::string(response_header_buffer.data()) == "");
247309
}
248310

249311
SECTION("ws connect fails (invalid response)") {
@@ -259,6 +321,9 @@ TEST_CASE("WebSocket Transport Connection", "[failure]")
259321

260322
// check that the connect() function fails
261323
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
324+
325+
// Verify the response header is empty
326+
REQUIRE(std::string(response_header_buffer.data()) == "");
262327
}
263328

264329
SECTION("ws connect fails (big response)") {
@@ -272,46 +337,8 @@ TEST_CASE("WebSocket Transport Connection", "[failure]")
272337

273338
// check that the connect() function fails
274339
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
275-
}
276-
277-
SECTION("ws connect receives redirection response") {
278-
// Set the callback function for mock_read
279-
mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
280-
char response[WS_BUFFER_SIZE];
281-
int response_length = snprintf(response, WS_BUFFER_SIZE,
282-
"HTTP/1.1 301 Moved Permanently\r\n"
283-
"Location: ws://newhost:8080\r\n"
284-
"\r\n");
285-
std::memcpy(buf, response, response_length);
286-
return response_length;
287-
});
288-
mock_poll_read_Stub(mock_poll_read_callback);
289-
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
290-
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
291340

292-
// check that the connect() function returns redir status
293-
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 301);
294-
// Assert the expected HTTP status code
295-
REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301);
296-
}
297-
298-
SECTION("ws connect receives redirection response without location uri") {
299-
// Set the callback function for mock_read
300-
mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
301-
char response[WS_BUFFER_SIZE];
302-
int response_length = snprintf(response, WS_BUFFER_SIZE,
303-
"HTTP/1.1 301 Moved Permanently\r\n"
304-
"\r\n");
305-
std::memcpy(buf, response, response_length);
306-
return response_length;
307-
});
308-
mock_poll_read_Stub(mock_poll_read_callback);
309-
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
310-
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
311-
312-
// check that the connect() function fails
313-
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == -1);
314-
// Assert the expected HTTP status code
315-
REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301);
341+
// Verify the response header is empty
342+
REQUIRE(std::string(response_header_buffer.data()) == "");
316343
}
317344
}

components/tcp_transport/include/esp_transport_ws.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
extern "C" {
1515
#endif
1616

17+
// Features supported
18+
#define ESP_TRANSPORT_WS_STORE_RESPONSE_HEADERS 1
1719

1820
typedef enum ws_transport_opcodes {
1921
WS_TRANSPORT_OPCODES_CONT = 0x00,
@@ -36,6 +38,8 @@ typedef struct {
3638
const char *user_agent; /*!< WS user agent */
3739
const char *headers; /*!< WS additional headers */
3840
const char *auth; /*!< HTTP authorization header */
41+
char *response_headers; /*!< The buffer to copy the http response header */
42+
size_t response_headers_len; /*!< The length of the http response header */
3943
bool propagate_control_frames; /*!< If true, control frames are passed to the reader
4044
* If false, only user frames are propagated, control frames are handled
4145
* automatically during read operations
@@ -107,6 +111,19 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea
107111
*/
108112
esp_err_t esp_transport_ws_set_auth(esp_transport_handle_t t, const char *auth);
109113

114+
/**
115+
* @brief Set the buffer to copy the http response header
116+
*
117+
* @param[in] t The transport handle
118+
* @param[in] response_header The buffer to copy the http response header
119+
* @param[in] response_header_len The length of the http response header
120+
*
121+
* @return
122+
* - ESP_OK
123+
* - ESP_FAIL
124+
*/
125+
esp_err_t esp_transport_ws_set_response_headers(esp_transport_handle_t t, char *response_header, size_t response_header_len);
126+
110127
/**
111128
* @brief Set websocket transport parameters
112129
*

components/tcp_transport/transport_ws.c

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ typedef struct {
7171
ws_transport_frame_state_t frame_state;
7272
esp_transport_handle_t parent;
7373
char *redir_host;
74+
char *response_header;
75+
size_t response_header_len;
7476
} transport_ws_t;
7577

7678
/**
@@ -305,14 +307,24 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
305307
} while (NULL == strstr(ws->buffer, delimiter) && header_len < WS_BUFFER_SIZE - 1);
306308

307309
if (header_len >= WS_BUFFER_SIZE - 1) {
308-
ESP_LOGE(TAG, "Header size exceeded buffer size");
310+
ESP_LOGE(TAG, "Header size exceeded buffer size (need=%d, max=%d)", header_len + 1, WS_BUFFER_SIZE);
309311
return -1;
310312
}
311313

314+
if(ws->response_header) {
315+
if(ws->response_header_len < header_len) {
316+
ESP_LOGW(TAG, "Received header length exceedes the allocated buffer size (need=%d, allocated=%d), truncating to allocated size", header_len, ws->response_header_len);
317+
header_len = ws->response_header_len;
318+
}
319+
// Copy response header to the static array
320+
strncpy(ws->response_header, ws->buffer, header_len);
321+
ws->response_header[header_len] = '\0';
322+
}
323+
312324
char* delim_ptr = strstr(ws->buffer, delimiter);
313325

314326
ws->http_status_code = get_http_status_code(ws->buffer);
315-
if (ws->http_status_code == -1) {
327+
if (ws->http_status_code == -1) {
316328
ESP_LOGE(TAG, "HTTP upgrade failed");
317329
return -1;
318330
} else if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) {
@@ -605,7 +617,7 @@ static int ws_handle_control_frame_internal(esp_transport_handle_t t, int timeou
605617

606618
if (payload_len > WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN) {
607619
ESP_LOGE(TAG, "Not enough room for reading control frames (need=%d, max_allowed=%d)",
608-
ws->frame_state.payload_len, WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN);
620+
ws->frame_state.payload_len, WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN);
609621
return -1;
610622
}
611623

@@ -625,7 +637,7 @@ static int ws_handle_control_frame_internal(esp_transport_handle_t t, int timeou
625637
int actual_len = ws_read_payload(t, control_frame_buffer, control_frame_buffer_len, timeout_ms);
626638
if (actual_len != payload_len) {
627639
ESP_LOGE(TAG, "Control frame (opcode=%d) payload read failed (payload_len=%d, read_len=%d)",
628-
ws->frame_state.opcode, payload_len, actual_len);
640+
ws->frame_state.opcode, payload_len, actual_len);
629641
ret = -1;
630642
goto free_payload_buffer;
631643
}
@@ -751,8 +763,8 @@ static int ws_get_socket(esp_transport_handle_t t)
751763
esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle)
752764
{
753765
if (parent_handle == NULL) {
754-
ESP_LOGE(TAG, "Invalid parent ptotocol");
755-
return NULL;
766+
ESP_LOGE(TAG, "Invalid parent ptotocol");
767+
return NULL;
756768
}
757769
esp_transport_handle_t t = esp_transport_init();
758770
if (t == NULL) {
@@ -870,6 +882,28 @@ esp_err_t esp_transport_ws_set_auth(esp_transport_handle_t t, const char *auth)
870882
return ESP_OK;
871883
}
872884

885+
esp_err_t esp_transport_ws_set_response_headers(esp_transport_handle_t t, char *response_header, size_t response_header_len)
886+
{
887+
if (t == NULL) {
888+
return ESP_ERR_INVALID_ARG;
889+
}
890+
891+
if (response_header != NULL && response_header_len == 0) {
892+
ESP_LOGE(TAG, "Invalid response header length");
893+
return ESP_ERR_INVALID_ARG;
894+
}
895+
896+
transport_ws_t *ws = esp_transport_get_context_data(t);
897+
898+
if (ws == NULL) {
899+
return ESP_ERR_INVALID_ARG;
900+
}
901+
902+
ws->response_header = response_header;
903+
ws->response_header_len = response_header_len;
904+
return ESP_OK;
905+
}
906+
873907
esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transport_ws_config_t *config)
874908
{
875909
if (t == NULL) {
@@ -897,15 +931,20 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp
897931
err = esp_transport_ws_set_auth(t, config->auth);
898932
ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
899933
}
934+
if(config->response_headers) {
935+
err = esp_transport_ws_set_response_headers(t, config->response_headers, config->response_headers_len);
936+
ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
937+
}
938+
900939
ws->propagate_control_frames = config->propagate_control_frames;
901940

902941
return err;
903942
}
904943

905944
bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t)
906945
{
907-
transport_ws_t *ws = esp_transport_get_context_data(t);
908-
return ws->frame_state.fin;
946+
transport_ws_t *ws = esp_transport_get_context_data(t);
947+
return ws->frame_state.fin;
909948
}
910949

911950
int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t)
@@ -969,7 +1008,7 @@ static int esp_transport_ws_handle_control_frames(esp_transport_handle_t t, char
9691008
if (ws->frame_state.opcode == WS_OPCODE_PING) {
9701009
// handle PING frames internally: just send a PONG with the same payload
9711010
actual_len = _ws_write(t, WS_OPCODE_PONG | WS_FIN, WS_MASK, buffer,
972-
payload_len, timeout_ms);
1011+
payload_len, timeout_ms);
9731012
if (actual_len != payload_len) {
9741013
ESP_LOGE(TAG, "PONG send failed (payload_len=%d, written_len=%d)", payload_len, actual_len);
9751014
return -1;

0 commit comments

Comments
 (0)