diff --git a/CMakeLists.txt b/CMakeLists.txt index b460254..b37980f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,6 +80,7 @@ cmake_dependent_option(SDLNET_RELOCATABLE "Create relocatable SDL_net package" " option(SDLNET_WERROR "Treat warnings as errors" OFF) option(SDLNET_SAMPLES "Build the SDL3_net sample program(s)" ${SDLNET_SAMPLES_DEFAULT}) +option(SDLNET_WS_SAMPLE "Build the SDL3_net sample web socket server" OFF) cmake_dependent_option(SDLNET_SAMPLES_INSTALL "Install the SDL3_net sample program(s)" OFF "SDLNET_SAMPLES;SDLNET_INSTALL" OFF) # Save BUILD_SHARED_LIBS variable @@ -317,3 +318,21 @@ if(SDLNET_SAMPLES) # Build at least one example in C90 set_property(TARGET get-local-addrs PROPERTY C_STANDARD 90) endif() + +if(SDLNET_WS_SAMPLE) + find_package(OpenSSL REQUIRED) + + add_executable(ws-server examples/web-socket-server.c) + sdl_add_warning_options(ws-server WARNING_AS_ERROR ${SDLTTF_WERROR}) + sdl_target_link_options_no_undefined(ws-server) + target_link_libraries(ws-server PRIVATE SDL3_net::${sdl3_net_target_name}) + target_link_libraries(ws-server PRIVATE ${sdl3_target_name}) + target_link_libraries(ws-server PRIVATE OpenSSL::SSL) + set_property(TARGET ws-server PROPERTY C_STANDARD 99) + set_property(TARGET ws-server PROPERTY C_EXTENSIONS FALSE) + if(SDLNET_SAMPLES_INSTALL) + install(TARGETS ws-server + RUNTIME DESTINATION "${CMAKE_INSTALL_LIBEXECDIR}/installed-tests/SDL3_net" + ) + endif() +endif() \ No newline at end of file diff --git a/examples/web-socket-server.c b/examples/web-socket-server.c new file mode 100644 index 0000000..95464c8 --- /dev/null +++ b/examples/web-socket-server.c @@ -0,0 +1,281 @@ +#define SDL_WEBSOCKET_ACCEPT_KEY_FUNCTION + +#include +#include +#include + +#include +#include + +bool onPreamble(NET_WSStream *, const char*, const char*, const char*, void*); +bool onHeader(NET_WSStream *, const char*, const char*, void*); +bool onOpen(NET_WSStream *, void*); +bool onData(NET_WSStream *, Uint8, void*, int); +void onClose(NET_WSStream *, void*); + +Uint16 server_port = 2382; + +int main(int argc, char **argv) +{ + const char *interface = NULL; + int simulate_failure = 0; + + for (int i = 1; i < argc; i++) { + const char *arg = argv[i]; + if ((SDL_strcmp(arg, "--port") == 0) && (i < (argc-1))) { + server_port = (Uint16) SDL_atoi(argv[++i]); + } else if ((SDL_strcmp(arg, "--simulate-failure") == 0) && (i < (argc-1))) { + simulate_failure = (int) SDL_atoi(argv[++i]); + } else { + interface = arg; + } + } + + if (!NET_Init()) { + SDL_Log("NET_Init failed: %s\n", SDL_GetError()); + SDL_Quit(); + return 1; + } + + if (interface) { + SDL_Log("Attempting to listen on interface '%s', port %d", interface, (int) server_port); + } else { + SDL_Log("Attempting to listen on all interfaces, port %d", (int) server_port); + } + + NET_Address *server_addr = NULL; + if (interface) { + server_addr = NET_ResolveHostname(interface); + if (!server_addr || (NET_WaitUntilResolved(server_addr, -1) != NET_SUCCESS)) { + SDL_Log("Failed to resolve interface for '%s': %s", interface, SDL_GetError()); + if (server_addr) { + NET_UnrefAddress(server_addr); + } + NET_Quit(); + SDL_Quit(); + return 1; + } else { + SDL_Log("Interface '%s' resolves to '%s' ...", interface, NET_GetAddressString(server_addr)); + } + } + + NET_Server *server = NET_CreateServer(server_addr, server_port); + if (!server) { + SDL_Log("Failed to create server: %s", SDL_GetError()); + } else { + SDL_Log("Server is ready! Open http://%s:%d in your browser", + interface == NULL ? "localhost" : interface, (int) server_port); + int num_vsockets = 1; + void *vsockets[128]; + SDL_zeroa(vsockets); + vsockets[0] = server; + while (NET_WaitUntilInputAvailable(vsockets, num_vsockets, -1) >= 0) { + NET_StreamSocket *streamsocket = NULL; + if (!NET_AcceptClient(server, &streamsocket)) { + SDL_Log("NET_AcceptClient failed: %s", SDL_GetError()); + break; + } else if (streamsocket) { // new connection! + SDL_Log("New connection from %s!", NET_GetAddressString(NET_GetStreamSocketAddress(streamsocket))); + if (num_vsockets >= (int) (SDL_arraysize(vsockets) - 1)) { + SDL_Log(" (too many connections, though, so dropping immediately.)"); + NET_DestroyStreamSocket(streamsocket); + } else { + if (simulate_failure) { + NET_SimulateStreamPacketLoss(streamsocket, simulate_failure); + } + NET_WSStream * ws = NET_CreateWSStream(streamsocket, onPreamble, onHeader, onOpen, onData, onClose, streamsocket); + if (!ws) { + SDL_Log("NET_CreateWSStream: %s\n", SDL_GetError()); + break; + } + vsockets[num_vsockets++] = ws; + } + } + + for (int i = 1; i < num_vsockets; i++) { + NET_WSStream * ws = (NET_WSStream *) vsockets[i]; + if(!ws || !NET_UpdateWSStream(ws)){ + SDL_Log("Dropping connection to '%s'\n", NET_GetAddressString(NET_GetWSStreamAddress(ws))); + NET_DestroyWSStream(ws); + vsockets[i] = NULL; + if (i < (num_vsockets - 1)) { + SDL_memmove(&vsockets[i], &vsockets[i+1], sizeof (vsockets[0]) * ((num_vsockets - i) - 1)); + } + num_vsockets--; + i--; + } + } + } + + SDL_Log("Destroying server..."); + NET_DestroyServer(server); + } + + SDL_Log("Shutting down..."); + NET_Quit(); + SDL_Quit(); + return 0; +} + +const char* indexFormat = "" +"" +" SDL3 Web Socket Server Test" +"" +"" +" " +"
" +"
" +"
" +" " +" " +"
" +"
" +" " +" " +"
" +"
" +"
    " +"
" +"
" +"" +""; + +bool onPreamble(NET_WSStream *ws, const char *method, const char *route, const char *protocol, void *userdata) +{ + (void)ws; + bool isWebSocket = false; + bool logPreamble = false; + char header[128]; + NET_StreamSocket *streamsocket = (NET_StreamSocket *)userdata; + if (SDL_strcmp(method, "GET") == 0 && SDL_strcmp(route, "/") == 0) { + char response[2048]; + + const int responseSize = SDL_snprintf(response, sizeof(response), + indexFormat, NET_GetAddressString(NET_GetStreamSocketAddress(streamsocket)), server_port); + + const int headerSize = SDL_snprintf(header, sizeof(header), + "HTTP/1.1 200 OK\r\n" + "Connection: close\r\n" + "Content-Type: text/html\r\n" + "Content-Length: %d\r\n" + "\r\n", + responseSize); + + NET_WriteToStreamSocket(streamsocket, header, headerSize); + NET_WriteToStreamSocket(streamsocket, response, responseSize); + logPreamble = true; + } else if (SDL_strcmp(method, "GET") == 0 && SDL_strcmp(route, "/ws") == 0){ + isWebSocket = true; + logPreamble = true; + } else { + const int headerSize = SDL_snprintf(header, sizeof(header), + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "\r\n"); + NET_WriteToStreamSocket(streamsocket, header, headerSize); + } + + if (logPreamble) { + SDL_Log("Method: %s; Route: %s; Protocol: %s\n", method, route, protocol); + } + return isWebSocket; +} + +bool onHeader(NET_WSStream *ws, const char *key, const char *value, void *userdata) +{ + (void)ws; + (void)userdata; + SDL_Log("Header %s=%s\n", key, value); + return true; +} + +bool onOpen(NET_WSStream *ws, void *userdata) +{ + (void)ws; + (void)userdata; + return true; +} + +bool onData(NET_WSStream *ws, Uint8 opcode, void *buf, int len) +{ + if(opcode == NET_WS_OP_CODE_TEXT) { + SDL_Log("Received: %.*s\n", len, (char*)buf); + + } else { + SDL_Log("Received: %d bytes\n", len); + } + return NET_SendPayloadToWSStream(ws, opcode, buf, len); +} + +void onClose(NET_WSStream *ws, void *userdata) +{ + (void)ws; + (void)userdata; +} + +bool NET_ConvertToSecWebSocketAcceptKey(SDL_INOUT_Z_CAP(maxlen) char *wsKeyPlusMagicString, int maxlen) +{ + char *buffer = SDL_strdup(wsKeyPlusMagicString); + if (!buffer) { + return false; + } + + // Prepare to perform SHA1 hash + EVP_MD_CTX *ctx = EVP_MD_CTX_create(); + const EVP_MD *md = EVP_sha1(); + EVP_DigestInit_ex(ctx, md, NULL); + + // SHA1 hash the key + magic string + EVP_DigestUpdate(ctx, wsKeyPlusMagicString, strlen(wsKeyPlusMagicString)); + + unsigned int len; + EVP_DigestFinal_ex(ctx, (unsigned char*)buffer, &len); + EVP_MD_CTX_destroy(ctx); + EVP_cleanup(); + + if((int)len >= maxlen) { + return false; + } + + // Base64 encode the contents of buffer and place into key + magic string + EVP_EncodeBlock((unsigned char*)wsKeyPlusMagicString, (unsigned char*)buffer, len); + SDL_free(buffer); + SDL_Log("Key=%s\n", wsKeyPlusMagicString); + return true; +} \ No newline at end of file diff --git a/examples/web-socket-server.html b/examples/web-socket-server.html new file mode 100644 index 0000000..f46c2c3 --- /dev/null +++ b/examples/web-socket-server.html @@ -0,0 +1,58 @@ + + + SDL3 Web Socket Server Test + + + +
+
+
+ + +
+
+ + +
+
+
    +
+
+ + \ No newline at end of file diff --git a/include/SDL3_net/SDL_net.h b/include/SDL3_net/SDL_net.h index 176722b..adf34b6 100644 --- a/include/SDL3_net/SDL_net.h +++ b/include/SDL3_net/SDL_net.h @@ -1125,6 +1125,94 @@ extern SDL_DECLSPEC void SDLCALL NET_SimulateStreamPacketLoss(NET_StreamSocket * */ extern SDL_DECLSPEC void SDLCALL NET_DestroyStreamSocket(NET_StreamSocket *sock); /* Destroy your sockets when finished with them. Does not block, handles shutdown internally. */ +/** + * An object that represents a web socket server connection to another system. + * + * This is meant to be a state machine that parses the incoming data as a web socket connection. + * + * \since This datatype is available since SDL_net 3.0.0. + */ +typedef struct NET_WSStream NET_WSStream; + +#define NET_WS_OP_CODE_CONTINUE 0x0 +#define NET_WS_OP_CODE_TEXT 0x1 +#define NET_WS_OP_CODE_BINARY 0x2 +#define NET_WS_OP_CODE_CLOSE 0x8 +#define NET_WS_OP_CODE_PING 0x9 +#define NET_WS_OP_CODE_PONG 0xA + +/** + * Callback that will be called when an HTTP Request sends its method, route, and protocol. Returning false + * will close the web socket connection. + */ +typedef bool (*NET_OnWSPreamble)(NET_WSStream *ws, const char *method, const char *route, const char* protocol, void *); + +/** + * Callback that will be called for each header received in an HTTP Request. Returning false + * will close web socket connection. + */ +typedef bool (*NET_OnWSHeader)(NET_WSStream *ws, const char *key, const char *value, void *); + +/** + * Callback that will be called when before the HTTP response establishing the web socket connection + * has been sent. Returning false will close web socket connection and not send the HTTP Response. + */ +typedef bool (*NET_OnWSOpen)(NET_WSStream *ws, void *); + +/** + * Callback when a web socket frame has been received in its entirety. + */ +typedef bool (*NET_OnWSData)(NET_WSStream *ws, Uint8 opcode, void *, int); + +/** + * Callback that will be called when the server receives a web socket frame with the close op code. + */ +typedef void (*NET_OnWSClose)(NET_WSStream *ws, void *); + +/** + * WebSocket protocol requires that client's accept key be the + * 'Sec-WebSocket-Key' value concatenated with a magic string + * '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'. We must take the SHA1 hash of that + * concatenation. Then, take the base64 encoding of the hash. While base64 is trivial, + * SHA1 hashing is not. So by default, this SDL function will always return NULL. + * The user should define the SDL_WEBSOCKET_ACCEPT_KEY_FUNCTION macro and + * include an implementation of this function in their code. + * + * TODO: Find a simple implementation to include into SDL. + */ +extern SDL_DECLSPEC bool NET_ConvertToSecWebSocketAcceptKey(SDL_INOUT_Z_CAP(maxlen) char* wsKeyPlusMagicString, int maxlen); + +/** + * The web socket stream will use a NET_StreamSocket to send and receive data. The NET_WSStream now owns + * a reference to the NET_StreamSocket and will clean it up it is destroyed. + */ +extern SDL_DECLSPEC NET_WSStream * NET_CreateWSStream(NET_StreamSocket *sock, + NET_OnWSPreamble, NET_OnWSHeader, NET_OnWSOpen, NET_OnWSData, NET_OnWSClose, void *); + +/** + * Shortcut for a simple web socket connection stream that doesn't handle any other event accept when it receives data. + */ +extern SDL_DECLSPEC NET_WSStream * NET_CreateSimpleWSStream(NET_StreamSocket *sock, NET_OnWSData, void *); + +extern SDL_DECLSPEC NET_Address * NET_GetWSStreamAddress(NET_WSStream *); + +/** + * This will update the web socket state machine. + * + * First, the initial HTTP request must be parsed and verified that it is a valid web socket connection. + * Then, an HTTP response is sent validating that the web socket connection can be established. + * Once established, all incoming and outgoing data must be sent as a web socket data frame. + * + * The client must adhere to the web socket protocol. Typically, the client should just be the + * web socket connection from a client's web browser. + * + * \since This datatype is available since SDL_net 3.0.0. + */ +extern SDL_DECLSPEC bool NET_UpdateWSStream(NET_WSStream *ws); + +extern SDL_DECLSPEC bool NET_SendPayloadToWSStream(NET_WSStream *ws, Uint8 opcode, void *buf, int len); + +extern SDL_DECLSPEC void NET_DestroyWSStream(NET_WSStream *ws); /* Datagram (UDP) API... */ diff --git a/src/SDL_net.c b/src/SDL_net.c index 5a6a6a5..2e6ab46 100644 --- a/src/SDL_net.c +++ b/src/SDL_net.c @@ -197,7 +197,8 @@ typedef enum NET_SocketType { SOCKETTYPE_STREAM, SOCKETTYPE_DATAGRAM, - SOCKETTYPE_SERVER + SOCKETTYPE_SERVER, + SOCKETTYPE_WEBSOCKET } NET_SocketType; @@ -1569,6 +1570,359 @@ void NET_DestroyStreamSocket(NET_StreamSocket *sock) } } +struct NET_WSStream +{ + NET_SocketType socktype; + NET_StreamSocket *stream; + NET_OnWSPreamble onPreamble; + NET_OnWSHeader onHeader; + NET_OnWSData onData; + NET_OnWSOpen onOpen; + NET_OnWSClose onClose; + Uint8 *pending_input_buffer; + int pending_input_len; + int pending_input_allocation; + bool established_connection; + void *userdata; +}; + +NET_WSStream * NET_CreateWSStream(NET_StreamSocket *stream, + NET_OnWSPreamble onPreamble, + NET_OnWSHeader onHeader, + NET_OnWSOpen onOpen, + NET_OnWSData onData, + NET_OnWSClose onClose, + void *userdata) +{ + NET_WSStream *ws = (NET_WSStream *)SDL_calloc(1, sizeof(NET_WSStream)); + if(!ws) { + return NULL; + } + + ws->socktype = SOCKETTYPE_WEBSOCKET; + ws->stream = stream; + ws->onPreamble = onPreamble; + ws->onHeader = onHeader; + ws->onOpen = onOpen; + ws->onData = onData; + ws->onClose = onClose; + ws->userdata = userdata; + return ws; +} + +NET_WSStream * NET_CreateSimpleWSStream(NET_StreamSocket *sock, NET_OnWSData onData, void *userdata) +{ + return NET_CreateWSStream(sock, NULL, NULL, NULL, onData, NULL, userdata); +} + +NET_Address * NET_GetWSStreamAddress(NET_WSStream * ws) { + return NET_GetStreamSocketAddress(ws->stream); +} + +bool NET_WSStreamSendBadRequest(NET_StreamSocket *sock) +{ + char buffer[128]; + SDL_snprintf(buffer, sizeof(buffer), + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "\r\n"); + return NET_WriteToStreamSocket(sock, buffer, sizeof(buffer)); +} + +bool NET_SendPayloadToWSStream(NET_WSStream *ws, Uint8 opcode, void* buf, int len) +{ + Uint8 bytes[2]; + bytes[0] = 0b10000000 | opcode; + if (len < 126) { + bytes[1] = len; + if (!NET_WriteToStreamSocket(ws->stream, bytes, sizeof(bytes))) { + return false; + } + } else if (len < UINT16_MAX) { + bytes[1] = 126; + Uint16 payloadLength = len; + payloadLength = SDL_Swap16BE(payloadLength); + if (!NET_WriteToStreamSocket(ws->stream, bytes, sizeof(bytes)) || + !NET_WriteToStreamSocket(ws->stream, &payloadLength, sizeof(payloadLength))) { + return false; + } + } else { + bytes[1] = 127; + Uint64 payloadLength = len; + payloadLength = SDL_Swap64BE(payloadLength); + if (!NET_WriteToStreamSocket(ws->stream, bytes, sizeof(bytes)) || + !NET_WriteToStreamSocket(ws->stream, &payloadLength, sizeof(payloadLength))) { + return false; + } + } + + return NET_WriteToStreamSocket(ws->stream, buf, len); +} + +bool NET_WSStreamParseInput(NET_WSStream *ws) +{ + // bit 0 = FIN + // bits 1 - 3 = RSV (ignore) + // bits 4 - 7 = Op Code + // bit 8 = Mask + // bit 9 - 15 = Payload Length + if (ws->pending_input_len < 2) { + return true; + } + + // TODO: Handle clients that send partial payloads by setting FIN to 0 + const bool complete = (ws->pending_input_buffer[0] & 0b10000000) != 0; + if (!complete) { + // Send close message to client + NET_SendPayloadToWSStream(ws, NET_WS_OP_CODE_CLOSE, NULL, 0); + return false; + } + + const Uint16 opcode = ws->pending_input_buffer[0] & 0b1111; + + // As part of the web socket specification, client's must always send masked + // payloads to a server. + const bool masked = (ws->pending_input_buffer[1] & 0b10000000) != 0; + if (!masked) { + NET_SendPayloadToWSStream(ws, NET_WS_OP_CODE_CLOSE, NULL, 0); + return false; + } + + Uint8 *maskingKey = &ws->pending_input_buffer[2]; + Uint64 payloadLength = ws->pending_input_buffer[1] & 0b01111111; + if (payloadLength == 127) { + // The next 8 bytes are the actual payload length + // Must use big endian + if ((Uint64)ws->pending_input_len < 2 + sizeof(Uint64)) { + return true; + } + SDL_memcpy(&payloadLength, &ws->pending_input_buffer[2], sizeof(Uint64)); + payloadLength = SDL_Swap64BE(payloadLength); + maskingKey += sizeof(Uint64); + } else if (payloadLength == 126) { + // The next 2 bytes are the actual payload length + // Must use big endian + if ((Uint64)ws->pending_input_len < 2 + sizeof(Uint16)) { + return true; + } + SDL_memcpy(&payloadLength, &ws->pending_input_buffer[2], sizeof(Uint16)); + payloadLength = SDL_Swap16BE(payloadLength); + maskingKey += sizeof(Uint16); + } + + Uint8* payloadStart = maskingKey + 4; + Uint8 *payloadEnd = payloadStart + payloadLength; + + // Check if the entire payload has been received + if(payloadEnd > ws->pending_input_buffer + ws->pending_input_len) { + return true; + } + + // The next four bytes are the masking key. Then, the next payload length bytes + // are the actual payload. + // Must decode every byte in the payload using byte XOR mask byte + for (Uint64 i = 0; i < payloadLength; ++i) { + payloadStart[i] ^= maskingKey[i % 4]; + } + + switch (opcode) { + case NET_WS_OP_CODE_TEXT: + case NET_WS_OP_CODE_BINARY: + if(ws->onData && !ws->onData(ws, opcode, payloadStart, payloadLength)){ + NET_SendPayloadToWSStream(ws, NET_WS_OP_CODE_CLOSE, NULL, 0); + return false; + } + break; + case NET_WS_OP_CODE_CLOSE: + return false; + case NET_WS_OP_CODE_PING: + NET_SendPayloadToWSStream(ws, NET_WS_OP_CODE_PONG, payloadStart, payloadLength); + break; + case NET_WS_OP_CODE_PONG: + break; + default: + // Unknown op code + NET_SendPayloadToWSStream(ws, NET_WS_OP_CODE_CLOSE, NULL, 0); + return false; + } + + const Uint64 totalRead = payloadEnd - ws->pending_input_buffer; + const Uint64 remaining = ws->pending_input_len - totalRead; + SDL_memmove(ws->pending_input_buffer, payloadEnd, remaining); + ws->pending_input_len = remaining; + return true; +} + +bool NET_UpdateWSStream(NET_WSStream *ws) +{ + if(!ws || ws->socktype != SOCKETTYPE_WEBSOCKET) { + SDL_InvalidParamError("ws"); + return false; + } + + char buffer[4096]; + const int bytesRead = NET_ReadFromStreamSocket(ws->stream, buffer, sizeof(buffer)); + switch(bytesRead) { + case -1: + return false; + case 0: + return true; + default: + break; + } + + const int min_alloc = ws->pending_input_len + bytesRead; + if (min_alloc > ws->pending_input_allocation) { + int newlen = sizeof(buffer) + ws->pending_input_allocation; + void *ptr = SDL_realloc(ws->pending_input_buffer, newlen); + if (!ptr) { + return false; + } + ws->pending_input_buffer = (Uint8 *) ptr; + ws->pending_input_allocation = newlen; + } + + SDL_memcpy(ws->pending_input_buffer + ws->pending_input_len, buffer, bytesRead); + ws->pending_input_len += bytesRead; + + if (ws->established_connection) { + return NET_WSStreamParseInput(ws); + } + + // If '\r\n' isn't found, then the HTTP Request has not been sent in its entireity. + char *start = (char *)ws->pending_input_buffer; + char *end = SDL_strnstr(start, "\r\n\r\n", ws->pending_input_len); + if (end == NULL) { + return true; + } + + // Insert a null-terminator before the last empty line + end += 2; + *end = '\0'; + + char *endOfMethod = SDL_strchr(start, ' '); + if (!endOfMethod) { + NET_WSStreamSendBadRequest(ws->stream); + return false; + } + *endOfMethod = '\0'; + const char *method = start; + start = endOfMethod + 1; + + char *endOfRoute = SDL_strchr(start, ' '); + if (!endOfRoute) { + NET_WSStreamSendBadRequest(ws->stream); + return false; + } + *endOfRoute = '\0'; + const char *route = start; + start = endOfRoute + 1; + + char *endOfProtocol = SDL_strstr(start, "\r\n"); + if (!endOfProtocol) { + NET_WSStreamSendBadRequest(ws->stream); + return false; + } + *endOfProtocol = '\0'; + const char *protocol = start; + start = endOfProtocol + 2; + + if (ws->onPreamble && !ws->onPreamble(ws, method, route, protocol, ws->userdata)) { + return false; + } + + const char *upgrade = NULL; + const char *connection = NULL; + const char *wsKey = NULL; + + while (*start) { + end = SDL_strstr(start, ": "); + if (!end) { + SDL_SetError("HTTP Request is missing ': ' for header"); + return false; + } + *end = '\0'; + const char *key = start; + start = end + 2; + + end = SDL_strstr(start, "\r\n"); + if (!end) { + SDL_SetError("HTTP Request is missing '\\r\\n' for the end of the header"); + return false; + } + *end = '\0'; + const char *value = start; + start = end + 2; + + if( ws->onHeader && !ws->onHeader(ws, key, value, ws->userdata)) { + return false; + } + + if (SDL_strcmp(key, "Upgrade") == 0) { + upgrade = value; + } else if (SDL_strcmp(key, "Connection") == 0){ + connection = value; + } else if (SDL_strcmp(key, "Sec-WebSocket-Key") == 0){ + wsKey = value; + } + } + + if (!wsKey || !upgrade || !connection) { + NET_WSStreamSendBadRequest(ws->stream); + return false; + } + + if (ws->onOpen && !ws->onOpen(ws, ws->userdata)) { + return false; + } + + // Clear the input buffer since it should only contain the HTTP request + ws->pending_input_len = 0; + + char acceptKey[256]; + + // Web Socket Key + Magic string defined in the Web Socket protocol + SDL_snprintf(acceptKey, sizeof(acceptKey), "%s258EAFA5-E914-47DA-95CA-C5AB0DC85B11", wsKey); + + if (!NET_ConvertToSecWebSocketAcceptKey(acceptKey, sizeof(acceptKey))) { + SDL_SetError("Failed to create web socket accept key"); + return false; + } + + char response[256]; + int written = SDL_snprintf(response, sizeof(response), + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: %s\r\n" + "Connection: %s\r\n" + "Sec-WebSocket-Accept: %s\r\n" + "\r\n", + upgrade, connection, acceptKey); + + ws->established_connection = true; + return NET_WriteToStreamSocket(ws->stream, response, written); +} + +#if !SDL_WEBSOCKET_ACCEPT_KEY_FUNCTION +// This should be implemented by the user if they want web socket support +bool NET_ConvertToSecWebSocketAcceptKey(SDL_INOUT_Z_CAP(maxlen) char *key, int maxlen) { + (void)key; + (void)maxlen; + return false; +} +#endif + +void NET_DestroyWSStream(NET_WSStream *ws) +{ + if (ws) { + if (ws->onClose) { + ws->onClose(ws, ws->userdata); + } + SDL_free(ws->pending_input_buffer); + NET_DestroyStreamSocket(ws->stream); + SDL_free(ws); + } +} + typedef struct NET_DatagramSocketHandle { Socket handle; @@ -1963,6 +2317,7 @@ typedef union NET_GenericSocket NET_SocketType socktype; NET_StreamSocket stream; NET_DatagramSocket dgram; + NET_WSStream ws; NET_Server server; } NET_GenericSocket; @@ -1987,6 +2342,7 @@ int NET_WaitUntilInputAvailable(void **vsockets, int numsockets, int timeoutms) const NET_GenericSocket *sock = sockets[i]; switch (sock->socktype) { case SOCKETTYPE_STREAM: + addStreamHandles: numhandles++; break; case SOCKETTYPE_DATAGRAM: @@ -1995,6 +2351,9 @@ int NET_WaitUntilInputAvailable(void **vsockets, int numsockets, int timeoutms) case SOCKETTYPE_SERVER: numhandles += sock->server.num_handles; break; + case SOCKETTYPE_WEBSOCKET: + sock = (NET_GenericSocket *)sock->ws.stream; + goto addStreamHandles; } } @@ -2018,6 +2377,7 @@ int NET_WaitUntilInputAvailable(void **vsockets, int numsockets, int timeoutms) switch (sock->socktype) { case SOCKETTYPE_STREAM: + addStreamSocket: pfd->fd = sock->stream.handle; if (sock->stream.status == NET_WAITING) { pfd->events = POLLOUT; // marked as writable when connection is complete. @@ -2048,6 +2408,10 @@ int NET_WaitUntilInputAvailable(void **vsockets, int numsockets, int timeoutms) pfd++; } break; + + case SOCKETTYPE_WEBSOCKET: + sock = (NET_GenericSocket *)sock->ws.stream; + goto addStreamSocket; } } @@ -2066,6 +2430,7 @@ int NET_WaitUntilInputAvailable(void **vsockets, int numsockets, int timeoutms) switch (sock->socktype) { case SOCKETTYPE_STREAM: { + pumpStreamSocket: SDL_assert(pfd->fd == sock->stream.handle); const bool failed = ((pfd->revents & (POLLERR|POLLHUP|POLLNVAL)) != 0) ? true : false; const bool writable = (pfd->revents & POLLOUT) ? true : false; @@ -2129,6 +2494,10 @@ int NET_WaitUntilInputAvailable(void **vsockets, int numsockets, int timeoutms) } } break; + + case SOCKETTYPE_WEBSOCKET: + sock = (NET_GenericSocket *)sock->ws.stream; + goto pumpStreamSocket; } if (count_it) { diff --git a/src/SDL_net.sym b/src/SDL_net.sym index d7b8a93..277a32d 100644 --- a/src/SDL_net.sym +++ b/src/SDL_net.sym @@ -2,13 +2,17 @@ SDL3_net_0.0.0 { global: NET_AcceptClient; NET_CompareAddresses; + NET_ConvertToSecWebSocketAcceptKey; NET_CreateClient; NET_CreateDatagramSocket; NET_CreateServer; + NET_CreateSimpleWSStream; + NET_CreateWSStream; NET_DestroyDatagram; NET_DestroyDatagramSocket; NET_DestroyServer; NET_DestroyStreamSocket; + NET_DestroyWSStream; NET_FreeLocalAddresses; NET_GetAddressStatus; NET_GetAddressString; @@ -16,6 +20,7 @@ SDL3_net_0.0.0 { NET_GetLocalAddresses; NET_GetStreamSocketAddress; NET_GetStreamSocketPendingWrites; + NET_GetWSStreamAddress; NET_Version; NET_Init; NET_Quit; @@ -24,9 +29,11 @@ SDL3_net_0.0.0 { NET_RefAddress; NET_ResolveHostname; NET_SendDatagram; + NET_SendPayloadToWSStream; NET_SimulateAddressResolutionLoss; NET_SimulateDatagramPacketLoss; NET_SimulateStreamPacketLoss; + NET_UpdateWSStream; NET_UnrefAddress; NET_WaitUntilConnected; NET_WaitUntilInputAvailable;