Skip to content

Commit 2205a22

Browse files
author
glmfe
committed
feat(tcp_transport): Add websocket HTTP redirect
- Add and expose URI parser from HTTP when received a 301 status
1 parent 3c60a00 commit 2205a22

File tree

3 files changed

+176
-51
lines changed

3 files changed

+176
-51
lines changed

components/tcp_transport/host_test/main/test_websocket_transport.cpp

Lines changed: 129 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: 2024 Espressif Systems (Shanghai) CO LTD
2+
* SPDX-FileCopyrightText: 2024-2025 Espressif Systems (Shanghai) CO LTD
33
*
44
* SPDX-License-Identifier: Apache-2.0
55
*/
@@ -145,9 +145,8 @@ int mock_valid_poll_read_fragmented_callback(esp_transport_handle_t t, int timeo
145145

146146
}
147147

148-
void test_ws_connect(bool expect_valid_connection,
149-
CMOCK_mock_read_CALLBACK read_callback,
150-
CMOCK_mock_poll_read_CALLBACK poll_read_callback=mock_poll_read_callback) {
148+
TEST_CASE("WebSocket Transport Connection", "[success]")
149+
{
151150
constexpr static auto timeout = 50;
152151
constexpr static auto port = 8080;
153152
constexpr static auto host = "localhost";
@@ -161,26 +160,39 @@ void test_ws_connect(bool expect_valid_connection,
161160
unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy};
162161
REQUIRE(websocket_transport);
163162

164-
SECTION("Successful connection and read data") {
165-
fmt::print("Attempting to connect to WebSocket\n");
166-
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
167-
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
163+
fmt::print("Attempting to connect to WebSocket\n");
164+
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
165+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
166+
167+
// Set the callback function for mock_write
168+
mock_write_Stub(mock_write_callback);
169+
mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK);
168170

169-
// Set the callback function for mock_write
170-
mock_write_Stub(mock_write_callback);
171-
mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK);
171+
SECTION("Happy flow") {
172172
// Set the callback function for mock_read
173-
mock_read_Stub(read_callback);
174-
mock_poll_read_Stub(poll_read_callback);
173+
mock_read_Stub(mock_valid_read_callback);
174+
mock_poll_read_Stub(mock_poll_read_callback);
175175
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
176176
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
177177

178-
if (!expect_valid_connection) {
179-
// for invalid connections we only check that the connect() function fails
180-
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
181-
// and we're done here
182-
return;
183-
}
178+
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
179+
char buffer[WS_BUFFER_SIZE];
180+
int read_len = 0;
181+
read_len = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout);
182+
183+
fmt::print("Read result: {}\n", read_len);
184+
REQUIRE(read_len > 0); // Ensure data is read
185+
186+
std::string response(buffer, read_len);
187+
REQUIRE(response == "Test");
188+
}
189+
190+
SECTION("Happy flow with fragmented reads byte by byte") {
191+
// Set the callback function for mock_read
192+
mock_read_Stub(mock_valid_read_fragmented_callback);
193+
mock_poll_read_Stub(mock_valid_poll_read_fragmented_callback);
194+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
195+
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
184196

185197
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
186198

@@ -195,43 +207,111 @@ void test_ws_connect(bool expect_valid_connection,
195207

196208
std::string response(buffer, read_len);
197209
REQUIRE(response == "Test");
198-
199210
}
200211
}
201212

202-
// Happy flow
203-
TEST_CASE("WebSocket Transport Connection", "[websocket_transport]")
213+
TEST_CASE("WebSocket Transport Connection", "[failure]")
204214
{
205-
test_ws_connect(true, mock_valid_read_callback);
206-
}
215+
constexpr static auto timeout = 50;
216+
constexpr static auto port = 8080;
217+
constexpr static auto host = "localhost";
218+
// Initialize the parent handle
219+
unique_transport parent_handle{esp_transport_init(), esp_transport_destroy};
220+
REQUIRE(parent_handle);
207221

