@@ -63,6 +63,8 @@ typedef struct {
6363 char * sub_protocol ;
6464 char * user_agent ;
6565 char * headers ;
66+ ws_header_hook_t header_hook ;
67+ void * header_user_context ;
6668 char * auth ;
6769 char * buffer ; /*!< Initial HTTP connection buffer, which may include data beyond the handshake headers, such as the next WebSocket packet*/
6870 size_t buffer_len ; /*!< The buffer length */
@@ -144,31 +146,6 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len
144146 return to_read ;
145147}
146148
147- static char * trimwhitespace (char * str )
148- {
149- char * end ;
150-
151- // Trim leading space
152- while (isspace ((unsigned char )* str )) {
153- str ++ ;
154- }
155-
156- if (* str == 0 ) {
157- return str ;
158- }
159-
160- // Trim trailing space
161- end = str + strlen (str ) - 1 ;
162- while (end > str && isspace ((unsigned char )* end )) {
163- end -- ;
164- }
165-
166- // Write new null terminator
167- * (end + 1 ) = '\0' ;
168-
169- return str ;
170- }
171-
172149static int get_http_status_code (const char * buffer )
173150{
174151 const char http [] = "HTTP/" ;
@@ -189,26 +166,14 @@ static int get_http_status_code(const char *buffer)
189166 return -1 ;
190167}
191168
192- static char * get_http_header (char * buffer , const char * key )
193- {
194- char * found = strcasestr (buffer , key );
195- if (found ) {
196- found += strlen (key );
197- char * found_end = strstr (found , "\r\n" );
198- if (found_end ) {
199- * found_end = '\0' ; // terminal string
200-
201- return trimwhitespace (found );
202- }
203- }
204- return NULL ;
205- }
206-
207169static int ws_connect (esp_transport_handle_t t , const char * host , int port , int timeout_ms )
208170{
209171 transport_ws_t * ws = esp_transport_get_context_data (t );
210172 const char delimiter [] = "\r\n\r\n" ;
211173
174+ free (ws -> redir_host );
175+ ws -> redir_host = NULL ;
176+
212177 if (esp_transport_connect (ws -> parent , host , port , timeout_ms ) < 0 ) {
213178 ESP_LOGE (TAG , "Error connecting to host %s:%d" , host , port );
214179 return -1 ;
@@ -327,16 +292,67 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
327292 if (ws -> http_status_code == -1 ) {
328293 ESP_LOGE (TAG , "HTTP upgrade failed" );
329294 return -1 ;
330- } else if (WS_HTTP_TEMPORARY_REDIRECT (ws -> http_status_code ) || WS_HTTP_PERMANENT_REDIRECT (ws -> http_status_code )) {
331- ws -> redir_host = get_http_header (ws -> buffer , "Location:" );
332- if (ws -> redir_host == NULL ) {
295+ }
296+
297+ const char * location = NULL ;
298+ int location_len = 0 ;
299+
300+ const char * server_key = NULL ;
301+ int server_key_len = 0 ;
302+ const char * header_cursor = strnstr (ws -> buffer , "\r\n" , header_len );
303+ if (!header_cursor ){
304+ ESP_LOGE (TAG , "HTTP Header locate failed" );
305+ return -1 ;
306+ }
307+ header_cursor += strlen ("\r\n" );
308+
309+ while (header_cursor < delim_ptr ){
310+ const char * end_of_line = strnstr (header_cursor , "\r\n" , header_len - (header_cursor - ws -> buffer ));
311+ if (!end_of_line ){
312+ ESP_LOGE (TAG , "HTTP Header walk failed" );
313+ return -1 ;
314+ }
315+ else if (end_of_line == header_cursor ){
316+ ESP_LOGD (TAG , "HTTP Header walk found end" );
317+ break ;
318+ }
319+ int line_len = end_of_line - header_cursor ;
320+ ESP_LOGD (TAG , "HTTP Header walk line:%.*s" , line_len , header_cursor );
321+
322+ // Check for Sec-WebSocket-Accept header
323+ const char * header_sec_websocket_accept = "Sec-WebSocket-Accept: " ;
324+ size_t header_sec_websocket_accept_len = strlen (header_sec_websocket_accept );
325+ if (line_len >= header_sec_websocket_accept_len && !strncasecmp (header_cursor , header_sec_websocket_accept , header_sec_websocket_accept_len )) {
326+ ESP_LOGD (TAG , "found server-key" );
327+ server_key = header_cursor + header_sec_websocket_accept_len ;
328+ server_key_len = line_len - header_sec_websocket_accept_len ;
329+ }
330+ else if (ws -> header_hook ) {
331+ ws -> header_hook (ws -> header_user_context , header_cursor , line_len );
332+ }
333+
334+ // Check for Location: header
335+ const char * header_location = "Location: " ;
336+ size_t header_location_len = strlen (header_location );
337+ if (line_len >= header_location_len && !strncasecmp (header_cursor , header_location , header_location_len )) {
338+ location = header_cursor + header_location_len ;
339+ location_len = line_len - header_location_len ;
340+ }
341+
342+ // Adjust cursor to the start of the next line
343+ header_cursor += line_len ;
344+ header_cursor += strlen ("\r\n" );
345+ }
346+
347+ if (WS_HTTP_TEMPORARY_REDIRECT (ws -> http_status_code ) || WS_HTTP_PERMANENT_REDIRECT (ws -> http_status_code )) {
348+ if (location == NULL || location_len <= 0 ) {
333349 ESP_LOGE (TAG , "Location header not found" );
334350 return -1 ;
335351 }
352+ ws -> redir_host = strndup (location , location_len );
336353 return ws -> http_status_code ;
337354 }
338355
339- char * server_key = get_http_header (ws -> buffer , "Sec-WebSocket-Accept:" );
340356 if (server_key == NULL ) {
341357 ESP_LOGE (TAG , "Sec-WebSocket-Accept not found" );
342358 return -1 ;
@@ -357,7 +373,7 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
357373 esp_crypto_base64_encode (expected_server_key , sizeof (expected_server_key ), & outlen , expected_server_sha1 , sizeof (expected_server_sha1 ));
358374 expected_server_key [ (outlen < sizeof (expected_server_key )) ? outlen : (sizeof (expected_server_key ) - 1 ) ] = 0 ;
359375 ESP_LOGD (TAG , "server key=%s, send_key=%s, expected_server_key=%s" , (char * )server_key , (char * )client_key , expected_server_key );
360- if (strcmp ((char * )expected_server_key , (char * )server_key ) != 0 ) {
376+ if (strncmp ((char * )expected_server_key , (char * )server_key , server_key_len ) != 0 ) {
361377 ESP_LOGE (TAG , "Invalid websocket key" );
362378 return -1 ;
363379 }
@@ -371,7 +387,6 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
371387 } else {
372388#ifdef CONFIG_WS_DYNAMIC_BUFFER
373389 free (ws -> buffer );
374- ws -> redir_host = NULL ;
375390 ws -> buffer = NULL ;
376391#endif
377392 ws -> buffer_len = 0 ;
@@ -713,6 +728,7 @@ static esp_err_t ws_destroy(esp_transport_handle_t t)
713728{
714729 transport_ws_t * ws = esp_transport_get_context_data (t );
715730 free (ws -> buffer );
731+ free (ws -> redir_host );
716732 free (ws -> path );
717733 free (ws -> sub_protocol );
718734 free (ws -> user_agent );
@@ -862,6 +878,23 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea
862878 return ESP_OK ;
863879}
864880
881+ esp_err_t esp_transport_ws_set_header_hook (esp_transport_handle_t t , ws_header_hook_t hook , void * user_context )
882+ {
883+ if (t == NULL ) {
884+ return ESP_ERR_INVALID_ARG ;
885+ }
886+ if (hook == NULL ) {
887+ ESP_LOGE (TAG , "Header hook is NULL" );
888+ return ESP_ERR_INVALID_ARG ;
889+ }
890+ ESP_LOGV (TAG , "User has context: %s" , user_context != NULL ? "true" : "false" );
891+
892+ transport_ws_t * ws = esp_transport_get_context_data (t );
893+ ws -> header_hook = hook ;
894+ ws -> header_user_context = user_context ;
895+ return ESP_OK ;
896+ }
897+
865898esp_err_t esp_transport_ws_set_auth (esp_transport_handle_t t , const char * auth )
866899{
867900 if (t == NULL ) {
@@ -927,6 +960,10 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp
927960 err = esp_transport_ws_set_headers (t , config -> headers );
928961 ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
929962 }
963+ if (config -> header_hook || config -> header_user_context ) {
964+ err = esp_transport_ws_set_header_hook (t , config -> header_hook , config -> header_user_context );
965+ ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
966+ }
930967 if (config -> auth ) {
931968 err = esp_transport_ws_set_auth (t , config -> auth );
932969 ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
0 commit comments