208-
// Happy flow with fragmented reads byte by byte
209-
TEST_CASE("ws connect and reads by fragments", "[websocket_transport]")
210-
{
211-
test_ws_connect(true, mock_valid_read_fragmented_callback, mock_valid_poll_read_fragmented_callback);
212-
}
222+
// Set mock functions for parent handle
223+
esp_transport_set_func(parent_handle.get(), mock_connect, mock_read, mock_write, mock_close, mock_poll_read, mock_poll_write, mock_destroy);
213224

214-
// Some corner cases where we expect the ws connection to fail
225+
unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy};
226+
REQUIRE(websocket_transport);
215227

216-
TEST_CASE("ws connect fails (0 len response)", "[websocket_transport]")
217-
{
218-
test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
219-
return 0;
220-
});
221-
}
228+
fmt::print("Attempting to connect to WebSocket\n");
229+
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
230+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
222231

223-
TEST_CASE("ws connect fails (invalid response)", "[websocket_transport]")
224-
{
225-
test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
226-
int resp_len = len/2;
227-
std::memset(buf, '?', resp_len);
228-
return resp_len;
229-
});
230-
}
232+
// Set the callback function for mock_write
233+
mock_write_Stub(mock_write_callback);
234+
mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK);
231235

232-
TEST_CASE("ws connect fails (big response)", "[websocket_transport]")
233-
{
234-
test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
235-
return WS_BUFFER_SIZE;
236-
});
236+
SECTION("ws connect fails (0 len response)") {
237+
// Set the callback function for mock_read
238+
mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
239+
return 0;
240+
});
241+
mock_poll_read_Stub(mock_poll_read_callback);
242+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
243+
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
244+
245+
// check that the connect() function fails
246+
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
247+
}
248+
249+
SECTION("ws connect fails (invalid response)") {
250+
// Set the callback function for mock_read
251+
mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
252+
int resp_len = len / 2;
253+
std::memset(buf, '?', resp_len);
254+
return resp_len;
255+
});
256+
mock_poll_read_Stub(mock_poll_read_callback);
257+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
258+
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
259+
260+
// check that the connect() function fails
261+
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
262+
}
263+
264+
SECTION("ws connect fails (big response)") {
265+
// Set the callback function for mock_read
266+
mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
267+
return WS_BUFFER_SIZE;
268+
});
269+
mock_poll_read_Stub(mock_poll_read_callback);
270+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
271+
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
272+
273+
// check that the connect() function fails
274+
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
275+
}
276+
277+
SECTION("ws connect receives redirection response") {
278+
// Set the callback function for mock_read
279+
mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
280+
char response[WS_BUFFER_SIZE];
281+
int response_length = snprintf(response, WS_BUFFER_SIZE,
282+
"HTTP/1.1 301 Moved Permanently\r\n"
283+
"Location: ws://newhost:8080\r\n"
284+
"\r\n");
285+
std::memcpy(buf, response, response_length);
286+
return response_length;
287+
});
288+
mock_poll_read_Stub(mock_poll_read_callback);
289+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
290+
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
291+
292+
// check that the connect() function returns redir status
293+
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 301);
294+
// Assert the expected HTTP status code
295+
REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301);
296+
}
297+
298+
SECTION("ws connect receives redirection response without location uri") {
299+
// Set the callback function for mock_read
300+
mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
301+
char response[WS_BUFFER_SIZE];
302+
int response_length = snprintf(response, WS_BUFFER_SIZE,
303+
"HTTP/1.1 301 Moved Permanently\r\n"
304+
"\r\n");
305+
std::memcpy(buf, response, response_length);
306+
return response_length;
307+
});
308+
mock_poll_read_Stub(mock_poll_read_callback);
309+
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
310+
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
311+
312+
// check that the connect() function fails
313+
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == -1);
314+
// Assert the expected HTTP status code
315+
REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301);
316+
}
237317
}

components/tcp_transport/include/esp_transport_ws.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
extern "C" {
1515
#endif
1616

17+
1718
typedef enum ws_transport_opcodes {
1819
WS_TRANSPORT_OPCODES_CONT = 0x00,
1920
WS_TRANSPORT_OPCODES_TEXT = 0x01,
@@ -152,7 +153,7 @@ bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t);
152153
/**
153154
* @brief Returns the HTTP status code of the websocket handshake
154155
*
155-
* This API should be called after the connection atempt otherwise its result is meaningless
156+
* This API should be called after the connection attempt otherwise its result is meaningless
156157
*
157158
* @param t websocket transport handle
158159
*
@@ -162,6 +163,17 @@ bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t);
162163
*/
163164
int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t);
164165

166+
/**
167+
* @brief Returns websocket redir host for the last connection attempt
168+
*
169+
* @param t websocket transport handle
170+
*
171+
* @return
172+
* - URI of the redirection host
173+
* - NULL if no redirection was attempted
174+
*/
175+
char* esp_transport_ws_get_redir_uri(esp_transport_handle_t t);
176+
165177
/**
166178
* @brief Returns websocket op-code for last received data
167179
*

components/tcp_transport/transport_ws.c

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,17 @@ static const char *TAG = "transport_ws";
3737
#define WS_SIZE16 126
3838
#define WS_SIZE64 127
3939
#define MAX_WEBSOCKET_HEADER_SIZE 16
40-
#define WS_RESPONSE_OK 101
4140
#define WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN 125
4241

42+
// HTTP status codes for redirection as described in RFC 9110.
43+
#define WS_HTTP_CODE_MOVED_PERMANENTLY 301
44+
#define WS_HTTP_CODE_FOUND 302
45+
#define WS_HTTP_CODE_SEE_OTHER 303
46+
#define WS_HTTP_CODE_TEMPORARY_REDIRECT 307
47+
#define WS_HTTP_CODE_PERMANENT_REDIRECT 308
48+
// Grouped HTTP status codes for redirection types.
49+
#define WS_HTTP_PERMANENT_REDIRECT(code) ((code == WS_HTTP_CODE_MOVED_PERMANENTLY) || (code == WS_HTTP_CODE_PERMANENT_REDIRECT))
50+
#define WS_HTTP_TEMPORARY_REDIRECT(code) ((code == WS_HTTP_CODE_FOUND) || (code == WS_HTTP_CODE_SEE_OTHER) || (code == WS_HTTP_CODE_TEMPORARY_REDIRECT))
4351

4452
typedef struct {
4553
uint8_t opcode;
@@ -62,6 +70,7 @@ typedef struct {
6270
bool propagate_control_frames;
6371
ws_transport_frame_state_t frame_state;
6472
esp_transport_handle_t parent;
73+
char *redir_host;
6574
} transport_ws_t;
6675

6776
/**
@@ -306,6 +315,13 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
306315
if (ws->http_status_code == -1) {
307316
ESP_LOGE(TAG, "HTTP upgrade failed");
308317
return -1;
318+
} else if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) {
319+
ws->redir_host = get_http_header(ws->buffer, "Location:");
320+
if (ws->redir_host == NULL) {
321+
ESP_LOGE(TAG, "Location header not found");
322+
return -1;
323+
}
324+
return ws->http_status_code;
309325
}
310326

311327
char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:");
@@ -343,6 +359,7 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
343359
} else {
344360
#ifdef CONFIG_WS_DYNAMIC_BUFFER
345361
free(ws->buffer);
362+
ws->redir_host = NULL;
346363
ws->buffer = NULL;
347364
#endif
348365
ws->buffer_len = 0;
@@ -897,6 +914,22 @@ int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t)
897914
return ws->http_status_code;
898915
}
899916

917+
char* esp_transport_ws_get_redir_uri(esp_transport_handle_t t)
918+
{
919+
if (!t) {
920+
ESP_LOGE(TAG, "Invalid Transport handle (null)");
921+
return NULL;
922+
}
923+
924+
transport_ws_t *ws = esp_transport_get_context_data(t);
925+
if (!ws) {
926+
ESP_LOGE(TAG, "Invalid ws context data (null)");
927+
return NULL;
928+
}
929+
930+
return ws->redir_host;
931+
}
932+
900933
ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t)
901934
{
902935
transport_ws_t *ws = esp_transport_get_context_data(t);

0 commit comments

Comments
 (0)