diff --git a/.circleci/config.yml b/.circleci/config.yml index 0d1f05b1e0..5f99b50eb8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -247,6 +247,36 @@ jobs: name: Run Tests command: zig-out/bin/test 2>&1 | cat + test_webzockets: + executor: linux-executor + steps: + - checkout + - attach_workspace: + at: workspace + - restore_cache: + key: linux-x86_64-0.14.1-{{ checksum "build.zig.zon" }}-selfhosted + - run: + name: Build and Run Webzockets Tests + command: | + cd src/rpc/webzockets + ../../../workspace/zig/zig build test -Dcpu=x86_64_v3 --summary all + + test_webzockets_macos: + executor: macos-executor + steps: + - checkout + - attach_workspace: + at: workspace + - restore_cache: + key: macos-aarch64-0.14.1-{{ checksum "build.zig.zon" }} + - run: + name: Build and Run Webzockets Tests + command: | + ulimit -Hn unlimited + ulimit -Sn unlimited + cd src/rpc/webzockets + ../../../workspace/zig/zig build test -Dcpu=apple_m4 --summary all + test_kcov_linux: executor: linux-executor steps: @@ -357,6 +387,9 @@ workflows: - test_linux: requires: - build_linux + - test_webzockets: + requires: + - setup_zig_linux - test_kcov_linux: requires: - setup_zig_linux @@ -393,3 +426,6 @@ workflows: - test_macos_hashmap_ledger: requires: - build_check_macos + - test_webzockets_macos: + requires: + - setup_zig_macos diff --git a/build.zig b/build.zig index f59dc31609..06c3bc271a 100644 --- a/build.zig +++ b/build.zig @@ -230,6 +230,7 @@ pub fn build(b: *Build) !void { const poseidon_mod = b.dependency("poseidon", dep_opts).module("poseidon"); const xev_mod = b.dependency("xev", dep_opts).module("xev"); const pretty_table_mod = b.dependency("prettytable", dep_opts).module("prettytable"); + const webzockets_mod = b.dependency("webzockets", dep_opts).module("webzockets"); const lsquic_dep = b.dependency("lsquic", .{ .target = config.target, @@ -295,6 +296,7 @@ pub fn build(b: *Build) !void { .{ .name = "sqlite", .module = sqlite_mod }, .{ .name = "ssl", .module = ssl_mod }, .{ .name = "tracy", .module = tracy_mod }, + .{ .name = "webzockets", .module = webzockets_mod }, .{ .name = "xev", .module = xev_mod }, .{ .name = "zstd", .module = zstd_mod }, .{ .name = "table", .module = gh_table }, diff --git a/build.zig.zon b/build.zig.zon index 1116eb6528..9375167513 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -26,8 +26,8 @@ .hash = "lsquic-0.1.0-oZGGDdOZHQFsyhwdJq9eAB2GFIPzpNm90y4-8ncaYNku", }, .xev = .{ - .url = "git+https://github.com/Syndica/libxev#bfb37ec5ad81a92f3fdc41f0d36e605a0490374d", - .hash = "libxev-0.0.0-86vtc5XGEwDhneYf_GapeLISR0pVPeBA5tSlqRS__1d-", + .url = "git+https://github.com/Syndica/libxev#c16162f0f9047f27d6230c548e1e25070b9be163", + .hash = "libxev-0.0.0-86vtcyQcEwAxRuv5vgdyq8LCRKHNGLgtKfS4KNHdic2V", }, .prettytable = .{ .url = "git+https://github.com/dying-will-bullet/prettytable-zig#46b6ad9b5970def35fa43c9613cd244f28862fa9", @@ -49,6 +49,9 @@ .url = "git+https://github.com/Syndica/tracy-zig#fd2576ccbe1b2eff5557ad6d91f2834311c8009c", .hash = "zig_tracy-0.13.0-4TLLRxFkAADP0fRkT2yR8jDk-2I4xg12GbfAF8Pgln_m", }, + .webzockets = .{ + .path = "src/rpc/webzockets", + }, .sqlite = .{ .url = "https://www.sqlite.org/2025/sqlite-amalgamation-3490200.zip", .hash = "N-V-__8AAH-mpwB7g3MnqYU-ooUBF1t99RP27dZ9addtMVXD", diff --git a/docs/docusaurus/docs/code/webzockets.md b/docs/docusaurus/docs/code/webzockets.md new file mode 100644 index 0000000000..22981c7aab --- /dev/null +++ b/docs/docusaurus/docs/code/webzockets.md @@ -0,0 +1,265 @@ +# Webzockets + +A WebSocket (RFC 6455) library for Zig 0.14.1, built on `libxev`. Server and client. No hidden allocators — all memory is allocated through caller-provided allocators with caller-configured pool sizes and limits. Sends are zero-copy. Server connections are memory-pooled; client connections are caller-owned. + +## Quick Start + +See [examples/echo_server.zig](examples/echo_server.zig) and [examples/simple_client.zig](examples/simple_client.zig). + +## Usage Rules + +### Buffer Lifetime + +- **Server `sendText`/`sendBinary`**: zero-copy. Keep buffer alive until `onWriteComplete`. +- **Client `sendText`/`sendBinary`**: masks in-place (XOR). Don't read/free/reuse until `onWriteComplete`. +- **Read data in callbacks**: transient — points into internal buffers reused after callback returns. Copy if needed. +- **`sendPing`/`sendPong`**: copies internally. Buffer can be freed immediately. Does not trigger `onWriteComplete`. + +### Write Concurrency + +- **One data write at a time.** Second `sendText`/`sendBinary` before completion returns `error.WriteBusy`. Queue from `onWriteComplete`. +- **Control frames** (`sendPing`/`sendPong`) use a separate internal queue (256 bytes). `error.QueueFull` on overflow. + +### Connection State + +- **Sends on non-`.open` connections** return `error.InvalidState`. `close()` no-ops if already closing. +- **`onClose` fires exactly once.** Server connections are pool-released afterward — don't reference them. Client connections require caller `deinit()`. +- **`onWriteComplete` fires even on disconnect** (so callers can free buffers). +- **Idle timeout** (optional, server only) sends `close(.going_away, "")`, following normal close handshake. +- **Close-handshake timeout** force-disconnects if peer doesn't respond. + +### Handler Lifecycle (Server) + +- **`Handler.init()`** runs before the 101 response. Return error to reject (socket closed, no HTTP response). +- **Handler Context:** handlers declare `pub const Context = T` (or `void`). If non-void, set `Config.handler_context: *T`; it’s passed to `Handler.init` as the second parameter. The pointer must remain valid for any handshake/connection that might call `init` or `onHandshakeFailed`. +- **`onHandshakeFailed`** (optional): called if the handshake fails _after_ `init` succeeds (e.g., connection pool exhausted, write error, server shutdown). Use it to clean up resources allocated in `init`. Neither `onOpen` nor `onClose` will fire. +- **Handler is embedded by value** in the pooled connection — no self-referential fields. + +### Platform + +- **macOS (kqueue) / Linux (epoll):** `xev.Loop` must be initialized with a `ThreadPool`. libxev dispatches socket close operations to the thread pool on these backends; without one, closes fail and FDs leak. Both `Server.init` and `Client.init` assert that the thread pool is set. + +### Timers + +- **Idle timeout** (`idle_timeout_ms`, server only, default `null`): sends close on inactivity. Resets on each read. +- **Close-handshake timeout** (`close_timeout_ms`, default 5000ms): force-disconnects if peer doesn't complete close handshake. +- **libxev tip:** prefer `Timer.cancel()` over raw `.cancel` completions (different behavior across backends). Note: cancellation still delivers the original callback with `error.Canceled`. + +### Event Loop + +Single-threaded. All callbacks run on the `loop.run()` thread. No locking needed; handlers must not block. + +### Client PRNG + +- Client handshake key generation and RFC 6455 masking require a caller-provided `ClientMaskPRNG` (a thin wrapper around `std.Random.DefaultCsprng`). +- `ClientMaskPRNG` is **not thread-safe**; only use it from the `loop.run()` thread and do not share it across loops/threads. +- The pointer must remain valid and **must not move** for the lifetime of any `ClientConnection` using it. + +### UTF-8 Validation + +The library delivers text messages to `onMessage` without validating UTF-8. Per RFC 6455 §8.1, endpoints must close the connection on invalid UTF-8 in text frames. Validate in your handler: + +See [autobahn/server/server.zig](autobahn/server/server.zig) for a complete example (required to pass Autobahn section 6.x tests). + +## Architecture + +``` + ┌──────────────────────┐ + │ User Application │ + │ (defines Handler) │ + └──┬───────┬────────┬──┘ + │ │ │ + ┌─────────────────┘ │ └─────────────────┐ + ▼ │ ▼ +┌─────────────────────┐ │ ┌─────────────────────┐ +│ Server │ │ │ Client (transient)│ +│ TCP accept loop + │ │ │ TCP connect + │ +│ memory pools │ │ │ handshake, then │ +└──────────┬──────────┘ │ │ can be discarded │ + │ │ └──────────┬──────────┘ + ┌──────┴──────┐ │ ▼ + ▼ ▼ │ ┌─────────────────────┐ + ┌────────┐ ┌────────┐ │ │ ClientConnection │ + │ Hand- │ │ Hand- │ │ │ (caller-provided) │ + │ shake │ │ shake │ │ └──────────┬──────────┘ + │ pooled │ │ pooled │ │ │ + └───┬────┘ └───┬────┘ │ │ + ▼ ▼ │ │ + ┌────────┐ ┌────────┐ │ │ + │ Conn │ │ Conn │ │ │ + │ pooled │ │ pooled │ │ │ + └───┬────┘ └───┬────┘ │ │ + └─────┬─────┘ │ │ + └─────────────────────────┼──────────────────────────┘ + ▼ + ┌────────────────────────────────────┐ + │ libxev Event Loop │ + └────────────────────────────────────┘ +``` + +**Server-side:** Each `Handshake` and `Connection` is a self-contained pooled type with its own read buffer and back-pointer to the server. + +**Client-side:** The client is transient — connects TCP, handshakes, initializes a caller-provided `*ClientConnection`, then can be discarded. + +### Connection Lifecycle + +**Server:** +`TCP accept → Handshake (pool) → read HTTP upgrade → validate → Handler.init() → 101 response → Connection (pool) → onOpen → read loop (parse/unmask/reassemble/dispatch) → close handshake → onClose → release to pool` + +**Client:** +`Client.connect() → write HTTP upgrade → read 101 → validate Sec-WebSocket-Accept → init ClientConnection (zero-copy handoff of leftover bytes) → onOpen → read loop (parse/reassemble/dispatch) → close handshake → onClose → deinit` + +## File Structure + +``` +src/ +├── root.zig Public API re-exports +├── types.zig Protocol types, enums, error sets +├── mask.zig XOR masking (SIMD-accelerated) +├── frame.zig Frame parsing/encoding (RFC 6455 §5) +├── http.zig HTTP parsing/encoding +├── reader.zig Frame reader with buffer management +├── control_queue.zig Ring buffer for outbound control frames +├── server/ +│ ├── server.zig TCP listener, accept loop, graceful shutdown +│ ├── slot_pool.zig Memory pool with active count tracking +│ ├── handshake.zig HTTP upgrade handshake (poolable) +│ └── connection.zig WebSocket state machine (poolable) +└── client/ + ├── client.zig Transient: connect, handshake, init connection + ├── handshake.zig Client-side HTTP upgrade state machine + └── connection.zig WebSocket state machine (caller-owned) + +examples/ Echo server and client examples +e2e_tests/ Client-server integration tests + server/ Server behavior tests + client/ Client behavior tests + support/ Shared test helpers, raw client +autobahn/ Autobahn conformance suite runners +``` + +## API Reference + +### Server Config + +```zig +const EchoServer = ws.Server(EchoHandler, 4096); +// Handler ^read buf sz + +const Config = struct { + address: std.net.Address, + tcp_accept_backlog: u31 = 128, + max_message_size: usize = 16 * 1024 * 1024, + initial_handshake_pool_size: usize = 16, + initial_connection_pool_size: usize = 64, + max_handshakes: ?usize = null, + max_connections: ?usize = null, + idle_timeout_ms: ?u32 = null, + close_timeout_ms: u32 = 5_000, + handler_context: …, // if Handler.Context != void: *Handler.Context, else: void ({}) +}; +``` + +### Client Config + +```zig +const SimpleClient = ws.Client(ClientHandler, 4096); +// Handler ^read buf sz + +const Config = struct { + address: std.net.Address, + path: []const u8 = "/", + max_message_size: usize = 16 * 1024 * 1024, + close_timeout_ms: u32 = 5_000, +}; +``` + +The client is a transient value type. `init` doesn't allocate. Caller provides a `*ClientConnection`, allocator, and `*ClientMaskPRNG`: + +```zig +var seed: [ws.ClientMaskPRNG.secret_seed_length]u8 = undefined; +std.crypto.random.bytes(&seed); +var csprng = ws.ClientMaskPRNG.init(seed); + +var conn: SimpleClient.Conn = undefined; +var client = SimpleClient.init(allocator, &loop, &handler, &conn, &csprng, .{ + .address = std.net.Address.parseIp4("127.0.0.1", 8080) catch unreachable, + .path = "/", + .max_message_size = 16 * 1024 * 1024, + .close_timeout_ms = 5_000, +}); +try client.connect(); +// After handshake, `conn` is live — client can be discarded +``` + +### Handler Interface + +```zig +// Required +fn onMessage(self: *Handler, conn: *Conn, message: Message) void +fn onWriteComplete(self: *Handler, conn: *Conn) void +fn onClose(self: *Handler, conn: *Conn) void + +// Optional +fn onOpen(self: *Handler, conn: *Conn) void +fn onPing(self: *Handler, conn: *Conn, data: []const u8) void +fn onPong(self: *Handler, conn: *Conn, data: []const u8) void + +// Optional +fn onBytesRead(self: *Handler, conn: *Conn, bytes_read: usize) void + +// Optional (client-only) +fn onSocketClose(self: *Handler) void + +// Optional (server-only) +fn onHandshakeFailed(self: *Handler) void + +// Server-only (required) +pub const Context = void; // or a real type T +fn init(request: http.Request, context: if (Context == void) void else *Context) !Handler +``` + +If `onPing` is not declared, auto-pong replies with latest-wins semantics. If declared, auto-pong is disabled — handler must call `conn.sendPong()`. + +`onBytesRead` fires on every TCP read completion regardless of whether reads are paused. Combine with `peekBufferedBytes()` to observe raw data arrival (e.g. for byte-counting or deciding when to pause). + +Server `init` runs before 101. Return error to reject. `onHandshakeFailed` fires if the handshake fails after `init` succeeds (pool exhaustion, write error, shutdown); use it to clean up `init`-allocated resources. + +### Connection Methods + +```zig +fn sendText(data) !void // server: []const u8 (zero-copy), client: []u8 (zero-copy, masked in-place) +fn sendBinary(data) !void // same as above +fn sendPing(data) !void // copies internally, max 125 bytes +fn sendPong(data) !void // copies internally, max 125 bytes +fn close(code: CloseCode, reason: []const u8) void +fn pauseReads() void // pause frame dispatch; TCP reads continue until buffer full +fn resumeReads() void // resume dispatch; drains buffered frames +fn peekBufferedBytes() []const u8 // raw bytes in read buffer (transient slice) +``` + +## Tests + +Unit tests colocated in source files. E2E tests in `e2e_tests/`. + +```bash +zig build test --summary all +``` + +## Autobahn Testsuite + +Industry-standard WebSocket conformance suite. **Requires Docker.** + +```bash +bash autobahn/server/run.sh # Results: autobahn/server/reports/index.html +bash autobahn/client/run.sh # Results: autobahn/client/reports/index.html +``` + +**Excluded:** 12.x / 13.x (permessage-deflate not implemented) + +## Current Limitations + +- **No custom response headers in the upgrade response (server):** The 101 response is fixed — no way to add `Sec-WebSocket-Protocol` or other headers. +- **No permessage-deflate (compression):** RFC 7692 is not implemented. Adds complexity around buffer ownership for the send API since compressed frames can't be zero-copy in the same way. +- **No DNS resolution (client):** `Config.address` takes a `std.net.Address` (IP only). The `Host` header is formatted from this address, but real-world servers typically expect the domain name. +- **No TLS:** Most important for the client — servers can sit behind a TLS terminator. diff --git a/scripts/style.py b/scripts/style.py index eb8138a675..4ed2590deb 100644 --- a/scripts/style.py +++ b/scripts/style.py @@ -31,6 +31,13 @@ def main(): def get_files(args): files_to_check = [] + excluded_dirs = { + ".git", + ".zig-cache", + "zig-cache", + "zig-out", + "__pycache__", + } dirs = [*args.dirs] while len(dirs) > 0: d = dirs.pop() @@ -40,6 +47,8 @@ def get_files(args): for file in files: full_path = os.path.join(d, file) if os.path.isdir(full_path): + if file in excluded_dirs: + continue dirs.append(full_path) else: if file.endswith(".zig"): diff --git a/src/rpc/lib.zig b/src/rpc/lib.zig index dbe76ca405..48e9bfbed9 100644 --- a/src/rpc/lib.zig +++ b/src/rpc/lib.zig @@ -1,3 +1,4 @@ +pub const webzockets = @import("webzockets"); pub const client = @import("client.zig"); pub const http = @import("http.zig"); pub const methods = @import("methods.zig"); diff --git a/src/rpc/webzockets/.gitignore b/src/rpc/webzockets/.gitignore new file mode 100644 index 0000000000..62c9f82fbc --- /dev/null +++ b/src/rpc/webzockets/.gitignore @@ -0,0 +1,13 @@ +# build outputs +/.zig-cache +/zig-cache +/zig-out + +# desktop app config +.DS_Store +.vscode/ + +# Autobahn testsuite reports +/autobahn/client/reports +/autobahn/server/reports +/autobahn/server/server.log diff --git a/src/rpc/webzockets/README.md b/src/rpc/webzockets/README.md new file mode 100644 index 0000000000..22981c7aab --- /dev/null +++ b/src/rpc/webzockets/README.md @@ -0,0 +1,265 @@ +# Webzockets + +A WebSocket (RFC 6455) library for Zig 0.14.1, built on `libxev`. Server and client. No hidden allocators — all memory is allocated through caller-provided allocators with caller-configured pool sizes and limits. Sends are zero-copy. Server connections are memory-pooled; client connections are caller-owned. + +## Quick Start + +See [examples/echo_server.zig](examples/echo_server.zig) and [examples/simple_client.zig](examples/simple_client.zig). + +## Usage Rules + +### Buffer Lifetime + +- **Server `sendText`/`sendBinary`**: zero-copy. Keep buffer alive until `onWriteComplete`. +- **Client `sendText`/`sendBinary`**: masks in-place (XOR). Don't read/free/reuse until `onWriteComplete`. +- **Read data in callbacks**: transient — points into internal buffers reused after callback returns. Copy if needed. +- **`sendPing`/`sendPong`**: copies internally. Buffer can be freed immediately. Does not trigger `onWriteComplete`. + +### Write Concurrency + +- **One data write at a time.** Second `sendText`/`sendBinary` before completion returns `error.WriteBusy`. Queue from `onWriteComplete`. +- **Control frames** (`sendPing`/`sendPong`) use a separate internal queue (256 bytes). `error.QueueFull` on overflow. + +### Connection State + +- **Sends on non-`.open` connections** return `error.InvalidState`. `close()` no-ops if already closing. +- **`onClose` fires exactly once.** Server connections are pool-released afterward — don't reference them. Client connections require caller `deinit()`. +- **`onWriteComplete` fires even on disconnect** (so callers can free buffers). +- **Idle timeout** (optional, server only) sends `close(.going_away, "")`, following normal close handshake. +- **Close-handshake timeout** force-disconnects if peer doesn't respond. + +### Handler Lifecycle (Server) + +- **`Handler.init()`** runs before the 101 response. Return error to reject (socket closed, no HTTP response). +- **Handler Context:** handlers declare `pub const Context = T` (or `void`). If non-void, set `Config.handler_context: *T`; it’s passed to `Handler.init` as the second parameter. The pointer must remain valid for any handshake/connection that might call `init` or `onHandshakeFailed`. +- **`onHandshakeFailed`** (optional): called if the handshake fails _after_ `init` succeeds (e.g., connection pool exhausted, write error, server shutdown). Use it to clean up resources allocated in `init`. Neither `onOpen` nor `onClose` will fire. +- **Handler is embedded by value** in the pooled connection — no self-referential fields. + +### Platform + +- **macOS (kqueue) / Linux (epoll):** `xev.Loop` must be initialized with a `ThreadPool`. libxev dispatches socket close operations to the thread pool on these backends; without one, closes fail and FDs leak. Both `Server.init` and `Client.init` assert that the thread pool is set. + +### Timers + +- **Idle timeout** (`idle_timeout_ms`, server only, default `null`): sends close on inactivity. Resets on each read. +- **Close-handshake timeout** (`close_timeout_ms`, default 5000ms): force-disconnects if peer doesn't complete close handshake. +- **libxev tip:** prefer `Timer.cancel()` over raw `.cancel` completions (different behavior across backends). Note: cancellation still delivers the original callback with `error.Canceled`. + +### Event Loop + +Single-threaded. All callbacks run on the `loop.run()` thread. No locking needed; handlers must not block. + +### Client PRNG + +- Client handshake key generation and RFC 6455 masking require a caller-provided `ClientMaskPRNG` (a thin wrapper around `std.Random.DefaultCsprng`). +- `ClientMaskPRNG` is **not thread-safe**; only use it from the `loop.run()` thread and do not share it across loops/threads. +- The pointer must remain valid and **must not move** for the lifetime of any `ClientConnection` using it. + +### UTF-8 Validation + +The library delivers text messages to `onMessage` without validating UTF-8. Per RFC 6455 §8.1, endpoints must close the connection on invalid UTF-8 in text frames. Validate in your handler: + +See [autobahn/server/server.zig](autobahn/server/server.zig) for a complete example (required to pass Autobahn section 6.x tests). + +## Architecture + +``` + ┌──────────────────────┐ + │ User Application │ + │ (defines Handler) │ + └──┬───────┬────────┬──┘ + │ │ │ + ┌─────────────────┘ │ └─────────────────┐ + ▼ │ ▼ +┌─────────────────────┐ │ ┌─────────────────────┐ +│ Server │ │ │ Client (transient)│ +│ TCP accept loop + │ │ │ TCP connect + │ +│ memory pools │ │ │ handshake, then │ +└──────────┬──────────┘ │ │ can be discarded │ + │ │ └──────────┬──────────┘ + ┌──────┴──────┐ │ ▼ + ▼ ▼ │ ┌─────────────────────┐ + ┌────────┐ ┌────────┐ │ │ ClientConnection │ + │ Hand- │ │ Hand- │ │ │ (caller-provided) │ + │ shake │ │ shake │ │ └──────────┬──────────┘ + │ pooled │ │ pooled │ │ │ + └───┬────┘ └───┬────┘ │ │ + ▼ ▼ │ │ + ┌────────┐ ┌────────┐ │ │ + │ Conn │ │ Conn │ │ │ + │ pooled │ │ pooled │ │ │ + └───┬────┘ └───┬────┘ │ │ + └─────┬─────┘ │ │ + └─────────────────────────┼──────────────────────────┘ + ▼ + ┌────────────────────────────────────┐ + │ libxev Event Loop │ + └────────────────────────────────────┘ +``` + +**Server-side:** Each `Handshake` and `Connection` is a self-contained pooled type with its own read buffer and back-pointer to the server. + +**Client-side:** The client is transient — connects TCP, handshakes, initializes a caller-provided `*ClientConnection`, then can be discarded. + +### Connection Lifecycle + +**Server:** +`TCP accept → Handshake (pool) → read HTTP upgrade → validate → Handler.init() → 101 response → Connection (pool) → onOpen → read loop (parse/unmask/reassemble/dispatch) → close handshake → onClose → release to pool` + +**Client:** +`Client.connect() → write HTTP upgrade → read 101 → validate Sec-WebSocket-Accept → init ClientConnection (zero-copy handoff of leftover bytes) → onOpen → read loop (parse/reassemble/dispatch) → close handshake → onClose → deinit` + +## File Structure + +``` +src/ +├── root.zig Public API re-exports +├── types.zig Protocol types, enums, error sets +├── mask.zig XOR masking (SIMD-accelerated) +├── frame.zig Frame parsing/encoding (RFC 6455 §5) +├── http.zig HTTP parsing/encoding +├── reader.zig Frame reader with buffer management +├── control_queue.zig Ring buffer for outbound control frames +├── server/ +│ ├── server.zig TCP listener, accept loop, graceful shutdown +│ ├── slot_pool.zig Memory pool with active count tracking +│ ├── handshake.zig HTTP upgrade handshake (poolable) +│ └── connection.zig WebSocket state machine (poolable) +└── client/ + ├── client.zig Transient: connect, handshake, init connection + ├── handshake.zig Client-side HTTP upgrade state machine + └── connection.zig WebSocket state machine (caller-owned) + +examples/ Echo server and client examples +e2e_tests/ Client-server integration tests + server/ Server behavior tests + client/ Client behavior tests + support/ Shared test helpers, raw client +autobahn/ Autobahn conformance suite runners +``` + +## API Reference + +### Server Config + +```zig +const EchoServer = ws.Server(EchoHandler, 4096); +// Handler ^read buf sz + +const Config = struct { + address: std.net.Address, + tcp_accept_backlog: u31 = 128, + max_message_size: usize = 16 * 1024 * 1024, + initial_handshake_pool_size: usize = 16, + initial_connection_pool_size: usize = 64, + max_handshakes: ?usize = null, + max_connections: ?usize = null, + idle_timeout_ms: ?u32 = null, + close_timeout_ms: u32 = 5_000, + handler_context: …, // if Handler.Context != void: *Handler.Context, else: void ({}) +}; +``` + +### Client Config + +```zig +const SimpleClient = ws.Client(ClientHandler, 4096); +// Handler ^read buf sz + +const Config = struct { + address: std.net.Address, + path: []const u8 = "/", + max_message_size: usize = 16 * 1024 * 1024, + close_timeout_ms: u32 = 5_000, +}; +``` + +The client is a transient value type. `init` doesn't allocate. Caller provides a `*ClientConnection`, allocator, and `*ClientMaskPRNG`: + +```zig +var seed: [ws.ClientMaskPRNG.secret_seed_length]u8 = undefined; +std.crypto.random.bytes(&seed); +var csprng = ws.ClientMaskPRNG.init(seed); + +var conn: SimpleClient.Conn = undefined; +var client = SimpleClient.init(allocator, &loop, &handler, &conn, &csprng, .{ + .address = std.net.Address.parseIp4("127.0.0.1", 8080) catch unreachable, + .path = "/", + .max_message_size = 16 * 1024 * 1024, + .close_timeout_ms = 5_000, +}); +try client.connect(); +// After handshake, `conn` is live — client can be discarded +``` + +### Handler Interface + +```zig +// Required +fn onMessage(self: *Handler, conn: *Conn, message: Message) void +fn onWriteComplete(self: *Handler, conn: *Conn) void +fn onClose(self: *Handler, conn: *Conn) void + +// Optional +fn onOpen(self: *Handler, conn: *Conn) void +fn onPing(self: *Handler, conn: *Conn, data: []const u8) void +fn onPong(self: *Handler, conn: *Conn, data: []const u8) void + +// Optional +fn onBytesRead(self: *Handler, conn: *Conn, bytes_read: usize) void + +// Optional (client-only) +fn onSocketClose(self: *Handler) void + +// Optional (server-only) +fn onHandshakeFailed(self: *Handler) void + +// Server-only (required) +pub const Context = void; // or a real type T +fn init(request: http.Request, context: if (Context == void) void else *Context) !Handler +``` + +If `onPing` is not declared, auto-pong replies with latest-wins semantics. If declared, auto-pong is disabled — handler must call `conn.sendPong()`. + +`onBytesRead` fires on every TCP read completion regardless of whether reads are paused. Combine with `peekBufferedBytes()` to observe raw data arrival (e.g. for byte-counting or deciding when to pause). + +Server `init` runs before 101. Return error to reject. `onHandshakeFailed` fires if the handshake fails after `init` succeeds (pool exhaustion, write error, shutdown); use it to clean up `init`-allocated resources. + +### Connection Methods + +```zig +fn sendText(data) !void // server: []const u8 (zero-copy), client: []u8 (zero-copy, masked in-place) +fn sendBinary(data) !void // same as above +fn sendPing(data) !void // copies internally, max 125 bytes +fn sendPong(data) !void // copies internally, max 125 bytes +fn close(code: CloseCode, reason: []const u8) void +fn pauseReads() void // pause frame dispatch; TCP reads continue until buffer full +fn resumeReads() void // resume dispatch; drains buffered frames +fn peekBufferedBytes() []const u8 // raw bytes in read buffer (transient slice) +``` + +## Tests + +Unit tests colocated in source files. E2E tests in `e2e_tests/`. + +```bash +zig build test --summary all +``` + +## Autobahn Testsuite + +Industry-standard WebSocket conformance suite. **Requires Docker.** + +```bash +bash autobahn/server/run.sh # Results: autobahn/server/reports/index.html +bash autobahn/client/run.sh # Results: autobahn/client/reports/index.html +``` + +**Excluded:** 12.x / 13.x (permessage-deflate not implemented) + +## Current Limitations + +- **No custom response headers in the upgrade response (server):** The 101 response is fixed — no way to add `Sec-WebSocket-Protocol` or other headers. +- **No permessage-deflate (compression):** RFC 7692 is not implemented. Adds complexity around buffer ownership for the send API since compressed frames can't be zero-copy in the same way. +- **No DNS resolution (client):** `Config.address` takes a `std.net.Address` (IP only). The `Host` header is formatted from this address, but real-world servers typically expect the domain name. +- **No TLS:** Most important for the client — servers can sit behind a TLS terminator. diff --git a/src/rpc/webzockets/autobahn/client/client.zig b/src/rpc/webzockets/autobahn/client/client.zig new file mode 100644 index 0000000000..23c307285e --- /dev/null +++ b/src/rpc/webzockets/autobahn/client/client.zig @@ -0,0 +1,458 @@ +const std = @import("std"); +const xev = @import("xev"); +const ws = @import("webzockets_lib"); + +pub const std_options: std.Options = .{ + .log_level = .info, +}; + +const log = std.log.scoped(.autobahn_client); + +/// Embedded read buffer size per connection. +const read_buf_size: usize = 4096; + +/// Maximum reassembled message size — Autobahn sends up to ~16MB. +const max_message_size: usize = 20 * 1024 * 1024; + +const AutobahnClient = ws.Client(AutobahnClientHandler, read_buf_size); + +/// Execution phase for a single fuzzingserver connection. +const Phase = enum { + get_case_count, + run_case, + update_reports, + update_reports_periodic, +}; + +/// Orchestrates sequential Autobahn case execution and report updates. +const AutobahnRunner = struct { + loop: *xev.Loop, + allocator: std.mem.Allocator, + current_case: usize, + total_cases: usize, + conn: AutobahnClient.Conn, + client: AutobahnClient, + handler: AutobahnClientHandler, + phase: Phase, + path_buf: [256]u8 = undefined, + csprng: ws.ClientMaskPRNG, + retry_count: usize = 0, + retry_timer: xev.Timer = .{}, + retry_timer_completion: xev.Completion = undefined, + + const max_retries = 20; + const retry_delay_ms = 3000; + + fn init(allocator: std.mem.Allocator, loop: *xev.Loop) AutobahnRunner { + var seed: [ws.ClientMaskPRNG.secret_seed_length]u8 = undefined; + std.crypto.random.bytes(&seed); + return .{ + .loop = loop, + .allocator = allocator, + .current_case = 1, + .total_cases = 0, + .conn = undefined, + .client = undefined, + .handler = undefined, + .phase = .get_case_count, + .csprng = ws.ClientMaskPRNG.init(seed), + }; + } + + fn retryTimerCallback( + self_opt: ?*AutobahnRunner, + _: *xev.Loop, + _: *xev.Completion, + r: xev.Timer.RunError!void, + ) xev.CallbackAction { + r catch { + log.err("retry timer failed", .{}); + return .disarm; + }; + const self = self_opt.?; + self.getCaseCount() catch |err| { + log.err("getCaseCount failed: {}", .{err}); + }; + return .disarm; + } + + fn deinit(_: *AutobahnRunner) void {} + + /// First step: connect to /getCaseCount to discover how many cases there are. + fn getCaseCount(self: *AutobahnRunner) !void { + log.debug("getCaseCount: connecting to /getCaseCount", .{}); + self.phase = .get_case_count; + self.handler = .{ .runner = self }; + self.client = AutobahnClient.init( + self.allocator, + self.loop, + &self.handler, + &self.conn, + &self.csprng, + .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 9001), + .path = "/getCaseCount", + .max_message_size = max_message_size, + }, + ); + try self.client.connect(); + log.debug("getCaseCount: connect submitted to loop", .{}); + } + + /// Start the next test case, or trigger report generation when done. + fn startNextCase(self: *AutobahnRunner) !void { + if (self.current_case > self.total_cases) { + // All cases done — connect to updateReports and finish + log.info("All {d} cases complete, generating report...", .{self.total_cases}); + try self.connectUpdateReports(false); + return; + } + + const case_num = self.current_case; + self.current_case += 1; + + log.info("Running case {d}/{d}", .{ case_num, self.total_cases }); + + // Build path: /runCase?case=N&agent=webzockets + const path = std.fmt.bufPrint( + &self.path_buf, + "/runCase?case={d}&agent=webzockets", + .{case_num}, + ) catch { + log.debug("startNextCase: ERROR — failed to format path for case {d}", .{case_num}); + return; + }; + + self.phase = .run_case; + self.handler = .{ .runner = self }; + self.client = AutobahnClient.init( + self.allocator, + self.loop, + &self.handler, + &self.conn, + &self.csprng, + .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 9001), + .path = path, + .max_message_size = max_message_size, + }, + ); + try self.client.connect(); + } + + /// Connect to /updateReports to tell the fuzzingserver to generate HTML. + fn connectUpdateReports(self: *AutobahnRunner, periodic: bool) !void { + log.debug("connectUpdateReports: periodic={}", .{periodic}); + self.phase = if (periodic) .update_reports_periodic else .update_reports; + self.handler = .{ .runner = self }; + self.client = AutobahnClient.init( + self.allocator, + self.loop, + &self.handler, + &self.conn, + &self.csprng, + .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 9001), + .path = "/updateReports?agent=webzockets", + .max_message_size = max_message_size, + }, + ); + try self.client.connect(); + } + + /// Called from handler when a connection's socket is fully closed. + /// It is safe to overwrite `self.client` here because xev captures all + /// completion state before invoking callbacks, and no code in the return + /// path accesses the old client after this function returns. + fn onConnectionDone(self: *AutobahnRunner) void { + const was_opened = self.handler.opened; + log.debug("onConnectionDone: phase={s}, opened={}", .{ @tagName(self.phase), was_opened }); + // conn.deinit() releases buffers back to the pool. + if (was_opened) { + self.conn.deinit(); + } + + // Retry getCaseCount if it failed (server may not be fully ready yet). + if (self.phase == .get_case_count and self.total_cases == 0) { + self.retry_count += 1; + if (self.retry_count > max_retries) { + log.err("getCaseCount failed after {d} retries, giving up.", .{max_retries}); + return; + } + if (!was_opened) { + log.warn("WebSocket handshake to /getCaseCount failed, " ++ + "retrying in {d}s ({d}/{d})...", .{ + retry_delay_ms / 1000, + self.retry_count, + max_retries, + }); + } else { + log.warn("Got 0 cases from fuzzingserver, " ++ + "retrying in {d}s ({d}/{d})...", .{ + retry_delay_ms / 1000, + self.retry_count, + max_retries, + }); + } + // Schedule retry via xev timer + self.retry_timer = xev.Timer.init() catch { + log.err("failed to create retry timer", .{}); + return; + }; + self.retry_timer_completion = .{}; + self.retry_timer.run( + self.loop, + &self.retry_timer_completion, + retry_delay_ms, + AutobahnRunner, + self, + retryTimerCallback, + ); + return; + } + + if (self.phase == .get_case_count) { + log.info("Fuzzingserver reports {d} test cases", .{self.total_cases}); + } + + // Proceed to next step. + const next_res = switch (self.phase) { + .get_case_count => self.startNextCase(), + .run_case => if (self.current_case > 1 and (self.current_case - 1) % 10 == 0) + self.connectUpdateReports(true) + else + self.startNextCase(), + .update_reports_periodic => self.startNextCase(), + .update_reports => { + log.info("Report generation complete.", .{}); + return; + }, + }; + + next_res catch |err| { + log.err("failed to start next step: {}", .{err}); + }; + } +}; + +/// Echo handler used for individual Autobahn test-case connections. +const AutobahnClientHandler = struct { + const PendingMessage = struct { + data: []u8, + is_text: bool, + next: ?*PendingMessage = null, + }; + + runner: *AutobahnRunner, + opened: bool = false, + + /// Message currently in-flight; freed in onWriteComplete/onClose. + sent_data: ?[]u8 = null, + /// Pending outbound messages while a write is in flight. + queue_head: ?*PendingMessage = null, + /// Tail pointer for O(1) queue append. + queue_tail: ?*PendingMessage = null, + + pub fn onOpen(self: *AutobahnClientHandler, conn: *AutobahnClient.Conn) void { + self.opened = true; + log.debug("handler.onOpen: phase={s}", .{@tagName(self.runner.phase)}); + if (self.runner.phase == .update_reports or self.runner.phase == .update_reports_periodic) { + // For updateReports, just close immediately after connection opens + log.debug("handler.onOpen: updateReports — closing immediately", .{}); + conn.close(.normal, ""); + } + } + + /// Explicitly handle pings so every ping gets its own pong response. + /// Without this, the library's auto-pong uses "latest wins" semantics, + /// which is spec-compliant but fails Autobahn test 2.10 (expects a + /// pong for each of 10 rapidly sent pings). + pub fn onPing(_: *AutobahnClientHandler, conn: *AutobahnClient.Conn, data: []const u8) void { + conn.sendPong(data) catch |err| { + log.err("sendPong failed: {}", .{err}); + }; + } + + pub fn onMessage( + self: *AutobahnClientHandler, + conn: *AutobahnClient.Conn, + message: ws.Message, + ) void { + log.debug("handler.onMessage: phase={s}, type={s}, len={d}", .{ + @tagName(self.runner.phase), + @tagName(message.type), + message.data.len, + }); + + switch (self.runner.phase) { + .get_case_count => { + // Server sends case count as a text message + if (message.type == .text) { + const trimmed = std.mem.trim(u8, message.data, &[_]u8{ ' ', '\t', '\r', '\n' }); + log.debug("handler.onMessage: getCaseCount body=\"{s}\"", .{trimmed}); + self.runner.total_cases = std.fmt.parseInt(usize, trimmed, 10) catch 0; + log.debug("handler.onMessage: parsed total_cases={d}", .{ + self.runner.total_cases, + }); + } + }, + .run_case => { + switch (message.type) { + .text => { + if (!std.unicode.utf8ValidateSlice(message.data)) { + conn.close(.invalid_payload, "Invalid UTF-8"); + return; + } + self.enqueue(conn, message.data, true); + }, + .binary => self.enqueue(conn, message.data, false), + else => {}, + } + }, + else => {}, + } + } + + fn enqueue( + self: *AutobahnClientHandler, + conn: *AutobahnClient.Conn, + data: []const u8, + is_text: bool, + ) void { + const allocator = self.runner.allocator; + const copy = allocator.dupe(u8, data) catch return; + const msg = allocator.create(PendingMessage) catch { + allocator.free(copy); + return; + }; + msg.* = .{ + .data = copy, + .is_text = is_text, + }; + // Append to tail + if (self.queue_tail) |tail| { + tail.next = msg; + } else { + self.queue_head = msg; + } + self.queue_tail = msg; + self.drainQueue(conn); + } + + fn drainQueue(self: *AutobahnClientHandler, conn: *AutobahnClient.Conn) void { + while (self.queue_head) |msg| { + if (self.sent_data != null) return; // write in flight + // Pop from head + self.queue_head = msg.next; + if (self.queue_head == null) self.queue_tail = null; + + if (msg.is_text) { + conn.sendText(msg.data) catch { + self.runner.allocator.free(msg.data); + self.runner.allocator.destroy(msg); + continue; + }; + } else { + conn.sendBinary(msg.data) catch { + self.runner.allocator.free(msg.data); + self.runner.allocator.destroy(msg); + continue; + }; + } + self.sent_data = msg.data; + self.runner.allocator.destroy(msg); + return; + } + } + + pub fn onWriteComplete(self: *AutobahnClientHandler, conn: *AutobahnClient.Conn) void { + log.debug("handler.onWriteComplete: phase={s}", .{@tagName(self.runner.phase)}); + if (self.sent_data) |data| { + self.runner.allocator.free(data); + self.sent_data = null; + } + self.drainQueue(conn); + } + + pub fn onClose(self: *AutobahnClientHandler, _: *AutobahnClient.Conn) void { + log.debug("handler.onClose: phase={s}", .{@tagName(self.runner.phase)}); + const allocator = self.runner.allocator; + if (self.sent_data) |data| { + allocator.free(data); + self.sent_data = null; + } + while (self.queue_head) |msg| { + self.queue_head = msg.next; + allocator.free(msg.data); + allocator.destroy(msg); + } + self.queue_tail = null; + } + + pub fn onSocketClose(self: *AutobahnClientHandler) void { + log.debug("handler.onSocketClose: phase={s}", .{@tagName(self.runner.phase)}); + self.runner.onConnectionDone(); + } +}; + +fn run(allocator: std.mem.Allocator) !void { + // Wait for Docker fuzzingserver to start + log.info("Waiting for fuzzingserver on port 9001...", .{}); + const max_retries = 60; + var attempt: usize = 0; + while (attempt < max_retries) : (attempt += 1) { + // Try connecting to see if the server is up + const stream = std.net.tcpConnectToHost(allocator, "127.0.0.1", 9001) catch { + log.warn("attempt {d}/{d} — not ready, retrying in 3s...", .{ + attempt + 1, + max_retries, + }); + std.time.sleep(3 * std.time.ns_per_s); + continue; + }; + stream.close(); + break; + } + if (attempt == max_retries) { + log.err("fuzzingserver did not become ready after {d} attempts.", .{max_retries}); + return error.ServerNotReady; + } + log.info("Fuzzingserver is up.", .{}); + + // Init event loop + var thread_pool = xev.ThreadPool.init(.{}); + defer thread_pool.deinit(); + defer thread_pool.shutdown(); + + var loop = try xev.Loop.init(.{ .thread_pool = &thread_pool }); + defer loop.deinit(); + + // Init runner — first connection will be getCaseCount + var runner = AutobahnRunner.init(allocator, &loop); + defer runner.deinit(); + try runner.getCaseCount(); + + log.debug("main: entering loop.run(.until_done)", .{}); + + // Run until all cases are done + try loop.run(.until_done); + + log.info("Autobahn client test run complete.", .{}); + log.info("Check autobahn/client/reports/index.html for results.", .{}); +} + +pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + + run(allocator) catch |err| { + if (gpa.deinit() == .leak) { + log.err("GPA detected memory leaks while exiting with error: {}", .{err}); + } + return err; + }; + + if (gpa.deinit() == .leak) { + return error.MemoryLeakDetected; + } +} diff --git a/src/rpc/webzockets/autobahn/client/config.json b/src/rpc/webzockets/autobahn/client/config.json new file mode 100644 index 0000000000..981b6d8aef --- /dev/null +++ b/src/rpc/webzockets/autobahn/client/config.json @@ -0,0 +1,7 @@ +{ + "url": "ws://127.0.0.1:9001", + "outdir": "/ab/reports", + "cases": ["*"], + "exclude-cases": ["12.*", "13.*"], + "exclude-agent-cases": {} +} diff --git a/src/rpc/webzockets/autobahn/client/run.sh b/src/rpc/webzockets/autobahn/client/run.sh new file mode 100755 index 0000000000..25cf28f236 --- /dev/null +++ b/src/rpc/webzockets/autobahn/client/run.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)" + +# Cleanup any previous run +docker stop fuzzingserver 2>/dev/null || true +trap "docker stop fuzzingserver 2>/dev/null || true" EXIT + +# Start fuzzingserver in background +echo "Starting Autobahn fuzzingserver on port 9001..." +bash "$SCRIPT_DIR/start_server.sh" & + +# Build +cd "$PROJECT_DIR" +echo "Building Autobahn client runner..." +zig build -Doptimize=ReleaseSafe + +# Run client (it has its own retry loop for server startup) +echo "Running Autobahn client..." +"$PROJECT_DIR/zig-out/bin/autobahn-client" + +echo "" +echo "Autobahn client testsuite complete." + +# Give fuzzingserver time to flush report files to disk +sleep 3 + +echo "Open $SCRIPT_DIR/reports/index.html to view results." + +# Check for failures +if [ -f "$SCRIPT_DIR/reports/index.json" ]; then + if grep -q "FAILED" "$SCRIPT_DIR/reports/index.json"; then + echo "" + echo "WARNING: Some tests FAILED! Check the report for details." + exit 1 + else + echo "All tests passed!" + fi +fi diff --git a/src/rpc/webzockets/autobahn/client/start_server.sh b/src/rpc/webzockets/autobahn/client/start_server.sh new file mode 100755 index 0000000000..dca0f8b7cc --- /dev/null +++ b/src/rpc/webzockets/autobahn/client/start_server.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Start the Autobahn fuzzingserver in Docker (foreground, Ctrl-C to stop) +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" + +docker run --rm \ + -v "$SCRIPT_DIR:/ab" \ + -p 9001:9001 \ + --name fuzzingserver \ + crossbario/autobahn-testsuite \ + /opt/pypy/bin/wstest --mode fuzzingserver --spec /ab/config.json diff --git a/src/rpc/webzockets/autobahn/server/config.json b/src/rpc/webzockets/autobahn/server/config.json new file mode 100644 index 0000000000..527653ffd1 --- /dev/null +++ b/src/rpc/webzockets/autobahn/server/config.json @@ -0,0 +1,9 @@ +{ + "outdir": "/ab/reports/", + "options": { "failByDrop": false }, + "servers": [ + { "agent": "webzockets", "url": "ws://host.docker.internal:9001" } + ], + "cases": ["*"], + "exclude-cases": ["12.*", "13.*"] +} diff --git a/src/rpc/webzockets/autobahn/server/run.sh b/src/rpc/webzockets/autobahn/server/run.sh new file mode 100755 index 0000000000..6796dd27e1 --- /dev/null +++ b/src/rpc/webzockets/autobahn/server/run.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)" + +# Build +cd "$PROJECT_DIR" +echo "Building echo server..." +zig build -Doptimize=ReleaseSafe + +# Start server in background +SERVER_LOG="$SCRIPT_DIR/server.log" +echo "Starting echo server on port 9001 (logging to $SERVER_LOG)..." +"$PROJECT_DIR/zig-out/bin/autobahn-server" >"$SERVER_LOG" 2>&1 & +SERVER_PID=$! +trap "kill $SERVER_PID 2>/dev/null || true" EXIT + +sleep 2 +if ! kill -0 "$SERVER_PID" 2>/dev/null; then + echo "ERROR: Echo server failed to start" + exit 1 +fi +echo "Server running (PID $SERVER_PID)" + +# Run Autobahn fuzzingclient +echo "Starting Autobahn fuzzingclient..." +docker run --rm \ + -v "$SCRIPT_DIR:/ab" \ + --name fuzzingclient \ + crossbario/autobahn-testsuite \ + /opt/pypy/bin/wstest --mode fuzzingclient --spec /ab/config.json + +echo "" +echo "Autobahn server testsuite complete." +echo "Open $SCRIPT_DIR/reports/index.html to view results." diff --git a/src/rpc/webzockets/autobahn/server/server.zig b/src/rpc/webzockets/autobahn/server/server.zig new file mode 100644 index 0000000000..f240c02dd4 --- /dev/null +++ b/src/rpc/webzockets/autobahn/server/server.zig @@ -0,0 +1,169 @@ +const std = @import("std"); +const xev = @import("xev"); +const ws = @import("webzockets_lib"); + +pub const std_options: std.Options = .{ + .log_level = .info, +}; + +const log = std.log.scoped(.autobahn_server); + +/// Embedded read buffer size per connection. +const read_buf_size: usize = 4096; + +/// Maximum reassembled message size — Autobahn sends up to ~16MB. +const max_message_size: usize = 20 * 1024 * 1024; + +const AutobahnServer = ws.Server(AutobahnHandler, read_buf_size); + +/// Echo handler for the Autobahn testsuite. +/// +/// Echoes text and binary messages back verbatim. Validates UTF-8 on text +/// messages and closes with code 1007 (Invalid Payload Data) on failure, +/// which is required to pass Autobahn section 6.x tests. We do it here rather +/// than in the library as it's not something you necessarily want in all servers. +/// +/// Uses an intrusive singly-linked list for O(1) queue operations. +const AutobahnHandler = struct { + pub const Context = void; + + const PendingMessage = struct { + data: []const u8, + is_text: bool, + next: ?*PendingMessage = null, + }; + + /// Data currently being written (freed in onWriteComplete). + sent_data: ?[]const u8 = null, + /// Head of pending message queue. + queue_head: ?*PendingMessage = null, + /// Tail of pending message queue (for O(1) append). + queue_tail: ?*PendingMessage = null, + + pub fn init(_: ws.http.Request, _: void) !AutobahnHandler { + return .{}; + } + + /// Explicitly handle pings so every ping gets its own pong response. + /// Without this, the library's auto-pong uses "latest wins" semantics, + /// which is spec-compliant but fails Autobahn test 2.10 (expects a + /// pong for each of 10 rapidly sent pings). + pub fn onPing(_: *AutobahnHandler, conn: *AutobahnServer.Conn, data: []const u8) void { + conn.sendPong(data) catch |err| { + log.err("sendPong failed: {}", .{err}); + }; + } + + pub fn onMessage(self: *AutobahnHandler, conn: *AutobahnServer.Conn, message: ws.Message) void { + switch (message.type) { + .text => { + if (!std.unicode.utf8ValidateSlice(message.data)) { + conn.close(.invalid_payload, "Invalid UTF-8"); + return; + } + self.enqueue(conn, message.data, true); + }, + .binary => self.enqueue(conn, message.data, false), + else => {}, + } + } + + fn enqueue( + self: *AutobahnHandler, + conn: *AutobahnServer.Conn, + data: []const u8, + is_text: bool, + ) void { + const copy = conn.allocator.dupe(u8, data) catch return; + const msg = conn.allocator.create(PendingMessage) catch { + conn.allocator.free(copy); + return; + }; + msg.* = .{ + .data = copy, + .is_text = is_text, + }; + // Append to tail + if (self.queue_tail) |tail| { + tail.next = msg; + } else { + self.queue_head = msg; + } + self.queue_tail = msg; + self.drainQueue(conn); + } + + fn drainQueue(self: *AutobahnHandler, conn: *AutobahnServer.Conn) void { + while (self.queue_head) |msg| { + if (self.sent_data != null) return; // write in flight + // Pop from head + self.queue_head = msg.next; + if (self.queue_head == null) self.queue_tail = null; + + if (msg.is_text) { + conn.sendText(msg.data) catch { + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + continue; + }; + } else { + conn.sendBinary(msg.data) catch { + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + continue; + }; + } + self.sent_data = msg.data; + conn.allocator.destroy(msg); + return; + } + } + + pub fn onWriteComplete(self: *AutobahnHandler, conn: *AutobahnServer.Conn) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + self.drainQueue(conn); + } + + pub fn onClose(self: *AutobahnHandler, conn: *AutobahnServer.Conn) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + } + while (self.queue_head) |msg| { + self.queue_head = msg.next; + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + } + self.queue_tail = null; + } +}; + +pub fn main() !void { + const address = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, 9001); + + var thread_pool = xev.ThreadPool.init(.{}); + defer thread_pool.deinit(); + defer thread_pool.shutdown(); + + var loop = try xev.Loop.init(.{ .thread_pool = &thread_pool }); + defer loop.deinit(); + + var server = try AutobahnServer.init( + std.heap.c_allocator, + &loop, + .{ + .address = address, + .handler_context = {}, + .max_message_size = max_message_size, + }, + ); + defer server.deinit(); + + server.accept(); + + log.info("Autobahn echo server listening on ws://0.0.0.0:9001", .{}); + + try loop.run(.until_done); +} diff --git a/src/rpc/webzockets/autobahn/summarize_reports.py b/src/rpc/webzockets/autobahn/summarize_reports.py new file mode 100755 index 0000000000..d9ed7193ca --- /dev/null +++ b/src/rpc/webzockets/autobahn/summarize_reports.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +Summarize Autobahn WebSocket test suite reports. + +Usage: + python3 summarize_reports.py client [--verbose] [--json] + python3 summarize_reports.py server [--verbose] [--json] + +Reads the index.json produced by the Autobahn test suite and prints a +human-readable summary of pass/fail/non-strict/informational results +grouped by test category. +""" + +import argparse +import json +import os +import sys +from collections import Counter, defaultdict +from pathlib import Path + +# Autobahn test category descriptions +CATEGORY_NAMES = { + "1": "Framing", + "2": "Pings/Pongs", + "3": "Reserved Bits", + "4": "Opcodes", + "5": "Fragmentation", + "6": "UTF-8 Handling", + "7": "Close Handling", + "9": "Limits/Performance", + "10": "Misc", + "12": "WebSocket Compression (different payloads)", + "13": "WebSocket Compression (different parameters)", +} + +# Behavior result ordering (best to worst) +BEHAVIOR_ORDER = ["OK", "NON-STRICT", "INFORMATIONAL", "UNIMPLEMENTED", "FAILED"] + + +def load_index(reports_dir: Path) -> dict: + index_path = reports_dir / "index.json" + if not index_path.exists(): + print(f"Error: {index_path} not found", file=sys.stderr) + sys.exit(1) + with open(index_path) as f: + data = json.load(f) + # The index has agent names as top-level keys; pick the first one + agents = list(data.keys()) + if len(agents) == 0: + print("Error: empty index.json", file=sys.stderr) + sys.exit(1) + agent_name = agents[0] + return agent_name, data[agent_name] + + +def case_sort_key(case_id: str): + return [int(p) for p in case_id.split(".")] + + +def summarize(cases: dict, verbose: bool = False): + total = len(cases) + behaviors = Counter(c["behavior"] for c in cases.values()) + close_behaviors = Counter(c["behaviorClose"] for c in cases.values()) + + # Group by category + categories = defaultdict(lambda: defaultdict(int)) + for case_id, c in cases.items(): + cat = case_id.split(".")[0] + categories[cat]["total"] += 1 + categories[cat][c["behavior"]] += 1 + + # Collect non-OK cases + non_ok = [] + for case_id, c in sorted(cases.items(), key=lambda x: case_sort_key(x[0])): + if c["behavior"] != "OK": + non_ok.append((case_id, c)) + + return { + "total": total, + "behaviors": dict(behaviors), + "close_behaviors": dict(close_behaviors), + "categories": dict(categories), + "non_ok": non_ok, + } + + +def print_summary(agent_name: str, summary: dict, verbose: bool = False): + total = summary["total"] + behaviors = summary["behaviors"] + non_ok = summary["non_ok"] + categories = summary["categories"] + + ok_count = behaviors.get("OK", 0) + failed_count = behaviors.get("FAILED", 0) + non_strict_count = behaviors.get("NON-STRICT", 0) + info_count = behaviors.get("INFORMATIONAL", 0) + + # Header + print("=" * 70) + print(f" Autobahn WebSocket Test Report — {agent_name}") + print("=" * 70) + print() + + # Overall pass rate + strict_pass = ok_count + pass_rate = (strict_pass / total * 100) if total else 0 + loose_pass = ok_count + non_strict_count + info_count + loose_rate = (loose_pass / total * 100) if total else 0 + + print(f" Total cases: {total}") + print(f" Passed (OK): {ok_count}") + print(f" Non-strict: {non_strict_count}") + print(f" Informational: {info_count}") + print(f" Failed: {failed_count}") + print() + print(f" Strict pass rate: {pass_rate:.1f}% ({strict_pass}/{total})") + print(f" Effective rate: {loose_rate:.1f}% ({loose_pass}/{total} — OK + NON-STRICT + INFO)") + print() + + # Per-category breakdown + print("-" * 70) + print(f" {'Category':<35} {'OK':>5} {'N-S':>5} {'INFO':>5} {'FAIL':>5} {'Total':>6}") + print("-" * 70) + + for cat in sorted(categories.keys(), key=int): + c = categories[cat] + name = CATEGORY_NAMES.get(cat, "Unknown") + label = f"{cat}. {name}" + cat_ok = c.get("OK", 0) + cat_ns = c.get("NON-STRICT", 0) + cat_info = c.get("INFORMATIONAL", 0) + cat_fail = c.get("FAILED", 0) + cat_total = c["total"] + marker = "" + if cat_fail > 0: + marker = " ✗" + elif cat_ns > 0: + marker = " ~" + print(f" {label:<35} {cat_ok:>5} {cat_ns:>5} {cat_info:>5} {cat_fail:>5} {cat_total:>6}{marker}") + + print("-" * 70) + print() + + # Non-OK details + if non_ok: + print("Non-OK test cases:") + print() + for case_id, c in non_ok: + status = c["behavior"] + close = c["behaviorClose"] + duration = c["duration"] + icon = {"NON-STRICT": "~", "INFORMATIONAL": "ℹ", "FAILED": "✗"}.get(status, "?") + print(f" [{icon}] Case {case_id}: {status} (close: {close}, {duration}ms)") + + if verbose: + print() + print("Detailed descriptions of non-OK cases:") + print() + reports_dir = None # Will be set by caller if needed + else: + print("All test cases passed with OK! 🎉") + + print() + + +def print_verbose_details(reports_dir: Path, non_ok: list): + """Print detailed descriptions for non-OK cases by reading individual report files.""" + for case_id, c in non_ok: + report_file = reports_dir / c.get("reportfile", "") + if report_file.exists(): + with open(report_file) as f: + detail = json.load(f) + desc = detail.get("description", "N/A") + # Strip HTML tags for readability + import re + desc = re.sub(r"<[^>]+>", " ", desc).strip() + desc = re.sub(r"\s+", " ", desc) + expectation = detail.get("expectation", "N/A") + expectation = re.sub(r"<[^>]+>", " ", expectation).strip() + expectation = re.sub(r"\s+", " ", expectation) + print(f" Case {case_id}:") + print(f" Description: {desc}") + print(f" Expectation: {expectation}") + print() + + +def print_json_summary(agent_name: str, summary: dict): + output = { + "agent": agent_name, + "total": summary["total"], + "behaviors": summary["behaviors"], + "close_behaviors": summary["close_behaviors"], + "categories": {}, + "non_ok": [], + } + for cat in sorted(summary["categories"].keys(), key=int): + c = summary["categories"][cat] + name = CATEGORY_NAMES.get(cat, "Unknown") + output["categories"][cat] = {"name": name, **dict(c)} + for case_id, c in summary["non_ok"]: + output["non_ok"].append({"case": case_id, **c}) + print(json.dumps(output, indent=2)) + + +def main(): + parser = argparse.ArgumentParser(description="Summarize Autobahn WebSocket test reports") + parser.add_argument( + "mode", + choices=["client", "server"], + help="Which report to summarize: client or server", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed descriptions of non-OK cases") + parser.add_argument("--json", "-j", action="store_true", help="Output as JSON instead of human-readable text") + args = parser.parse_args() + + script_dir = Path(__file__).parent + reports_dir = script_dir / args.mode / "reports" + + agent_name, cases = load_index(reports_dir) + summary = summarize(cases, verbose=args.verbose) + + if args.json: + print_json_summary(agent_name, summary) + else: + print_summary(agent_name, summary, verbose=args.verbose) + if args.verbose and summary["non_ok"]: + print_verbose_details(reports_dir, summary["non_ok"]) + + +if __name__ == "__main__": + main() diff --git a/src/rpc/webzockets/build.zig b/src/rpc/webzockets/build.zig new file mode 100644 index 0000000000..9692a6493f --- /dev/null +++ b/src/rpc/webzockets/build.zig @@ -0,0 +1,163 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + + const optimize = b.standardOptimizeOption(.{}); + + const xev_dep = b.dependency("libxev", .{ .target = target, .optimize = optimize }); + + const lib_mod = b.createModule(.{ + .root_source_file = b.path("src/root.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "xev", .module = xev_dep.module("xev") }, + }, + }); + + const lib = b.addLibrary(.{ + .linkage = .static, + .name = "webzockets", + .root_module = lib_mod, + }); + + _ = b.addModule("webzockets", .{ + .root_source_file = b.path("src/root.zig"), + .imports = &.{ + .{ .name = "xev", .module = xev_dep.module("xev") }, + }, + }); + + b.installArtifact(lib); + + const echo_server_mod = b.createModule(.{ + .root_source_file = b.path("examples/echo_server.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "xev", .module = xev_dep.module("xev") }, + }, + }); + echo_server_mod.addImport("webzockets_lib", lib_mod); + + const echo_server_exe = b.addExecutable(.{ + .name = "echo-server", + .root_module = echo_server_mod, + }); + b.installArtifact(echo_server_exe); + + const run_echo_server_cmd = b.addRunArtifact(echo_server_exe); + run_echo_server_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_echo_server_cmd.addArgs(args); + } + const run_echo_server_step = b.step("run-echo-server", "Run the example echo server"); + run_echo_server_step.dependOn(&run_echo_server_cmd.step); + + const echo_client_mod = b.createModule(.{ + .root_source_file = b.path("examples/simple_client.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "xev", .module = xev_dep.module("xev") }, + }, + }); + echo_client_mod.addImport("webzockets_lib", lib_mod); + + const echo_client_exe = b.addExecutable(.{ + .name = "simple-client", + .root_module = echo_client_mod, + }); + b.installArtifact(echo_client_exe); + + const run_echo_client_cmd = b.addRunArtifact(echo_client_exe); + run_echo_client_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_echo_client_cmd.addArgs(args); + } + const run_echo_client_step = b.step("run-simple-client", "Run the example simple client"); + run_echo_client_step.dependOn(&run_echo_client_cmd.step); + + // Autobahn testsuite echo server + const autobahn_mod = b.createModule(.{ + .root_source_file = b.path("autobahn/server/server.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "xev", .module = xev_dep.module("xev") }, + }, + }); + autobahn_mod.addImport("webzockets_lib", lib_mod); + + const autobahn_exe = b.addExecutable(.{ + .name = "autobahn-server", + .root_module = autobahn_mod, + }); + + b.installArtifact(autobahn_exe); + + const run_autobahn_cmd = b.addRunArtifact(autobahn_exe); + run_autobahn_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_autobahn_cmd.addArgs(args); + } + + const run_autobahn_step = b.step("run-autobahn", "Run the Autobahn testsuite echo server"); + run_autobahn_step.dependOn(&run_autobahn_cmd.step); + + // Autobahn testsuite client runner + const autobahn_client_mod = b.createModule(.{ + .root_source_file = b.path("autobahn/client/client.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "xev", .module = xev_dep.module("xev") }, + }, + }); + autobahn_client_mod.addImport("webzockets_lib", lib_mod); + + const autobahn_client_exe = b.addExecutable(.{ + .name = "autobahn-client", + .root_module = autobahn_client_mod, + }); + + b.installArtifact(autobahn_client_exe); + + const run_autobahn_client_cmd = b.addRunArtifact(autobahn_client_exe); + run_autobahn_client_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_autobahn_client_cmd.addArgs(args); + } + + const run_autobahn_client_step = b.step( + "run-autobahn-client", + "Run the Autobahn testsuite client runner", + ); + run_autobahn_client_step.dependOn(&run_autobahn_client_cmd.step); + + const lib_unit_tests = b.addTest(.{ + .root_module = lib_mod, + }); + + const run_lib_unit_tests = b.addRunArtifact(lib_unit_tests); + + const e2e_mod = b.createModule(.{ + .root_source_file = b.path("e2e_tests/tests.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "xev", .module = xev_dep.module("xev") }, + }, + }); + e2e_mod.addImport("webzockets_lib", lib_mod); + + const e2e_tests = b.addTest(.{ + .root_module = e2e_mod, + }); + const run_e2e_tests = b.addRunArtifact(e2e_tests); + + const test_step = b.step("test", "Run tests"); + test_step.dependOn(&run_lib_unit_tests.step); + test_step.dependOn(&run_e2e_tests.step); +} diff --git a/src/rpc/webzockets/build.zig.zon b/src/rpc/webzockets/build.zig.zon new file mode 100644 index 0000000000..ee6ce9613c --- /dev/null +++ b/src/rpc/webzockets/build.zig.zon @@ -0,0 +1,25 @@ +.{ + .name = .webzockets, + + .version = "0.0.0", + + .fingerprint = 0x3c694f7003ef0582, // Changing this has security and trust implications. + + .minimum_zig_version = "0.14.1", + + .dependencies = .{ + .libxev = .{ + .url = "git+https://github.com/Syndica/libxev#c16162f0f9047f27d6230c548e1e25070b9be163", + .hash = "libxev-0.0.0-86vtcyQcEwAxRuv5vgdyq8LCRKHNGLgtKfS4KNHdic2V", + }, + }, + .paths = .{ + "build.zig", + "build.zig.zon", + "src", + "examples", + "e2e_tests", + "autobahn", + "README.md", + }, +} diff --git a/src/rpc/webzockets/e2e_tests/client/close_tests.zig b/src/rpc/webzockets/e2e_tests/client/close_tests.zig new file mode 100644 index 0000000000..d716b6ca0d --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/client/close_tests.zig @@ -0,0 +1,34 @@ +const std = @import("std"); +const testing = std.testing; + +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e client: server-initiated close" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // Start a server that sends close immediately on open + const ts = try servers.startCloseOnOpenServer(testing.allocator); + defer ts.stop(); + + var handler: clients.ServerCloseHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestServerCloseClient.Conn = undefined; + var client = env.initClient(clients.TestServerCloseClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + try client.connect(); + try env.loop.run(.until_done); + + // Client should have connected successfully + try testing.expect(handler.open_called); + // Client should have received the server's close and cleaned up + try testing.expect(handler.close_called); + conn.deinit(); +} diff --git a/src/rpc/webzockets/e2e_tests/client/connection_tests.zig b/src/rpc/webzockets/e2e_tests/client/connection_tests.zig new file mode 100644 index 0000000000..8c82913dde --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/client/connection_tests.zig @@ -0,0 +1,285 @@ +const std = @import("std"); +const testing = std.testing; + +const http = @import("webzockets_lib").http; +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e client: connection to non-existent server" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // Use an ephemeral port that nothing is listening on. + // Bind a socket to get an OS-assigned port, then close it immediately. + const addr = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 0); + const sock = try std.posix.socket(std.posix.AF.INET, std.posix.SOCK.STREAM, 0); + try std.posix.bind(sock, @ptrCast(&addr.any), addr.getOsSockLen()); + var bound_addr: std.posix.sockaddr.storage = undefined; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(bound_addr)); + try std.posix.getsockname(sock, @ptrCast(&bound_addr), &addr_len); + const sa4: *const std.posix.sockaddr.in = @ptrCast(@alignCast(&bound_addr)); + const unused_port = std.mem.bigToNative(u16, sa4.port); + std.posix.close(sock); + + var handler: clients.ConnectFailHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestConnectFailClient.Conn = undefined; + var client = env.initClient(clients.TestConnectFailClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, unused_port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + // Connection should have failed — onOpen should NOT have been called + try testing.expect(!handler.open_called); + // onSocketClose should have been called + try testing.expect(handler.socket_close_called); +} + +test "e2e client: connection refused by server (handler rejects)" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startRejectServer(testing.allocator); + defer ts.stop(); + + var handler: clients.ConnectFailHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestConnectFailClient.Conn = undefined; + var client = env.initClient(clients.TestConnectFailClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + // Server rejected the upgrade — onOpen should NOT have been called + try testing.expect(!handler.open_called); + // onSocketClose should have been called (handshake failure) + try testing.expect(handler.socket_close_called); +} + +test "e2e client: 10 concurrent clients to same server" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const num_clients = 10; + + var handlers: [num_clients]clients.EchoTestHandler = undefined; + var conns: [num_clients]clients.TestEchoClient.Conn = undefined; + var client_instances: [num_clients]clients.TestEchoClient = undefined; + + // Initialize all handlers with unique messages + var msg_bufs: [num_clients][16]u8 = undefined; + for (0..num_clients) |i| { + const msg = std.fmt.bufPrint(&msg_bufs[i], "client_{d}", .{i}) catch unreachable; + handlers[i] = .{ + .send_kind = .text, + .send_data = msg, + .allocator = testing.allocator, + }; + } + defer for (&handlers) |*h| h.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + // Create and connect all clients + for (0..num_clients) |i| { + client_instances[i] = env.initClient(clients.TestEchoClient, &handlers[i], &conns[i], .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + try client_instances[i].connect(); + } + + try env.loop.run(.until_done); + + // Verify all clients connected and received their echoed messages + for (0..num_clients) |i| { + defer if (handlers[i].open_called) conns[i].deinit(); + try testing.expect(handlers[i].open_called); + const expected = std.fmt.bufPrint(&msg_bufs[i], "client_{d}", .{i}) catch unreachable; + const received_data = handlers[i].received_data orelse return error.NoData; + try testing.expectEqualSlices(u8, expected, received_data); + } +} + +test "e2e client: bare LF response doesn't crash client" { + // A malicious/broken server sends a 101 response where the status line + // uses \r\n but headers use bare \n, terminated by \n\n. The client must + // reject gracefully without crashing. + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const fake = try startWithResponseFn(buildBareLfResponse); + defer fake.stop(); + + var handler: clients.ConnectFailHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestConnectFailClient.Conn = undefined; + var client = env.initClient(clients.TestConnectFailClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, fake.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + // The handshake should have failed — onOpen must NOT have been called. + try testing.expect(!handler.open_called); + // onSocketClose should have been called (handshake failure path). + try testing.expect(handler.socket_close_called); +} + +test "e2e client: fully bare LF response (no \\r\\n at all) doesn't crash client" { + // Same as above but the entire response uses bare \n — no \r\n anywhere. + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const fake = try startWithResponseFn(buildFullyBareLfResponse); + defer fake.stop(); + + var handler: clients.ConnectFailHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestConnectFailClient.Conn = undefined; + var client = env.initClient(clients.TestConnectFailClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, fake.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(!handler.open_called); + try testing.expect(handler.socket_close_called); +} + +/// Build a 101 response with \r\n on the status line but bare \n on headers. +fn buildBareLfResponse(buf: []u8, accept_key: []const u8) []const u8 { + var fbs = std.io.fixedBufferStream(buf); + const w = fbs.writer(); + w.writeAll("HTTP/1.1 101 Switching Protocols\r\n") catch unreachable; + w.writeAll("Upgrade: websocket\n") catch unreachable; + w.writeAll("Connection: Upgrade\n") catch unreachable; + w.print("Sec-WebSocket-Accept: {s}\n", .{accept_key}) catch unreachable; + w.writeAll("\n") catch unreachable; + return buf[0..fbs.pos]; +} + +/// Build a 101 response where the entire response uses bare \n (no \r\n at all). +fn buildFullyBareLfResponse(buf: []u8, accept_key: []const u8) []const u8 { + var fbs = std.io.fixedBufferStream(buf); + const w = fbs.writer(); + w.writeAll("HTTP/1.1 101 Switching Protocols\n") catch unreachable; + w.writeAll("Upgrade: websocket\n") catch unreachable; + w.writeAll("Connection: Upgrade\n") catch unreachable; + w.print("Sec-WebSocket-Accept: {s}\n", .{accept_key}) catch unreachable; + w.writeAll("\n") catch unreachable; + return buf[0..fbs.pos]; +} + +const FakeServer = struct { + listener: std.posix.socket_t, + port: u16, + thread: std.Thread, + + fn stop(self: *const FakeServer) void { + std.posix.close(self.listener); + self.thread.join(); + } +}; + +/// Shared implementation: start a TCP listener on an ephemeral port and spawn +/// a background thread that accepts one connection, reads the client's upgrade +/// request to extract Sec-WebSocket-Key, then responds using `responseFn`. +fn startWithResponseFn(comptime responseFn: fn ([]u8, []const u8) []const u8) !FakeServer { + const addr = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 0); + const listener = std.posix.socket(std.posix.AF.INET, std.posix.SOCK.STREAM, 0) catch + @panic("failed to create socket"); + std.posix.bind(listener, @ptrCast(&addr.any), addr.getOsSockLen()) catch + @panic("failed to bind"); + std.posix.listen(listener, 1) catch @panic("failed to listen"); + + // Get the assigned port. + var bound_addr: std.posix.sockaddr.storage = undefined; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(bound_addr)); + std.posix.getsockname(listener, @ptrCast(&bound_addr), &addr_len) catch + @panic("failed to getsockname"); + const sa4: *const std.posix.sockaddr.in = @ptrCast(@alignCast(&bound_addr)); + const port = std.mem.bigToNative(u16, sa4.port); + + const thread = std.Thread.spawn(.{}, acceptAndRespond, .{ listener, responseFn }) catch + @panic("failed to spawn thread"); + + return FakeServer{ .listener = listener, .port = port, .thread = thread }; +} + +/// Thread function: accept one connection, read the upgrade request, extract +/// the Sec-WebSocket-Key, compute the accept key, send the response, then close. +fn acceptAndRespond( + listener: std.posix.socket_t, + comptime responseFn: fn ([]u8, []const u8) []const u8, +) void { + var client_addr: std.posix.sockaddr.storage = undefined; + var client_addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(client_addr)); + const conn_fd = std.posix.accept( + listener, + @ptrCast(&client_addr), + &client_addr_len, + 0, + ) catch return; + defer std.posix.close(conn_fd); + const stream = std.net.Stream{ .handle = conn_fd }; + + // Read the client's upgrade request. + var req_buf: [4096]u8 = undefined; + var total: usize = 0; + while (total < req_buf.len) { + const n = stream.read(req_buf[total..]) catch return; + if (n == 0) return; + total += n; + if (std.mem.indexOf(u8, req_buf[0..total], "\r\n\r\n") != null) break; + } + + // Extract Sec-WebSocket-Key from the request. + const key = extractWebSocketKey(req_buf[0..total]) orelse return; + + // Compute the accept key. + var accept_buf: [28]u8 = undefined; + const accept_key = http.computeAcceptKey(&accept_buf, key); + + // Build and send the bare-LF response. + var resp_buf: [512]u8 = undefined; + const response = responseFn(&resp_buf, accept_key); + stream.writeAll(response) catch return; +} + +/// Extract the Sec-WebSocket-Key header value from a raw HTTP request. +fn extractWebSocketKey(request: []const u8) ?[]const u8 { + const needle = "Sec-WebSocket-Key: "; + const start = std.mem.indexOf(u8, request, needle) orelse return null; + const value_start = start + needle.len; + const remaining = request[value_start..]; + const end = std.mem.indexOf(u8, remaining, "\r\n") orelse return null; + return remaining[0..end]; +} diff --git a/src/rpc/webzockets/e2e_tests/client/max_message_tests.zig b/src/rpc/webzockets/e2e_tests/client/max_message_tests.zig new file mode 100644 index 0000000000..83f149fe46 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/client/max_message_tests.zig @@ -0,0 +1,39 @@ +const std = @import("std"); +const testing = std.testing; + +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e client: max_message_size enforcement closes with 1009" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // OversizedServer sends a 2048-byte binary message on open. + // The client is configured with max_message_size = 1024, so it should + // reject the message and close the connection with code 1009 (message too big). + const ts = try servers.startOversizedServer(testing.allocator); + defer ts.stop(); + + var handler: clients.MaxMessageHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestMaxMessageClient.Conn = undefined; + var client = env.initClient(clients.TestMaxMessageClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + .max_message_size = 1024, + }); + try client.connect(); + try env.loop.run(.until_done); + + // Client should have connected successfully + try testing.expect(handler.open_called); + // onClose should have been called (client closed due to oversized message) + try testing.expect(handler.close_called); + // The oversized message should NOT have been delivered to onMessage + try testing.expect(!handler.message_received); + conn.deinit(); +} diff --git a/src/rpc/webzockets/e2e_tests/client/pause_resume_tests.zig b/src/rpc/webzockets/e2e_tests/client/pause_resume_tests.zig new file mode 100644 index 0000000000..88c6a534b0 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/client/pause_resume_tests.zig @@ -0,0 +1,231 @@ +const std = @import("std"); +const testing = std.testing; + +const servers = @import("../support/test_servers.zig"); +const server_handlers = @import("../support/server_handlers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e client pause/resume: sequential processing of server burst" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const messages = [_][]const u8{ "msg-1", "msg-2", "msg-3" }; + var send_ctx: server_handlers.SendMessagesOnOpenHandler.Context = .{ .messages = &messages }; + const ts = try servers.startSendMessagesOnOpenServer(testing.allocator, &send_ctx); + defer ts.stop(); + + // Each server-to-client text frame for "msg-N" (5 bytes payload) is 7 bytes + // on the wire: 2-byte header (FIN+opcode, length) + 5-byte payload. + // 3 messages × 7 bytes = 21 bytes minimum before we resume. + var handler: clients.PauseUntilBufferedClientHandler = .{ + .allocator = testing.allocator, + .expected_messages = messages.len, + .resume_threshold = 21, + .results = std.ArrayList( + clients.PauseUntilBufferedClientHandler.RecvResult, + ).init(testing.allocator), + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestPauseUntilBufferedClient.Conn = undefined; + var client = env.initClient(clients.TestPauseUntilBufferedClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.close_called); + try testing.expectEqual(messages.len, handler.results.items.len); + for (messages, handler.results.items) |expected, result| { + try testing.expectEqualSlices(u8, expected, result.data); + } + + conn.deinit(); +} + +test "e2e client pause/resume: pause mid-stream stops dispatch then delivers on resume" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const messages = [_][]const u8{ "msg-1", "msg-2", "msg-3", "msg-4" }; + var send_ctx: server_handlers.SendMessagesOnOpenHandler.Context = .{ .messages = &messages }; + const ts = try servers.startSendMessagesOnOpenServer(testing.allocator, &send_ctx); + defer ts.stop(); + + // Server sends 4 unmasked text frames: "msg-1".."msg-4" (5 bytes each). + // Each frame = 2 + 5 = 7 bytes → 4 × 7 = 28 bytes threshold. + var handler: clients.PauseMidStreamClientHandler = .{ + .allocator = testing.allocator, + .expected_messages = messages.len, + .resume_threshold = 28, + .results = std.ArrayList( + clients.PauseMidStreamClientHandler.RecvResult, + ).init(testing.allocator), + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestPauseMidStreamClient.Conn = undefined; + var client = env.initClient(clients.TestPauseMidStreamClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.close_called); + try testing.expectEqual(messages.len, handler.results.items.len); + for (messages, handler.results.items) |expected, result| { + try testing.expectEqualSlices(u8, expected, result.data); + } + + conn.deinit(); +} + +test "e2e client pause/resume: close while client is paused" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const messages = [_][]const u8{"msg-1"}; + var send_ctx: server_handlers.SendMessagesOnOpenHandler.Context = .{ .messages = &messages }; + const ts = try servers.startSendMessagesOnOpenServer(testing.allocator, &send_ctx); + defer ts.stop(); + + // Server sends "msg-1" (2 + 5 = 7 bytes) then close frame with code 1000 + // + reason "done" (2 + 6 = 8 bytes). Total threshold: 15 bytes. + var handler: clients.PauseUntilBufferedClientHandler = .{ + .allocator = testing.allocator, + .expected_messages = 0, + .resume_threshold = 15, + .results = std.ArrayList( + clients.PauseUntilBufferedClientHandler.RecvResult, + ).init(testing.allocator), + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestPauseUntilBufferedClient.Conn = undefined; + var client = env.initClient(clients.TestPauseUntilBufferedClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.close_called); + try testing.expectEqual(@as(usize, 1), handler.results.items.len); + try testing.expectEqualSlices(u8, "msg-1", handler.results.items[0].data); + + conn.deinit(); +} + +test "e2e client pause/resume: no re-entrant onMessage dispatch" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const messages = [_][]const u8{ "a", "b", "c", "done" }; + var send_ctx: server_handlers.SendMessagesOnOpenHandler.Context = .{ .messages = &messages }; + const ts = try servers.startSendMessagesOnOpenServer(testing.allocator, &send_ctx); + defer ts.stop(); + + // Server sends 4 unmasked text frames: "a" (3B), "b" (3B), "c" (3B), "done" (6B) = 15 bytes. + var handler: clients.ReentrancyDetectClientHandler = .{ + .allocator = testing.allocator, + .resume_threshold = 15, + .results = std.ArrayList( + clients.ReentrancyDetectClientHandler.RecvResult, + ).init(testing.allocator), + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestReentrancyDetectClient.Conn = undefined; + var client = env.initClient(clients.TestReentrancyDetectClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.close_called); + try testing.expect(!handler.reentrant_detected); + try testing.expectEqual(messages.len, handler.results.items.len); + for (messages, handler.results.items) |expected, result| { + try testing.expectEqualSlices(u8, expected, result.data); + } + + conn.deinit(); +} + +test "e2e client pause/resume: buffer fills while paused (small read buffer)" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + var send_ctx: server_handlers.SendMessagesOnOpenHandler.Context = .{ + .messages = &servers.small_buf_slices, + }; + const ts = try servers.startSendMessagesOnOpenServer(testing.allocator, &send_ctx); + defer ts.stop(); + + const msg_len = servers.small_buf_msg_len; + const msg_count = servers.small_buf_msg_count; + + // 256-byte client read buffer. Server sends 12 unmasked text frames, + // each with 20-byte payload = 22 bytes per frame. 12 × 22 = 264 bytes + // total — exceeds the 256-byte buffer. + // Threshold of 256 ensures we only resume once the buffer is completely + // full and reads have been stopped (freeSpace() == 0). + var handler: clients.PauseUntilBufferedClientHandler = .{ + .allocator = testing.allocator, + .expected_messages = msg_count, + .resume_threshold = 256, + .results = std.ArrayList( + clients.PauseUntilBufferedClientHandler.RecvResult, + ).init(testing.allocator), + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestPauseUntilBufferedSmallBufClient.Conn = undefined; + var client = env.initClient(clients.TestPauseUntilBufferedSmallBufClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.close_called); + try testing.expectEqual(@as(usize, msg_count), handler.results.items.len); + + for (0..msg_count) |i| { + var expected: [msg_len]u8 = undefined; + const byte = @as(u8, @truncate('A' + i)); + @memset(&expected, byte); + try testing.expectEqualSlices(u8, &expected, handler.results.items[i].data); + } + + conn.deinit(); +} diff --git a/src/rpc/webzockets/e2e_tests/client/ping_pong_tests.zig b/src/rpc/webzockets/e2e_tests/client/ping_pong_tests.zig new file mode 100644 index 0000000000..820bc6c281 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/client/ping_pong_tests.zig @@ -0,0 +1,102 @@ +const std = @import("std"); +const testing = std.testing; + +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e client: auto-pong response to server ping" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // PingOnOpenServer sends ping("hello") on open. The client's handler + // (ServerCloseHandler) does not declare onPing, so the library auto-pongs. + // The server receives the pong and closes the connection. If auto-pong + // were broken the connection would not complete cleanly. + const ts = try servers.startPingOnOpenServer(testing.allocator); + defer ts.stop(); + + var handler: clients.ServerCloseHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestServerCloseClient.Conn = undefined; + var client = env.initClient(clients.TestServerCloseClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.close_called); + conn.deinit(); +} + +test "e2e client: onPong callback fires on unsolicited pong" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // PongOnOpenServer sends an unsolicited pong("hello") on open. + // The client's PongTrackingHandler captures the pong data via onPong + // and closes the connection. + const ts = try servers.startPongOnOpenServer(testing.allocator); + defer ts.stop(); + + var handler: clients.PongTrackingHandler = .{ + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestPongTrackingClient.Conn = undefined; + var client = env.initClient(clients.TestPongTrackingClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.pong_received); + const pong_data = handler.pong_data orelse return error.NoData; + try testing.expectEqualSlices(u8, "hello", pong_data); + conn.deinit(); +} + +test "e2e client: explicit onPing handler sends pong" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // PingOnOpenServer sends ping("hello") on open. The client's + // ExplicitPongHandler declares onPing, so the library does NOT auto-pong. + // The handler manually calls sendPong in onPing. The server receives the + // pong and closes. This verifies that declaring onPing disables auto-pong + // and the handler can manage pong responses itself. + const ts = try servers.startPingOnOpenServer(testing.allocator); + defer ts.stop(); + + var handler: clients.ExplicitPongHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestExplicitPongClient.Conn = undefined; + var client = env.initClient(clients.TestExplicitPongClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.ping_received); + try testing.expect(handler.close_called); + conn.deinit(); +} diff --git a/src/rpc/webzockets/e2e_tests/client/tests.zig b/src/rpc/webzockets/e2e_tests/client/tests.zig new file mode 100644 index 0000000000..a0c1c1579b --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/client/tests.zig @@ -0,0 +1,8 @@ +comptime { + _ = @import("connection_tests.zig"); + _ = @import("close_tests.zig"); + _ = @import("ping_pong_tests.zig"); + _ = @import("max_message_tests.zig"); + _ = @import("pause_resume_tests.zig"); + _ = @import("timeout_tests.zig"); +} diff --git a/src/rpc/webzockets/e2e_tests/client/timeout_tests.zig b/src/rpc/webzockets/e2e_tests/client/timeout_tests.zig new file mode 100644 index 0000000000..c6ae3d0e33 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/client/timeout_tests.zig @@ -0,0 +1,220 @@ +const std = @import("std"); +const testing = std.testing; +const ws = @import("webzockets_lib"); + +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e client timeout: close timeout force-disconnects unresponsive peer" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + var srv = try UnresponsiveServer.start(); + defer srv.stop(); + + var handler: clients.CloseOnOpenHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestCloseClient.Conn = undefined; + var client = env.initClient(clients.TestCloseClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, srv.port), + .close_timeout_ms = 200, + }); + try client.connect(); + try env.loop.run(.until_done); + + // Server never responded to close → close timer fired → force disconnect + try testing.expect(handler.open_called); + try testing.expect(handler.close_called); + conn.deinit(); +} + +test "e2e client timeout: normal close completes before timeout" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var handler: clients.EchoTestHandler = .{ + .allocator = testing.allocator, + .send_kind = .text, + .send_data = "hello", + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestEchoClient.Conn = undefined; + var client = env.initClient(clients.TestEchoClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + // Close timeout configured but server responds promptly, so it + // should never fire. + .close_timeout_ms = 500, + }); + try client.connect(); + try env.loop.run(.until_done); + + // Normal echo + close should work fine with close timer enabled + try testing.expect(handler.open_called); + const received_data = handler.received_data orelse return error.NoData; + try testing.expectEqualSlices(u8, "hello", received_data); + conn.deinit(); +} + +test "e2e client timeout: close in onOpen arms close timer" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // Echo server will respond to close, so timer gets cancelled cleanly + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var handler: clients.CloseOnOpenHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestCloseClient.Conn = undefined; + var client = env.initClient(clients.TestCloseClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + .close_timeout_ms = 200, + }); + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + conn.deinit(); +} + +test "e2e client timeout: server-initiated close does not arm close timer" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // Server closes immediately on open (server-initiated close). + // Client echoes and disconnects — no close timer needed. + const ts = try servers.startCloseOnOpenServer(testing.allocator); + defer ts.stop(); + + var handler: clients.ServerCloseHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestServerCloseClient.Conn = undefined; + var client = env.initClient(clients.TestServerCloseClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + .close_timeout_ms = 200, + }); + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expect(handler.close_called); + conn.deinit(); +} + +/// Minimal raw TCP "server" that completes the WebSocket handshake but then +/// ignores all further data (never responds to close frames). Used to test +/// client-side close-handshake timeout behavior. +const UnresponsiveServer = struct { + listener: std.posix.socket_t, + port: u16, + thread: std.Thread, + + fn start() !UnresponsiveServer { + // Low-level posix to test exactly the handshake and timeout behavior we want + const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 0); + const listener = try std.posix.socket( + address.any.family, + std.posix.SOCK.STREAM, + std.posix.IPPROTO.TCP, + ); + errdefer std.posix.close(listener); + + try std.posix.setsockopt( + listener, + std.posix.SOL.SOCKET, + std.posix.SO.REUSEADDR, + &std.mem.toBytes(@as(c_int, 1)), + ); + try std.posix.bind(listener, &address.any, address.getOsSockLen()); + try std.posix.listen(listener, 1); + + var bound_addr: std.posix.sockaddr.storage = undefined; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(bound_addr)); + try std.posix.getsockname(listener, @ptrCast(&bound_addr), &addr_len); + const sa4: *const std.posix.sockaddr.in = @ptrCast(@alignCast(&bound_addr)); + const port = std.mem.bigToNative(u16, sa4.port); + + const thread = try std.Thread.spawn(.{}, acceptAndHandshake, .{listener}); + return .{ + .listener = listener, + .port = port, + .thread = thread, + }; + } + + fn acceptAndHandshake(listener: std.posix.socket_t) void { + var addr: std.posix.sockaddr.storage = undefined; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(addr)); + const conn_fd = std.posix.accept(listener, @ptrCast(&addr), &addr_len, 0) catch return; + const stream = std.net.Stream{ .handle = conn_fd }; + defer stream.close(); + + // Safety net: if the client never disconnects (test bug), don't block forever. + const timeout: std.posix.timeval = .{ .sec = 5, .usec = 0 }; + std.posix.setsockopt( + conn_fd, + std.posix.SOL.SOCKET, + std.posix.SO.RCVTIMEO, + std.mem.asBytes(&timeout), + ) catch {}; + + // Read the HTTP upgrade request + var buf: [4096]u8 = undefined; + var total: usize = 0; + while (total < buf.len) { + const n = stream.read(buf[total..]) catch return; + if (n == 0) return; + total += n; + if (std.mem.indexOf(u8, buf[0..total], "\r\n\r\n")) |_| break; + } + + // Extract the Sec-WebSocket-Key + const request = buf[0..total]; + const key_header = "Sec-WebSocket-Key: "; + const key_start = (std.mem.indexOf(u8, request, key_header) orelse return) + key_header.len; + const key_end = std.mem.indexOf(u8, request[key_start..], "\r\n") orelse return; + const client_key = request[key_start..][0..key_end]; + + // Compute accept key using the library's utility + var accept_key_buf: [28]u8 = undefined; + const accept_key = ws.http.computeAcceptKey(&accept_key_buf, client_key); + + // Send 101 Switching Protocols + var resp_buf: [256]u8 = undefined; + const resp = std.fmt.bufPrint(&resp_buf, "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Accept: {s}\r\n\r\n", .{accept_key}) catch return; + stream.writeAll(resp) catch return; + + // Block until client force-disconnects; never respond to close frames. + // NOTE: timeout set on socket so this wont block forever if there is a bug + _ = stream.read(&buf) catch {}; + } + + fn stop(self: *UnresponsiveServer) void { + std.posix.close(self.listener); + self.thread.join(); + } +}; diff --git a/src/rpc/webzockets/e2e_tests/server/buffer_tier_tests.zig b/src/rpc/webzockets/e2e_tests/server/buffer_tier_tests.zig new file mode 100644 index 0000000000..0f3f0d1bea --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/buffer_tier_tests.zig @@ -0,0 +1,185 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); + +const testing = std.testing; +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e: medium message requiring pooled buffer (8KB)" { + try runBufferTierTest(8 * 1024, null); +} + +test "e2e: large message requiring pooled buffer (32KB)" { + try runBufferTierTest(32 * 1024, null); +} + +test "e2e: message at pool buffer boundary (64KB)" { + try runBufferTierTest(64 * 1024, 128 * 1024); +} + +test "e2e: large message requiring dynamic allocation (128KB)" { + try runBufferTierTest(128 * 1024, 256 * 1024); +} + +test "e2e: very large message requiring dynamic allocation (256KB)" { + try runBufferTierTest(256 * 1024, 512 * 1024); +} + +test "e2e: buffer tier retained after large messages" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const medium_size = 8 * 1024; + const medium_msg = try testing.allocator.alloc(u8, medium_size); + defer testing.allocator.free(medium_msg); + @memset(medium_msg, 'M'); + + const large_size = 100 * 1024; + const large_msg = try testing.allocator.alloc(u8, large_size); + defer testing.allocator.free(large_msg); + @memset(large_msg, 'L'); + + const messages = [_]clients.SequenceHandler.MsgSpec{ + .{ .data = "small", .is_binary = false }, + .{ .data = medium_msg, .is_binary = true }, + .{ .data = "still pooled", .is_binary = false }, + .{ .data = large_msg, .is_binary = true }, + .{ .data = "still dynamic", .is_binary = false }, + }; + + var handler: clients.SequenceHandler = .{ + .messages = &messages, + .results = std.ArrayList(clients.SequenceHandler.RecvResult).init(testing.allocator), + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: TestLargeSeqClient.Conn = undefined; + var client = env.initClient(TestLargeSeqClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + .max_message_size = 128 * 1024, + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expectEqual(@as(usize, 5), handler.results.items.len); + try testing.expectEqualSlices(u8, "small", handler.results.items[0].data); + try testing.expectEqual(medium_size, handler.results.items[1].len); + try testing.expectEqualSlices(u8, "still pooled", handler.results.items[2].data); + try testing.expectEqual(large_size, handler.results.items[3].len); + try testing.expectEqualSlices(u8, "still dynamic", handler.results.items[4].data); +} + +/// Send a binary message of `msg_size` bytes (filled with a repeating index +/// pattern) via the library client, verify the echo matches byte-for-byte. +/// If `max_message_size` is non-null it is forwarded to the client config. +fn runBufferTierTest(msg_size: usize, max_message_size: ?usize) !void { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const msg = try testing.allocator.alloc(u8, msg_size); + defer testing.allocator.free(msg); + for (msg, 0..) |*byte, i| byte.* = @truncate(i); + + var handler: LargeMessageHandler = .{ + .send_data = msg, + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: TestLargeClient.Conn = undefined; + var config: TestLargeClient.Config = .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }; + if (max_message_size) |mms| { + config.max_message_size = mms; + } + var client = env.initClient(TestLargeClient, &handler, &conn, config); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expectEqual(msg_size, handler.received_len); + const received = handler.received_data orelse return error.NoData; + for (received, 0..) |byte, i| { + try testing.expectEqual(@as(u8, @truncate(i)), byte); + } +} + +/// Client-side handler for large message tests. +/// Sends binary data on open, captures response, then closes. +const LargeMessageHandler = struct { + send_data: []const u8, + received_data: ?[]const u8 = null, + received_len: usize = 0, + open_called: bool = false, + allocator: std.mem.Allocator, + sent_data: ?[]const u8 = null, + + pub fn deinit(self: *LargeMessageHandler) void { + if (self.received_data) |data| { + self.allocator.free(data); + self.received_data = null; + } + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } + + pub fn onOpen(self: *LargeMessageHandler, conn: anytype) void { + self.open_called = true; + const copy = self.allocator.dupe(u8, self.send_data) catch return; + conn.sendBinary(copy) catch { + self.allocator.free(copy); + return; + }; + self.sent_data = copy; + } + + pub fn onMessage(self: *LargeMessageHandler, conn: anytype, message: ws.Message) void { + self.received_data = self.allocator.dupe(u8, message.data) catch null; + self.received_len = message.data.len; + conn.close(.normal, ""); + } + + pub fn onWriteComplete(self: *LargeMessageHandler, _: anytype) void { + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } + + pub fn onClose(self: *LargeMessageHandler, _: anytype) void { + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } +}; + +/// Single client type with large buffer for all buffer tier tests. +const TestLargeClient = ws.Client(LargeMessageHandler, 256 * 1024); + +/// Large sequence client for mixed message sequence tests. +const TestLargeSeqClient = ws.Client(clients.SequenceHandler, 128 * 1024); diff --git a/src/rpc/webzockets/e2e_tests/server/close_tests.zig b/src/rpc/webzockets/e2e_tests/server/close_tests.zig new file mode 100644 index 0000000000..6cc9a1eceb --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/close_tests.zig @@ -0,0 +1,144 @@ +const std = @import("std"); +const testing = std.testing; + +const servers = @import("../support/test_servers.zig"); +const RawClient = @import("../support/raw_client.zig").RawClient; +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; +const helpers = @import("../support/test_helpers.zig"); + +const wait_ms: u64 = 2_000; + +test "e2e close: normal code 1000" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + var msg = "test".*; + try client.write(&msg); + + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, "test", response.data); + + try client.close(.{ .code = 1000 }); + try helpers.expectCloseWithCode(&client, 1000); +} + +test "e2e close: custom code in registered range (3000-3999)" { + try testCloseEcho(3500, "custom close"); +} + +test "e2e close: custom code in private range (4000-4999)" { + try testCloseEcho(4999, "private close"); +} + +test "e2e close: server echoes close frame" { + try testCloseEcho(1000, "bye"); +} + +test "e2e close: server rejects invalid close code 0 with 1002" { + var payload = [_]u8{ 0x00, 0x00 }; + try testCloseRejection(&payload, 1002); +} + +test "e2e close: server rejects reserved close code 1004 with 1002" { + var payload = [_]u8{ 0x03, 0xEC }; + try testCloseRejection(&payload, 1002); +} + +test "e2e close: server rejects API-only close code 1005 with 1002" { + var payload = [_]u8{ 0x03, 0xED }; + try testCloseRejection(&payload, 1002); +} + +test "e2e close: server rejects out-of-range close code 5000 with 1002" { + var payload = [_]u8{ 0x13, 0x88 }; + try testCloseRejection(&payload, 1002); +} + +test "e2e close: server rejects invalid UTF-8 in close reason with 1007" { + // 0xFE is never valid in UTF-8 + var payload = [_]u8{ 0x03, 0xE8, 0xFE, 0xFF }; + try testCloseRejection(&payload, 1007); +} + +test "e2e close: server rejects 1-byte close payload with 1002" { + var payload = [_]u8{0x42}; + try testCloseRejection(&payload, 1002); +} + +test "e2e close: server rejects API-only close code 1006 with 1002" { + var payload = [_]u8{ 0x03, 0xEE }; + try testCloseRejection(&payload, 1002); +} + +test "e2e close: server rejects close code 999 (below valid range) with 1002" { + var payload = [_]u8{ 0x03, 0xE7 }; + try testCloseRejection(&payload, 1002); +} + +test "e2e close: server rejects close code 2000 (gap range 1012-2999) with 1002" { + var payload = [_]u8{ 0x07, 0xD0 }; + try testCloseRejection(&payload, 1002); +} + +test "e2e close: server echoes max-length close reason (123 bytes)" { + // Control frame max payload is 125 bytes; 2 for code leaves 123 for reason + var reason: [123]u8 = undefined; + @memset(&reason, 'R'); + try testCloseEcho(1000, &reason); +} + +test "e2e close: server accepts close with no payload" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + var empty = [0]u8{}; + try client.writeFrame(.close, &empty); + + const response = try client.waitForCloseFrame(wait_ms); + defer client.done(response); +} + +/// Send a close frame with the given raw payload and assert the server +/// responds with the expected close code. +fn testCloseRejection(close_payload: []u8, expected_code: u16) !void { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + try client.writeFrame(.close, close_payload); + try helpers.expectCloseWithCode(&client, expected_code); +} + +/// Send a close frame with the given code and reason, and assert the server +/// echoes them back. +fn testCloseEcho(code: u16, reason: []const u8) !void { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + try client.close(.{ .code = code, .reason = reason }); + try helpers.expectCloseWithCodeAndReason(&client, code, reason); +} diff --git a/src/rpc/webzockets/e2e_tests/server/echo_tests.zig b/src/rpc/webzockets/e2e_tests/server/echo_tests.zig new file mode 100644 index 0000000000..b70549d7d1 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/echo_tests.zig @@ -0,0 +1,263 @@ +const std = @import("std"); + +const testing = std.testing; +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e: text echo" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var handler: clients.EchoTestHandler = .{ + .send_kind = .text, + .send_data = "hello", + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestEchoClient.Conn = undefined; + var client = env.initClient(clients.TestEchoClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + const received_type = handler.received_type orelse return error.NoData; + const received_data = handler.received_data orelse return error.NoData; + try testing.expectEqual(.text, received_type); + try testing.expectEqualSlices(u8, "hello", received_data); +} + +test "e2e: binary echo" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var handler: clients.EchoTestHandler = .{ + .send_kind = .binary, + .send_data = &[_]u8{ 0x01, 0x02, 0x03, 0xFF }, + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestEchoClient.Conn = undefined; + var client = env.initClient(clients.TestEchoClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + const received_type = handler.received_type orelse return error.NoData; + const received_data = handler.received_data orelse return error.NoData; + try testing.expectEqual(.binary, received_type); + try testing.expectEqualSlices(u8, &[_]u8{ 0x01, 0x02, 0x03, 0xFF }, received_data); +} + +test "e2e: ping/pong" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var handler: clients.EchoTestHandler = .{ + .send_kind = .ping, + .send_data = "ping", + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestEchoClient.Conn = undefined; + var client = env.initClient(clients.TestEchoClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + const received_type = handler.received_type orelse return error.NoData; + const received_data = handler.received_data orelse return error.NoData; + try testing.expectEqual(.pong, received_type); + try testing.expectEqualSlices(u8, "ping", received_data); +} + +test "e2e: close handshake" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var handler: clients.CloseOnOpenHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestCloseClient.Conn = undefined; + var client = env.initClient(clients.TestCloseClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); +} + +test "e2e: multiple messages in sequence" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const messages = [_]clients.SequenceHandler.MsgSpec{ + .{ .data = "one" }, + .{ .data = "two" }, + .{ .data = "three" }, + }; + var handler: clients.SequenceHandler = .{ + .messages = &messages, + .results = std.ArrayList(clients.SequenceHandler.RecvResult).init(testing.allocator), + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestSequenceClient.Conn = undefined; + var client = env.initClient(clients.TestSequenceClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expectEqual(@as(usize, 3), handler.results.items.len); + try testing.expectEqualSlices(u8, "one", handler.results.items[0].data); + try testing.expectEqualSlices(u8, "two", handler.results.items[1].data); + try testing.expectEqualSlices(u8, "three", handler.results.items[2].data); +} + +test "e2e: large message (>125 bytes)" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const large_msg = try testing.allocator.alloc(u8, 1000); + defer testing.allocator.free(large_msg); + @memset(large_msg, 'A'); + + var handler: clients.EchoTestHandler = .{ + .send_kind = .binary, + .send_data = large_msg, + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestEchoClient.Conn = undefined; + var client = env.initClient(clients.TestEchoClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + const received_type = handler.received_type orelse return error.NoData; + const received = handler.received_data orelse return error.NoData; + try testing.expectEqual(.binary, received_type); + try testing.expectEqual(@as(usize, 1000), received.len); + for (received) |byte| { + try testing.expectEqual(@as(u8, 'A'), byte); + } +} + +test "e2e: multiple concurrent connections" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var handler1: clients.EchoTestHandler = .{ + .send_kind = .text, + .send_data = "from_client_1", + .allocator = testing.allocator, + }; + defer handler1.deinit(); + + var handler2: clients.EchoTestHandler = .{ + .send_kind = .text, + .send_data = "from_client_2", + .allocator = testing.allocator, + }; + defer handler2.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn1: clients.TestEchoClient.Conn = undefined; + var client1 = env.initClient(clients.TestEchoClient, &handler1, &conn1, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler1.open_called) conn1.deinit(); + + var conn2: clients.TestEchoClient.Conn = undefined; + var client2 = env.initClient(clients.TestEchoClient, &handler2, &conn2, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler2.open_called) conn2.deinit(); + + try client1.connect(); + try client2.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler1.open_called); + const received1 = handler1.received_data orelse return error.NoData; + try testing.expectEqualSlices(u8, "from_client_1", received1); + + try testing.expect(handler2.open_called); + const received2 = handler2.received_data orelse return error.NoData; + try testing.expectEqualSlices(u8, "from_client_2", received2); +} diff --git a/src/rpc/webzockets/e2e_tests/server/fragmentation_tests.zig b/src/rpc/webzockets/e2e_tests/server/fragmentation_tests.zig new file mode 100644 index 0000000000..cb18b9aa43 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/fragmentation_tests.zig @@ -0,0 +1,287 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); +const testing = std.testing; + +const servers = @import("../support/test_servers.zig"); +const RawClient = @import("../support/raw_client.zig").RawClient; +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +const wait_ms: u64 = 2_000; + +test "e2e fragment: text message in 2 fragments" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Fragment 1: text FIN=0 "Hel" + var frag1 = "Hel".*; + try client.writeFrameEx(TEXT, &frag1, .{ .fin = false }); + + // Fragment 2: continuation FIN=1 "lo" + var frag2 = "lo".*; + try client.writeFrameEx(CONTINUATION, &frag2, .{}); + + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, "Hello", response.data); +} + +test "e2e fragment: text message in 6 fragments" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + const parts = [_][]const u8{ "This ", "is ", "a ", "frag", "mented ", "message" }; + const expected = "This is a fragmented message"; + + // First fragment: text FIN=0 + var buf0: [5]u8 = undefined; + @memcpy(&buf0, parts[0]); + try client.writeFrameEx(TEXT, &buf0, .{ .fin = false }); + + // Middle fragments: continuation FIN=0 + inline for (1..5) |i| { + var buf: [parts[i].len]u8 = undefined; + @memcpy(&buf, parts[i]); + try client.writeFrameEx(CONTINUATION, &buf, .{ .fin = false }); + } + + // Final fragment: continuation FIN=1 + var buf_last: [7]u8 = undefined; + @memcpy(&buf_last, parts[5]); + try client.writeFrameEx(CONTINUATION, &buf_last, .{}); + + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, expected, response.data); +} + +test "e2e fragment: binary message in 2 fragments" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Fragment 1: binary FIN=0 + var frag1 = [_]u8{ 0x01, 0x02, 0x03 }; + try client.writeFrameEx(BINARY, &frag1, .{ .fin = false }); + + // Fragment 2: continuation FIN=1 + var frag2 = [_]u8{ 0x04, 0x05 }; + try client.writeFrameEx(CONTINUATION, &frag2, .{}); + + const response = try client.waitForMessageType(.binary, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, &[_]u8{ 0x01, 0x02, 0x03, 0x04, 0x05 }, response.data); +} + +test "e2e fragment: ping interleaved during fragmentation" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Fragment 1: text FIN=0 "Hel" + var frag1 = "Hel".*; + try client.writeFrameEx(TEXT, &frag1, .{ .fin = false }); + + // Interleaved ping + var ping_payload = "ping".*; + try client.writeFrameEx(PING, &ping_payload, .{}); + + // Fragment 2: continuation FIN=1 "lo" + var frag2 = "lo".*; + try client.writeFrameEx(CONTINUATION, &frag2, .{}); + + // Should get the pong first (server dispatches control frames immediately) + const pong = try client.waitForMessageType(.pong, wait_ms); + defer client.done(pong); + try testing.expectEqualSlices(u8, "ping", pong.data); + + // Then the reassembled text message + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, "Hello", response.data); +} + +test "e2e fragment: non-empty first, empty final continuation" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Fragment 1: text FIN=0 "data" + var frag1 = "data".*; + try client.writeFrameEx(TEXT, &frag1, .{ .fin = false }); + + // Fragment 2: continuation FIN=1, empty payload + var frag2 = [0]u8{}; + try client.writeFrameEx(CONTINUATION, &frag2, .{}); + + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, "data", response.data); +} + +test "e2e fragment: empty first, non-empty final continuation" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Fragment 1: text FIN=0, empty payload + var frag1 = [0]u8{}; + try client.writeFrameEx(TEXT, &frag1, .{ .fin = false }); + + // Fragment 2: continuation FIN=1 "data" + var frag2 = "data".*; + try client.writeFrameEx(CONTINUATION, &frag2, .{}); + + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, "data", response.data); +} + +test "e2e fragment: all-empty fragments produce empty message" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Fragment 1: text FIN=0, empty + var frag1 = [0]u8{}; + try client.writeFrameEx(TEXT, &frag1, .{ .fin = false }); + + // Fragment 2: continuation FIN=1, empty + var frag2 = [0]u8{}; + try client.writeFrameEx(CONTINUATION, &frag2, .{}); + + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqual(@as(usize, 0), response.data.len); +} + +test "e2e fragment: 8KB message across 4 fragments" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + // 8KB response needs a larger read buffer than the default 4096 + var client = try RawClient.connectEx( + testing.allocator, + ts.port, + .{ .read_buf_size = 16 * 1024 }, + ); + defer client.deinit(); + + // 8192 bytes total, split into 4 × 2048-byte fragments + const fragment_size = 2048; + const num_fragments = 4; + const total_size = fragment_size * num_fragments; + + // Build expected payload: each fragment filled with its index byte + var expected: [total_size]u8 = undefined; + for (0..num_fragments) |i| { + @memset(expected[i * fragment_size ..][0..fragment_size], @as(u8, @truncate(i + 'A'))); + } + + // Send fragments + for (0..num_fragments) |i| { + var buf: [fragment_size]u8 = undefined; + @memcpy(&buf, expected[i * fragment_size ..][0..fragment_size]); + + const opcode: u4 = if (i == 0) TEXT else CONTINUATION; + const is_last = (i == num_fragments - 1); + try client.writeFrameEx(opcode, &buf, .{ .fin = is_last }); + } + + // Note: assumes the echo server responds with a single unfragmented frame. + // Fragmented responses would need multiple waits. + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqual(@as(usize, total_size), response.data.len); + try testing.expectEqualSlices(u8, &expected, response.data); +} + +test "e2e fragment: 256KB message across 64 fragments" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + // 256KB response needs a much larger read buffer + var client = try RawClient.connectEx( + testing.allocator, + ts.port, + .{ .read_buf_size = 512 * 1024 }, + ); + defer client.deinit(); + + // 262144 bytes total, split into 64 × 4096-byte fragments + const fragment_size = 4096; + const num_fragments = 64; + const total_size = fragment_size * num_fragments; + + // Build expected payload: repeating pattern (heap-allocated to avoid 256KB on the stack) + const expected = try testing.allocator.alloc(u8, total_size); + defer testing.allocator.free(expected); + for (0..total_size) |i| { + expected[i] = @as(u8, @truncate(i)); + } + + // Send fragments + for (0..num_fragments) |i| { + var buf: [fragment_size]u8 = undefined; + @memcpy(&buf, expected[i * fragment_size ..][0..fragment_size]); + + const opcode: u4 = if (i == 0) BINARY else CONTINUATION; + const is_last = (i == num_fragments - 1); + try client.writeFrameEx(opcode, &buf, .{ .fin = is_last }); + } + + // Note: assumes the echo server responds with a single unfragmented frame. + // Fragmented responses would need multiple waits. + const response = try client.waitForMessageType(.binary, wait_ms); + defer client.done(response); + try testing.expectEqual(@as(usize, total_size), response.data.len); + try testing.expectEqualSlices(u8, expected, response.data); +} + +const Opcode = ws.Opcode; +const TEXT: u4 = @intFromEnum(Opcode.text); +const BINARY: u4 = @intFromEnum(Opcode.binary); +const CONTINUATION: u4 = @intFromEnum(Opcode.continuation); +const PING: u4 = @intFromEnum(Opcode.ping); diff --git a/src/rpc/webzockets/e2e_tests/server/handshake_tests.zig b/src/rpc/webzockets/e2e_tests/server/handshake_tests.zig new file mode 100644 index 0000000000..255419e52d --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/handshake_tests.zig @@ -0,0 +1,477 @@ +const std = @import("std"); +const testing = std.testing; +const ws = @import("webzockets_lib"); + +const servers = @import("../support/test_servers.zig"); +const RawClient = @import("../support/raw_client.zig").RawClient; +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; +const verifyServerFunctional = @import("../support/test_helpers.zig").verifyServerFunctional; + +const wait_ms: u64 = 2_000; + +test "e2e handshake: malformed HTTP request (garbage bytes)" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + // Send garbage bytes terminated with \r\n\r\n so the server sees a + // "complete" header block and attempts to parse it as HTTP. + try stream.writeAll("\x00\x01\x02\x03GARBAGE\xff\xfe\r\n\r\n"); + + try expectClosed(stream); +} + +test "e2e handshake: missing Upgrade header" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + var buf: [512]u8 = undefined; + const request = buildRequest(&buf, .{ .include_upgrade = false }); + try stream.writeAll(request); + + try expectClosed(stream); +} + +test "e2e handshake: missing Sec-WebSocket-Key" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + var buf: [512]u8 = undefined; + const request = buildRequest(&buf, .{ .include_key = false }); + try stream.writeAll(request); + + try expectClosed(stream); +} + +test "e2e handshake: wrong HTTP method (POST)" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + var buf: [512]u8 = undefined; + const request = buildRequest(&buf, .{ .method = "POST" }); + try stream.writeAll(request); + + try expectClosed(stream); +} + +test "e2e handshake: unsupported WebSocket version" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + var buf: [512]u8 = undefined; + const request = buildRequest(&buf, .{ .ws_version = "12" }); + try stream.writeAll(request); + + try expectClosed(stream); +} + +test "e2e handshake: incremental request (chunked with delays)" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + // Build a valid upgrade request, then send it in small chunks with delays + // to exercise the incremental HeadParser path (multiple reads needed). + var buf: [512]u8 = undefined; + const request = buildRequest(&buf, .{}); + + // Send in 3 chunks: first line, middle headers, final headers + \r\n\r\n + const chunk1_end = std.mem.indexOf(u8, request, "\r\n").? + 2; + const chunk2_end = chunk1_end + ((request.len - chunk1_end) / 2); + + try stream.writeAll(request[0..chunk1_end]); + std.time.sleep(10 * std.time.ns_per_ms); + + try stream.writeAll(request[chunk1_end..chunk2_end]); + std.time.sleep(10 * std.time.ns_per_ms); + + try stream.writeAll(request[chunk2_end..]); + + // If the server successfully parsed the incremental request, it will send + // back a 101 Switching Protocols response. + var response_buf: [512]u8 = undefined; + var total_read: usize = 0; + while (total_read < response_buf.len) { + const n = stream.read(response_buf[total_read..]) catch break; + if (n == 0) break; + total_read += n; + if (std.mem.indexOf(u8, response_buf[0..total_read], "\r\n\r\n") != null) break; + } + + const response = response_buf[0..total_read]; + try testing.expect(std.mem.startsWith(u8, response, "HTTP/1.1 101 Switching Protocols\r\n")); +} + +test "e2e handshake: byte-by-byte request" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + // Send a valid upgrade request one byte at a time to stress-test + // the incremental HeadParser with minimal feed sizes. + var buf: [512]u8 = undefined; + const request = buildRequest(&buf, .{}); + + for (request) |byte| { + try stream.writeAll(&.{byte}); + std.time.sleep(1 * std.time.ns_per_ms); + } + + // Read the 101 response. + var response_buf: [512]u8 = undefined; + var total_read: usize = 0; + while (total_read < response_buf.len) { + const n = stream.read(response_buf[total_read..]) catch break; + if (n == 0) break; + total_read += n; + if (std.mem.indexOf(u8, response_buf[0..total_read], "\r\n\r\n") != null) break; + } + + const response = response_buf[0..total_read]; + try testing.expect(std.mem.startsWith(u8, response, "HTTP/1.1 101 Switching Protocols\r\n")); +} + +test "e2e handshake: bare LF headers (no \\r\\n) doesn't crash server" { + // HeadParser accepts \n\n as end-of-headers. Send a request with \r\n on + // the request line but bare \n on headers, terminated by \n\n. The server + // must reject it gracefully (close the connection) without crashing. + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + // Request line ends with \r\n (so parseRequest finds the request line), + // but all headers use bare \n. The \n\n terminates the head for HeadParser. + try stream.writeAll( + "GET / HTTP/1.1\r\n" ++ + "Host: 127.0.0.1\n" ++ + "Upgrade: websocket\n" ++ + "Connection: Upgrade\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\n" ++ + "Sec-WebSocket-Version: 13\n" ++ + "\n", + ); + + try expectClosed(stream); + + // Verify the server is still functional after handling the malformed request. + std.time.sleep(50 * std.time.ns_per_ms); + try verifyServerFunctional(ts.port); +} + +test "e2e handshake: fully bare LF request (no \\r\\n at all) doesn't crash server" { + // Everything uses bare \n — no \r\n anywhere. HeadParser sees \n\n and + // reports finished, but the headers won't be found without \r\n line endings. + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + try stream.writeAll( + "GET / HTTP/1.1\n" ++ + "Host: 127.0.0.1\n" ++ + "Upgrade: websocket\n" ++ + "Connection: Upgrade\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\n" ++ + "Sec-WebSocket-Version: 13\n" ++ + "\n", + ); + + try expectClosed(stream); + + // Verify the server is still functional after handling the malformed request. + std.time.sleep(50 * std.time.ns_per_ms); + try verifyServerFunctional(ts.port); +} + +test "e2e handshake: partial request then disconnect" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + // Send a partial HTTP request and immediately close the connection + { + const stream = try rawConnect(ts.port); + try stream.writeAll("GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n"); + stream.close(); + } + + // Give the server time to clean up the aborted handshake + std.time.sleep(50 * std.time.ns_per_ms); + + // Verify the server is still functional after the aborted handshake + try verifyServerFunctional(ts.port); +} + +test "e2e handshake: headers exceeding read buffer size" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const stream = try rawConnect(ts.port); + defer stream.close(); + + // Build a request whose headers exceed the server's read buffer (4096 bytes). + // The \r\n\r\n terminator won't appear within the first 4096 bytes, so the + // server fills its buffer without finding a complete header block and + // transitions to the failed state. + var big_buf: [8192]u8 = undefined; + var fbs = std.io.fixedBufferStream(&big_buf); + const w = fbs.writer(); + w.writeAll("GET / HTTP/1.1\r\n") catch unreachable; + w.writeAll("Host: 127.0.0.1\r\n") catch unreachable; + w.writeAll("X-Padding: ") catch unreachable; + // Fill with padding until we've written well past the server buffer size + const target = servers.default_read_buf_size + 500; + while (fbs.pos < target) { + w.writeByte('A') catch unreachable; + } + w.writeAll("\r\n") catch unreachable; + w.writeAll("Upgrade: websocket\r\n") catch unreachable; + w.writeAll("Connection: Upgrade\r\n") catch unreachable; + w.writeAll("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n") catch unreachable; + w.writeAll("Sec-WebSocket-Version: 13\r\n") catch unreachable; + w.writeAll("\r\n") catch unreachable; + + try stream.writeAll(big_buf[0..fbs.pos]); + + try expectClosed(stream); +} + +/// Connect raw TCP to the test server with a 2s read timeout. +fn rawConnect(port: u16) !std.net.Stream { + const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, port); + const stream = try std.net.tcpConnectToAddress(address); + errdefer stream.close(); + const timeout = std.posix.timeval{ .sec = 2, .usec = 0 }; + try std.posix.setsockopt( + stream.handle, + std.posix.SOL.SOCKET, + std.posix.SO.RCVTIMEO, + std.mem.asBytes(&timeout), + ); + return stream; +} + +/// Read from `stream` until the server closes the connection. Returns success +/// if the connection is closed (read returns 0 or a connection error). Returns +/// `error.ConnectionNotClosed` if a 2s read timeout fires first. +fn expectClosed(stream: std.net.Stream) !void { + var buf: [4096]u8 = undefined; + while (true) { + const n = stream.read(&buf) catch |err| switch (err) { + error.WouldBlock => return error.ConnectionNotClosed, + else => return, // Connection broken = closed + }; + if (n == 0) return; // Clean FIN + } +} + +test "e2e handshake: onHandshakeFailed fires on connection pool exhaustion" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + var ctx = HandshakeFailContext{}; + + const ts = try HandshakeFailTrackingTestServer.start(testing.allocator, .{ + .address = servers.localhost, + .max_connections = 1, + .handler_context = &ctx, + }); + defer ts.stop(); + + // First connection: fills the single connection slot. + var client1 = try RawClient.connect(testing.allocator, ts.port); + defer client1.deinit(); + + // Second connection: handler.init() succeeds but transitionToConnection + // fails (pool exhausted), so onHandshakeFailed should fire. + const stream2 = try rawConnect(ts.port); + defer stream2.close(); + + // Send a valid upgrade request on the second connection. + var buf: [512]u8 = undefined; + const request = buildRequest(&buf, .{}); + try stream2.writeAll(request); + + // The server should close this connection since the pool is exhausted. + try expectClosed(stream2); + + // Wait (bounded) for the server thread to run the callback. + try ctx.called.timedWait(2 * std.time.ns_per_s); + + // Verify the first connection is still functional (server didn't crash). + try verifyEchoOnClient(&client1); +} + +test "e2e handshake: handler without onHandshakeFailed still works" { + // The default EchoHandler does NOT declare onHandshakeFailed. + // Verify that pool exhaustion doesn't crash — the handler is silently + // dropped (pre-existing behavior for handlers without the callback). + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try LimitedEchoTestServer.start(testing.allocator, .{ + .address = servers.localhost, + .handler_context = {}, + .max_connections = 1, + }); + defer ts.stop(); + + // First connection fills the pool. + var client1 = try RawClient.connect(testing.allocator, ts.port); + defer client1.deinit(); + + // Second connection: init succeeds, pool exhausted, no onHandshakeFailed + // declared — should just close without crashing. + const stream2 = try rawConnect(ts.port); + defer stream2.close(); + + var buf: [512]u8 = undefined; + const request = buildRequest(&buf, .{}); + try stream2.writeAll(request); + + try expectClosed(stream2); + + // Verify the first connection is still functional (server didn't crash). + try verifyEchoOnClient(&client1); +} + +const HandshakeFailContext = struct { + called: std.Thread.ResetEvent = .{}, +}; + +/// Echo handler that tracks `onHandshakeFailed` calls via a caller-provided +/// `HandshakeFailContext` passed through `Config.handler_context`. +const HandshakeFailTrackingHandler = struct { + pub const Context = HandshakeFailContext; + + ctx: *Context, + + pub fn init(_: ws.http.Request, context: *Context) !HandshakeFailTrackingHandler { + return .{ .ctx = context }; + } + + pub fn onHandshakeFailed(self: *HandshakeFailTrackingHandler) void { + // This should only ever be called once per handshake. + if (self.ctx.called.isSet()) @panic("onHandshakeFailed called more than once"); + self.ctx.called.set(); + } + + pub fn onMessage(_: *HandshakeFailTrackingHandler, conn: anytype, message: ws.Message) void { + switch (message.type) { + .text => conn.sendText(@constCast(message.data)) catch return, + .binary => conn.sendBinary(@constCast(message.data)) catch return, + else => {}, + } + } + + pub fn onWriteComplete(_: *HandshakeFailTrackingHandler, _: anytype) void {} + pub fn onClose(_: *HandshakeFailTrackingHandler, _: anytype) void {} +}; + +const HandshakeFailTrackingServer = ws.Server( + HandshakeFailTrackingHandler, + servers.default_read_buf_size, +); +const HandshakeFailTrackingTestServer = servers.ServerRunner(HandshakeFailTrackingServer); + +const LimitedEchoServer = ws.Server( + servers.EchoHandler, + servers.default_read_buf_size, +); +const LimitedEchoTestServer = servers.ServerRunner(LimitedEchoServer); + +/// Send a text message through an existing RawClient and verify the echo. +/// Used to confirm the server is still functional after error scenarios +/// where the client's connection slot is already occupied. +fn verifyEchoOnClient(client: *RawClient) !void { + var msg = "echo check".*; + try client.write(&msg); + + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, "echo check", response.data); +} + +/// Options for building a crafted HTTP upgrade request. +const RequestOpts = struct { + method: []const u8 = "GET", + path: []const u8 = "/", + http_version: []const u8 = "HTTP/1.1", + include_host: bool = true, + include_upgrade: bool = true, + include_connection: bool = true, + include_key: bool = true, + ws_version: []const u8 = "13", + include_ws_version: bool = true, +}; + +/// Build an HTTP upgrade request into `buf` with selective header omission. +fn buildRequest(buf: []u8, opts: RequestOpts) []const u8 { + var fbs = std.io.fixedBufferStream(buf); + const w = fbs.writer(); + w.print("{s} {s} {s}\r\n", .{ opts.method, opts.path, opts.http_version }) catch unreachable; + if (opts.include_host) w.writeAll("Host: 127.0.0.1\r\n") catch unreachable; + if (opts.include_upgrade) w.writeAll("Upgrade: websocket\r\n") catch unreachable; + if (opts.include_connection) w.writeAll("Connection: Upgrade\r\n") catch unreachable; + if (opts.include_key) { + w.writeAll("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n") catch unreachable; + } + if (opts.include_ws_version) { + w.print("Sec-WebSocket-Version: {s}\r\n", .{opts.ws_version}) catch unreachable; + } + w.writeAll("\r\n") catch unreachable; + return buf[0..fbs.pos]; +} diff --git a/src/rpc/webzockets/e2e_tests/server/pause_resume_tests.zig b/src/rpc/webzockets/e2e_tests/server/pause_resume_tests.zig new file mode 100644 index 0000000000..f401949170 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/pause_resume_tests.zig @@ -0,0 +1,183 @@ +const std = @import("std"); +const testing = std.testing; + +const servers = @import("../support/test_servers.zig"); +const server_handlers = @import("../support/server_handlers.zig"); +const RawClient = @import("../support/raw_client.zig").RawClient; +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; +const helpers = @import("../support/test_helpers.zig"); + +const wait_ms: u64 = 4_000; + +fn expectText(client: *RawClient, expected: []const u8) !void { + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, expected, response.data); +} + +test "e2e pause/resume server: sequential processing of buffered burst" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // Each masked text frame: 6-byte header (2 + 4 mask) + payload. + // "a","b","c","d" = 4 × (6 + 1) = 28 bytes. + var ctx: server_handlers.PauseUntilBufferedEchoHandler.Context = .{ + .resume_threshold = 28, + .expected_messages = 4, + }; + const ts = try servers.startPauseUntilBufferedEchoServer(testing.allocator, &ctx); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + var m1 = "a".*; + var m2 = "b".*; + var m3 = "c".*; + var m4 = "d".*; + + try client.writeFrame(.text, &m1); + try client.writeFrame(.text, &m2); + try client.writeFrame(.text, &m3); + try client.writeFrame(.text, &m4); + + try expectText(&client, "a"); + try expectText(&client, "b"); + try expectText(&client, "c"); + try expectText(&client, "d"); + + try helpers.expectCloseWithCode(&client, 1000); +} + +test "e2e pause/resume server: pause mid-stream stops dispatch then delivers on resume" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // 4 × (6 + 1) = 28 bytes threshold. + var mid_ctx: server_handlers.PauseMidStreamEchoHandler.Context = .{ + .resume_threshold = 28, + .expected_messages = 4, + }; + const ts = try servers.startPauseMidStreamEchoServer(testing.allocator, &mid_ctx); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + var m1 = "a".*; + var m2 = "b".*; + var m3 = "c".*; + var m4 = "d".*; + + try client.writeFrame(.text, &m1); + try client.writeFrame(.text, &m2); + try client.writeFrame(.text, &m3); + try client.writeFrame(.text, &m4); + + // Each echo arrives only after the handler resumes from onWriteComplete. + try expectText(&client, "a"); + try expectText(&client, "b"); + try expectText(&client, "c"); + try expectText(&client, "d"); + + try helpers.expectCloseWithCode(&client, 1000); +} + +test "e2e pause/resume server: close frame while server is paused" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // "hello" text frame: 6 + 5 = 11 bytes. + // Close frame (code 1000, no reason): 6 + 2 = 8 bytes. + // Total threshold: 19 bytes. + var ctx: server_handlers.PauseUntilBufferedEchoHandler.Context = .{ + .resume_threshold = 19, + .expected_messages = 0, + }; + const ts = try servers.startPauseUntilBufferedEchoServer(testing.allocator, &ctx); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + var msg = "hello".*; + try client.writeFrame(.text, &msg); + try client.close(.{ .code = 1000 }); + + // Echo should still be delivered before close handshake completes. + try expectText(&client, "hello"); + try helpers.expectCloseWithCode(&client, 1000); +} + +test "e2e pause/resume server: no re-entrant onMessage dispatch" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // "a"(7) + "b"(7) + "c"(7) + "done"(10) = 31 bytes. + var ctx: server_handlers.ReentrancyDetectHandler.Context = .{ + .resume_threshold = 31, + }; + const ts = try servers.startReentrancyDetectServer(testing.allocator, &ctx); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + var m1 = "a".*; + var m2 = "b".*; + var m3 = "c".*; + var done = "done".*; + + try client.writeFrame(.text, &m1); + try client.writeFrame(.text, &m2); + try client.writeFrame(.text, &m3); + try client.writeFrame(.text, &done); + + // Without the re-entrancy guard, pauseReads() + resumeReads() inside + // onMessage would recursively dispatch into onMessage, triggering a + // close with policy_violation (1008). With the guard, messages dispatch + // sequentially and "done" triggers a normal close (1000). + try helpers.expectCloseWithCode(&client, 1000); +} + +test "e2e pause/resume server: buffer fills while paused (small read buffer)" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // 256-byte read buffer. Each masked frame with 20-byte payload = 26 bytes. + // 12 messages × 26 = 312 bytes total — exceeds the 256-byte buffer. + // Threshold of 256 ensures we only resume once the buffer is completely + // full and reads have been stopped (freeSpace() == 0). + const msg_len = 20; + const msg_count = 12; + var ctx: server_handlers.PauseUntilBufferedEchoHandler.Context = .{ + .resume_threshold = 256, + .expected_messages = msg_count, + }; + const ts = try servers.startPauseUntilBufferedEchoSmallBufServer( + testing.allocator, + &ctx, + ); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + var expected: [msg_count][msg_len]u8 = undefined; + for (0..msg_count) |i| { + const byte = @as(u8, @truncate('A' + i)); + @memset(&expected[i], byte); + + var payload = expected[i]; + try client.writeFrame(.text, &payload); + } + + for (0..msg_count) |i| { + const response = try client.waitForMessageType(.text, wait_ms); + defer client.done(response); + try testing.expectEqual(@as(usize, msg_len), response.data.len); + try testing.expectEqualSlices(u8, &expected[i], response.data); + } + + try helpers.expectCloseWithCode(&client, 1000); +} diff --git a/src/rpc/webzockets/e2e_tests/server/pool_tests.zig b/src/rpc/webzockets/e2e_tests/server/pool_tests.zig new file mode 100644 index 0000000000..90c618a096 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/pool_tests.zig @@ -0,0 +1,83 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); + +const testing = std.testing; +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e: connection pool exhaustion" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try startLimitedTestServer(testing.allocator); + defer ts.stop(); + + // With max_connections=1, exactly one of the two clients should get the + // connection slot and complete its echo exchange. The other is rejected + // at the server's handshake-to-connection transition (pool exhausted). + // There is no ordering guarantee for which client gets the slot. + + var handler1: clients.EchoTestHandler = .{ + .send_kind = .text, + .send_data = "from_client_1", + .allocator = testing.allocator, + }; + defer handler1.deinit(); + + var handler2: clients.EchoTestHandler = .{ + .send_kind = .text, + .send_data = "from_client_2", + .allocator = testing.allocator, + }; + defer handler2.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn1: clients.TestEchoClient.Conn = undefined; + var client1 = env.initClient(clients.TestEchoClient, &handler1, &conn1, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler1.open_called) conn1.deinit(); + + var conn2: clients.TestEchoClient.Conn = undefined; + var client2 = env.initClient(clients.TestEchoClient, &handler2, &conn2, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler2.open_called) conn2.deinit(); + + try client1.connect(); + try client2.connect(); + try env.loop.run(.until_done); + + const h1_ok = handler1.received_data != null; + const h2_ok = handler2.received_data != null; + + // At least one client must have completed the echo exchange. + try testing.expect(h1_ok or h2_ok); + + if (h1_ok) { + try testing.expectEqualSlices(u8, "from_client_1", handler1.received_data.?); + } + if (h2_ok) { + try testing.expectEqualSlices(u8, "from_client_2", handler2.received_data.?); + } +} + +/// WebSocket server type used for connection-pool exhaustion tests. +const LimitedServer = ws.Server( + servers.EchoHandler, + servers.default_read_buf_size, +); +const LimitedTestServer = servers.ServerRunner(LimitedServer); + +fn startLimitedTestServer(allocator: std.mem.Allocator) !*LimitedTestServer { + const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 0); + return try LimitedTestServer.start(allocator, .{ + .address = address, + .handler_context = {}, + .max_connections = 1, + }); +} diff --git a/src/rpc/webzockets/e2e_tests/server/protocol_error_tests.zig b/src/rpc/webzockets/e2e_tests/server/protocol_error_tests.zig new file mode 100644 index 0000000000..d5c0cfdec6 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/protocol_error_tests.zig @@ -0,0 +1,166 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); + +const testing = std.testing; +const servers = @import("../support/test_servers.zig"); +const RawClient = @import("../support/raw_client.zig").RawClient; +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; +const expectCloseWithCode = @import("../support/test_helpers.zig").expectCloseWithCode; + +test "e2e protocol error: unmasked client frame → 1002" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Send a text frame without the mask bit set — violates RFC 6455 §5.1 + var payload = "hello".*; + try client.writeFrameEx(@intFromEnum(ws.Opcode.text), &payload, .{ .mask = false }); + + try expectCloseWithCode(&client, 1002); +} + +test "e2e protocol error: reserved opcode 0x3 → 1002" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Opcode 0x3 is reserved for future non-control frames — RFC 6455 §5.2 + var payload = "test".*; + try client.writeFrameEx(0x3, &payload, .{}); + + try expectCloseWithCode(&client, 1002); +} + +test "e2e protocol error: RSV1 set without extension → 1002" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // RSV1 must be 0 unless an extension is negotiated — RFC 6455 §5.2 + var payload = "test".*; + try client.writeFrameEx(@intFromEnum(ws.Opcode.text), &payload, .{ .rsv1 = true }); + + try expectCloseWithCode(&client, 1002); +} + +test "e2e protocol error: fragmented ping (FIN=0) → 1002" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Control frames must not be fragmented — RFC 6455 §5.5 + var payload = "ping".*; + try client.writeFrameEx(@intFromEnum(ws.Opcode.ping), &payload, .{ .fin = false }); + + try expectCloseWithCode(&client, 1002); +} + +test "e2e protocol error: ping with 126-byte payload → 1002" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Control frame payloads must be <= 125 bytes — RFC 6455 §5.5 + var payload: [126]u8 = undefined; + @memset(&payload, 'X'); + try client.writeFrameEx(@intFromEnum(ws.Opcode.ping), &payload, .{}); + + try expectCloseWithCode(&client, 1002); +} + +test "e2e protocol error: unexpected continuation frame → 1002" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Continuation without a preceding non-FIN data frame — RFC 6455 §5.4 + var payload = "data".*; + try client.writeFrameEx(@intFromEnum(ws.Opcode.continuation), &payload, .{}); + + try expectCloseWithCode(&client, 1002); +} + +test "e2e protocol error: new data frame during fragmentation → 1002" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Start a fragmented message (text, FIN=0) + var frag1 = "hello".*; + try client.writeFrameEx(@intFromEnum(ws.Opcode.text), &frag1, .{ .fin = false }); + + // Send another text frame instead of continuation — RFC 6455 §5.4 + var frag2 = "world".*; + try client.writeFrameEx(@intFromEnum(ws.Opcode.text), &frag2, .{}); + + try expectCloseWithCode(&client, 1002); +} + +test "e2e protocol error: message exceeding max_message_size → 1009" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try startSmallMaxTestServer(testing.allocator); + defer ts.stop(); + + var client = try RawClient.connect(testing.allocator, ts.port); + defer client.deinit(); + + // Send a message larger than max_message_size (128 bytes) + var payload: [256]u8 = undefined; + @memset(&payload, 'A'); + try client.writeFrameEx(@intFromEnum(ws.Opcode.text), &payload, .{}); + + try expectCloseWithCode(&client, 1009); +} + +/// Server type with small max_message_size for testing oversized messages. +const SmallMaxServer = ws.Server( + servers.EchoHandler, + servers.default_read_buf_size, +); +const SmallMaxTestServer = servers.ServerRunner(SmallMaxServer); + +fn startSmallMaxTestServer(allocator: std.mem.Allocator) !*SmallMaxTestServer { + const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 0); + return try SmallMaxTestServer.start(allocator, .{ + .address = address, + .handler_context = {}, + .max_message_size = 128, + }); +} diff --git a/src/rpc/webzockets/e2e_tests/server/rejection_tests.zig b/src/rpc/webzockets/e2e_tests/server/rejection_tests.zig new file mode 100644 index 0000000000..6796c177ff --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/rejection_tests.zig @@ -0,0 +1,105 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); + +const testing = std.testing; +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +test "e2e: handler rejects connection" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try startRejectingTestServer(testing.allocator); + defer ts.stop(); + + var handler: clients.NoOpHandler = .{}; + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestNoOpClient.Conn = undefined; + var client = env.initClient(clients.TestNoOpClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + .path = "/reject-me", + }); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(!handler.open_called); +} + +test "e2e: handler accepts valid path" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try startRejectingTestServer(testing.allocator); + defer ts.stop(); + + var handler: clients.EchoTestHandler = .{ + .send_kind = .text, + .send_data = "test", + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestEchoClient.Conn = undefined; + var client = env.initClient(clients.TestEchoClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + .path = "/valid-path", + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + const received_data = handler.received_data orelse return error.NoData; + try testing.expectEqualSlices(u8, "test", received_data); +} + +/// Server-side handler that rejects connections to paths starting with "/reject". +const RejectingHandler = struct { + pub const Context = void; + + inner: servers.EchoHandler, + + pub fn init(request: ws.http.Request, _: void) !RejectingHandler { + if (std.mem.startsWith(u8, request.path, "/reject")) { + return error.ConnectionRejected; + } + return .{ .inner = try servers.EchoHandler.init(request, {}) }; + } + + pub fn onMessage(self: *RejectingHandler, conn: anytype, message: ws.Message) void { + self.inner.onMessage(conn, message); + } + + pub fn onWriteComplete(self: *RejectingHandler, conn: anytype) void { + self.inner.onWriteComplete(conn); + } + + pub fn onClose(self: *RejectingHandler, conn: anytype) void { + self.inner.onClose(conn); + } +}; + +const RejectingServer = ws.Server( + RejectingHandler, + servers.default_read_buf_size, +); +const RejectingTestServer = servers.ServerRunner(RejectingServer); + +fn startRejectingTestServer(allocator: std.mem.Allocator) !*RejectingTestServer { + const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 0); + return try RejectingTestServer.start(allocator, .{ + .address = address, + .handler_context = {}, + }); +} diff --git a/src/rpc/webzockets/e2e_tests/server/stress_tests.zig b/src/rpc/webzockets/e2e_tests/server/stress_tests.zig new file mode 100644 index 0000000000..6b9da7930a --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/stress_tests.zig @@ -0,0 +1,483 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); +const xev = @import("xev"); + +const testing = std.testing; +const servers = @import("../support/test_servers.zig"); +const clients = @import("../support/test_clients.zig"); +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; +const verifyServerFunctional = @import("../support/test_helpers.zig").verifyServerFunctional; + +test "e2e/stress: rapid connect/disconnect" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const num_clients = 20; + + var envs: [num_clients]clients.TestEnv = undefined; + var handlers: [num_clients]clients.CloseOnOpenHandler = undefined; + var conns: [num_clients]clients.TestCloseClient.Conn = undefined; + var client_objs: [num_clients]clients.TestCloseClient = undefined; + + // Start all clients + for (0..num_clients) |i| { + envs[i] = undefined; + try envs[i].start(); + handlers[i] = .{}; + client_objs[i] = envs[i].initClient(clients.TestCloseClient, &handlers[i], &conns[i], .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + try client_objs[i].connect(); + } + + // Run all event loops + for (0..num_clients) |i| { + try envs[i].loop.run(.until_done); + } + + // Cleanup runs even if assertions fail (avoids additional fd leak detector noise) + defer for (0..num_clients) |i| { + conns[i].deinit(); + envs[i].deinit(); + }; + + for (0..num_clients) |i| { + try testing.expect(handlers[i].open_called); + } + + // Verify server is still healthy + try verifyServerFunctional(ts.port); +} + +test "e2e/stress: many concurrent echo clients" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const num_loops = 4; + const clients_per_loop = 5; + + const LoopContext = struct { + env: clients.TestEnv, + handlers: [clients_per_loop]clients.SequenceHandler, + conns: [clients_per_loop]clients.TestSequenceClient.Conn, + client_objs: [clients_per_loop]clients.TestSequenceClient, + msg_specs: [clients_per_loop][]clients.SequenceHandler.MsgSpec, + }; + + var contexts: [num_loops]LoopContext = undefined; + var threads: [num_loops]std.Thread = undefined; + + for (0..num_loops) |loop_idx| { + contexts[loop_idx].env = undefined; + try contexts[loop_idx].env.start(); + + for (0..clients_per_loop) |client_idx| { + const msg_specs = try testing.allocator.alloc(clients.SequenceHandler.MsgSpec, 2); + msg_specs[0] = .{ .data = try std.fmt.allocPrint( + testing.allocator, + "loop{d}-client{d}-msg0", + .{ loop_idx, client_idx }, + ) }; + msg_specs[1] = .{ .data = try std.fmt.allocPrint( + testing.allocator, + "loop{d}-client{d}-msg1", + .{ loop_idx, client_idx }, + ) }; + contexts[loop_idx].msg_specs[client_idx] = msg_specs; + + contexts[loop_idx].handlers[client_idx] = .{ + .messages = msg_specs, + .results = std.ArrayList( + clients.SequenceHandler.RecvResult, + ).init(testing.allocator), + .allocator = testing.allocator, + }; + + contexts[loop_idx].client_objs[client_idx] = contexts[loop_idx].env.initClient( + clients.TestSequenceClient, + &contexts[loop_idx].handlers[client_idx], + &contexts[loop_idx].conns[client_idx], + .{ .address = std.net.Address.initIp4( + .{ 127, 0, 0, 1 }, + ts.port, + ) }, + ); + try contexts[loop_idx].client_objs[client_idx].connect(); + } + + threads[loop_idx] = try std.Thread.spawn( + .{}, + runLoopOnThread, + .{&contexts[loop_idx].env.loop}, + ); + } + + // Join all threads + for (0..num_loops) |loop_idx| { + threads[loop_idx].join(); + } + + // Cleanup runs even if assertions fail (avoids additional fd leak detector noise) + defer for (0..num_loops) |loop_idx| { + for (0..clients_per_loop) |client_idx| { + contexts[loop_idx].handlers[client_idx].deinit(); + contexts[loop_idx].conns[client_idx].deinit(); + const msg_specs = contexts[loop_idx].msg_specs[client_idx]; + for (msg_specs) |spec| { + testing.allocator.free(spec.data); + } + testing.allocator.free(msg_specs); + } + contexts[loop_idx].env.deinit(); + }; + + for (0..num_loops) |loop_idx| { + for (0..clients_per_loop) |client_idx| { + const handler = &contexts[loop_idx].handlers[client_idx]; + const msg_specs = contexts[loop_idx].msg_specs[client_idx]; + + try testing.expect(handler.open_called); + try testing.expectEqual(@as(usize, 2), handler.results.items.len); + + try testing.expectEqualSlices(u8, msg_specs[0].data, handler.results.items[0].data); + try testing.expectEqualSlices(u8, msg_specs[1].data, handler.results.items[1].data); + } + } +} + +test "e2e/stress: rapid message burst" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const num_messages = 100; + + // Build message specs + var spec_bufs: [num_messages][10]u8 = undefined; + var specs: [num_messages]clients.SequenceHandler.MsgSpec = undefined; + var spec_lens: [num_messages]usize = undefined; + + for (0..num_messages) |i| { + const result = std.fmt.bufPrint(&spec_bufs[i], "msg-{d}", .{i}) catch unreachable; + spec_lens[i] = result.len; + specs[i] = .{ .data = spec_bufs[i][0..spec_lens[i]] }; + } + + var handler: clients.SequenceHandler = .{ + .messages = &specs, + .results = std.ArrayList(clients.SequenceHandler.RecvResult).init(testing.allocator), + .allocator = testing.allocator, + }; + defer handler.deinit(); + + var env: clients.TestEnv = undefined; + try env.start(); + defer env.deinit(); + + var conn: clients.TestSequenceClient.Conn = undefined; + var client = env.initClient(clients.TestSequenceClient, &handler, &conn, .{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }); + defer if (handler.open_called) conn.deinit(); + + try client.connect(); + try env.loop.run(.until_done); + + try testing.expect(handler.open_called); + try testing.expectEqual(@as(usize, num_messages), handler.results.items.len); + + for (0..num_messages) |i| { + const expected = std.fmt.bufPrint(&spec_bufs[i], "msg-{d}", .{i}) catch unreachable; + try testing.expectEqualSlices(u8, expected, handler.results.items[i].data); + } +} + +test "e2e/stress: mixed operations under load" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + const num_loops = 4; + + // Each loop runs: 1 text echo, 1 binary echo, 1 ping, 1 close-on-open + const MixedLoopContext = struct { + env: clients.TestEnv, + + text_handler: clients.EchoTestHandler, + text_conn: clients.TestEchoClient.Conn, + text_client: clients.TestEchoClient, + + binary_handler: clients.EchoTestHandler, + binary_conn: clients.TestEchoClient.Conn, + binary_client: clients.TestEchoClient, + + ping_handler: clients.EchoTestHandler, + ping_conn: clients.TestEchoClient.Conn, + ping_client: clients.TestEchoClient, + + close_handler: clients.CloseOnOpenHandler, + close_conn: clients.TestCloseClient.Conn, + close_client: clients.TestCloseClient, + }; + + var contexts: [num_loops]MixedLoopContext = undefined; + var threads: [num_loops]std.Thread = undefined; + + for (0..num_loops) |i| { + contexts[i].env = undefined; + try contexts[i].env.start(); + + const opts = ws.Client(clients.EchoTestHandler, 4096).Config{ + .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port), + }; + + // Text echo client + contexts[i].text_handler = .{ + .send_kind = .text, + .send_data = "text-echo", + .allocator = testing.allocator, + }; + contexts[i].text_client = contexts[i].env.initClient( + clients.TestEchoClient, + &contexts[i].text_handler, + &contexts[i].text_conn, + opts, + ); + try contexts[i].text_client.connect(); + + // Binary echo client + contexts[i].binary_handler = .{ + .send_kind = .binary, + .send_data = &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }, + .allocator = testing.allocator, + }; + contexts[i].binary_client = contexts[i].env.initClient( + clients.TestEchoClient, + &contexts[i].binary_handler, + &contexts[i].binary_conn, + opts, + ); + try contexts[i].binary_client.connect(); + + // Ping client + contexts[i].ping_handler = .{ + .send_kind = .ping, + .send_data = "ping-data", + .allocator = testing.allocator, + }; + contexts[i].ping_client = contexts[i].env.initClient( + clients.TestEchoClient, + &contexts[i].ping_handler, + &contexts[i].ping_conn, + opts, + ); + try contexts[i].ping_client.connect(); + + // Close-on-open client + contexts[i].close_handler = .{}; + contexts[i].close_client = contexts[i].env.initClient( + clients.TestCloseClient, + &contexts[i].close_handler, + &contexts[i].close_conn, + .{ .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port) }, + ); + try contexts[i].close_client.connect(); + + threads[i] = try std.Thread.spawn(.{}, runLoopOnThread, .{&contexts[i].env.loop}); + } + + // Join all threads + for (0..num_loops) |i| { + threads[i].join(); + } + + // Cleanup runs even if assertions fail (avoids additional fd leak detector noise) + defer for (0..num_loops) |i| { + contexts[i].text_handler.deinit(); + contexts[i].text_conn.deinit(); + contexts[i].binary_handler.deinit(); + contexts[i].binary_conn.deinit(); + contexts[i].ping_handler.deinit(); + contexts[i].ping_conn.deinit(); + contexts[i].close_conn.deinit(); + contexts[i].env.deinit(); + }; + + for (0..num_loops) |i| { + // Text echo + try testing.expect(contexts[i].text_handler.open_called); + const text_type = contexts[i].text_handler.received_type orelse return error.NoData; + const text_data = contexts[i].text_handler.received_data orelse return error.NoData; + try testing.expectEqual(.text, text_type); + try testing.expectEqualSlices(u8, "text-echo", text_data); + + // Binary echo + try testing.expect(contexts[i].binary_handler.open_called); + const bin_type = contexts[i].binary_handler.received_type orelse return error.NoData; + const bin_data = contexts[i].binary_handler.received_data orelse return error.NoData; + try testing.expectEqual(.binary, bin_type); + try testing.expectEqualSlices(u8, &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }, bin_data); + + // Ping/pong + try testing.expect(contexts[i].ping_handler.open_called); + const ping_type = contexts[i].ping_handler.received_type orelse return error.NoData; + const ping_data = contexts[i].ping_handler.received_data orelse return error.NoData; + try testing.expectEqual(.pong, ping_type); + try testing.expectEqualSlices(u8, "ping-data", ping_data); + + // Close-on-open + try testing.expect(contexts[i].close_handler.open_called); + } + + // Verify server still healthy + try verifyServerFunctional(ts.port); +} + +test "e2e/stress: randomized concurrent echo" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startTestServer(testing.allocator); + defer ts.stop(); + + // Seed from crypto random for non-deterministic runs; log seed for reproducibility + var seed_bytes: [8]u8 = undefined; + std.crypto.random.bytes(&seed_bytes); + const seed = std.mem.readInt(u64, &seed_bytes, .little); + std.debug.print("\n[randomized concurrent echo] seed={d}\n", .{seed}); + + var prng = std.Random.DefaultPrng.init(seed); + const random = prng.random(); + + // Randomize dimensions + const num_loops = random.intRangeAtMost(usize, 1, 4); + const max_clients_per_loop = 5; + const max_messages_per_client = 4; + const max_payload_len = 256; + + const allocator = testing.allocator; + + const LoopState = struct { + env: clients.TestEnv, + num_clients: usize, + handlers: [max_clients_per_loop]clients.SequenceHandler, + conns: [max_clients_per_loop]clients.TestSequenceClient.Conn, + client_objs: [max_clients_per_loop]clients.TestSequenceClient, + num_messages: [max_clients_per_loop]usize, + msg_specs: [max_clients_per_loop][]clients.SequenceHandler.MsgSpec, + payload_bufs: [max_clients_per_loop][max_messages_per_client][]u8, + }; + + var loop_states = try allocator.alloc(LoopState, num_loops); + defer allocator.free(loop_states); + + var threads = try allocator.alloc(std.Thread, num_loops); + defer allocator.free(threads); + + for (0..num_loops) |loop_idx| { + loop_states[loop_idx].env = undefined; + try loop_states[loop_idx].env.start(); + + const nc = random.intRangeAtMost(usize, 1, max_clients_per_loop); + loop_states[loop_idx].num_clients = nc; + + for (0..nc) |client_idx| { + const nm = random.intRangeAtMost(usize, 1, max_messages_per_client); + loop_states[loop_idx].num_messages[client_idx] = nm; + + const msg_specs = try allocator.alloc(clients.SequenceHandler.MsgSpec, nm); + loop_states[loop_idx].msg_specs[client_idx] = msg_specs; + + for (0..nm) |msg_idx| { + const payload_len = random.intRangeAtMost(usize, 1, max_payload_len); + const payload = try allocator.alloc(u8, payload_len); + // Fill with random printable bytes to avoid UTF-8 issues with text frames + for (payload) |*b| { + b.* = random.intRangeAtMost(u8, 0x20, 0x7E); + } + loop_states[loop_idx].payload_bufs[client_idx][msg_idx] = payload; + msg_specs[msg_idx] = .{ .data = payload }; + } + + loop_states[loop_idx].handlers[client_idx] = .{ + .messages = msg_specs, + .results = std.ArrayList( + clients.SequenceHandler.RecvResult, + ).init(allocator), + .allocator = allocator, + }; + + loop_states[loop_idx].client_objs[client_idx] = loop_states[loop_idx].env.initClient( + clients.TestSequenceClient, + &loop_states[loop_idx].handlers[client_idx], + &loop_states[loop_idx].conns[client_idx], + .{ .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, ts.port) }, + ); + try loop_states[loop_idx].client_objs[client_idx].connect(); + } + + threads[loop_idx] = try std.Thread.spawn( + .{}, + runLoopOnThread, + .{&loop_states[loop_idx].env.loop}, + ); + } + + // Join all threads + for (0..num_loops) |loop_idx| { + threads[loop_idx].join(); + } + + // Cleanup runs even if assertions fail + defer for (0..num_loops) |loop_idx| { + const nc = loop_states[loop_idx].num_clients; + for (0..nc) |client_idx| { + loop_states[loop_idx].handlers[client_idx].deinit(); + if (loop_states[loop_idx].handlers[client_idx].open_called) { + loop_states[loop_idx].conns[client_idx].deinit(); + } + const nm = loop_states[loop_idx].num_messages[client_idx]; + for (0..nm) |msg_idx| { + allocator.free(loop_states[loop_idx].payload_bufs[client_idx][msg_idx]); + } + allocator.free(loop_states[loop_idx].msg_specs[client_idx]); + } + loop_states[loop_idx].env.deinit(); + }; + + for (0..num_loops) |loop_idx| { + const nc = loop_states[loop_idx].num_clients; + for (0..nc) |client_idx| { + const handler = &loop_states[loop_idx].handlers[client_idx]; + const nm = loop_states[loop_idx].num_messages[client_idx]; + const msg_specs = loop_states[loop_idx].msg_specs[client_idx]; + + try testing.expect(handler.open_called); + try testing.expectEqual(nm, handler.results.items.len); + + for (0..nm) |msg_idx| { + try testing.expectEqualSlices( + u8, + msg_specs[msg_idx].data, + handler.results.items[msg_idx].data, + ); + } + } + } +} + +fn runLoopOnThread(loop: *xev.Loop) void { + loop.run(.until_done) catch |err| { + std.debug.panic("event loop failed on thread: {}", .{err}); + }; +} diff --git a/src/rpc/webzockets/e2e_tests/server/tests.zig b/src/rpc/webzockets/e2e_tests/server/tests.zig new file mode 100644 index 0000000000..0cad289115 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/tests.zig @@ -0,0 +1,24 @@ +comptime { + // Handshake & connection setup + _ = @import("handshake_tests.zig"); + _ = @import("rejection_tests.zig"); + + // Protocol compliance + _ = @import("close_tests.zig"); + _ = @import("protocol_error_tests.zig"); + _ = @import("fragmentation_tests.zig"); + + // Messaging + _ = @import("echo_tests.zig"); + _ = @import("pause_resume_tests.zig"); + _ = @import("buffer_tier_tests.zig"); + + // Resources + _ = @import("pool_tests.zig"); + + // Stress/load + _ = @import("stress_tests.zig"); + + // Timeouts + _ = @import("timeout_tests.zig"); +} diff --git a/src/rpc/webzockets/e2e_tests/server/timeout_tests.zig b/src/rpc/webzockets/e2e_tests/server/timeout_tests.zig new file mode 100644 index 0000000000..00281a2325 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/server/timeout_tests.zig @@ -0,0 +1,189 @@ +const std = @import("std"); +const testing = std.testing; + +const servers = @import("../support/test_servers.zig"); +const RawClient = @import("../support/raw_client.zig").RawClient; +const FdLeakDetector = @import("../support/fd_leak.zig").FdLeakDetector; + +const poll_read_timeout_ms: u32 = 100; +const close_deadline_ms: u64 = 2_000; + +test "close timeout: server disconnects when peer ignores close response" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // Server echoes first message then closes; close timeout is 200ms + const ts = try servers.startCloseAfterFirstMessageServer(testing.allocator, 200); + defer ts.stop(); + + var client = try RawClient.connectEx(testing.allocator, ts.port, .{ + .read_timeout_ms = poll_read_timeout_ms, + }); + defer client.deinit(); + + // Send a message to trigger echo + server close + var msg = "hello".*; + try client.write(&msg); + + // Read the echo + const echo = try client.waitForMessageType(.text, close_deadline_ms); + try testing.expectEqualSlices(u8, "hello", echo.data); + client.done(echo); + + // Read the close frame from the server + const close_msg = try client.waitForCloseFrame(close_deadline_ms); + try testing.expectEqual(.close, close_msg.type); + client.done(close_msg); + + // Do NOT echo the close frame — just wait for the server to force disconnect. + try client.waitForClosedNoData(close_deadline_ms); +} + +test "idle timeout: server closes idle connection" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startEchoServerWithTimeouts(testing.allocator, 200, 200); + defer ts.stop(); + + var client = try RawClient.connectEx(testing.allocator, ts.port, .{ + .read_timeout_ms = poll_read_timeout_ms, + }); + defer client.deinit(); + + // Do nothing after handshake — wait for the server to send a close frame + const close_msg = try client.waitForCloseFrame(close_deadline_ms); + defer client.done(close_msg); + try testing.expectEqual(.close, close_msg.type); + + // Verify the close code is 1001 (going_away) + try testing.expect(close_msg.data.len >= 2); + const code = std.mem.readInt(u16, close_msg.data[0..2], .big); + try testing.expectEqual(@as(u16, 1001), code); + + // Echo the close frame to complete the handshake cleanly + try client.close(.{ .code = 1001 }); + + // Verify TCP connection is closed (distinguish close from transient timeouts). + try client.waitForClosedNoData(close_deadline_ms); +} + +test "idle timeout: activity resets timer" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // Worst-case close arrives 2 × idle_timeout after last message. + const ts = try servers.startEchoServerWithTimeouts(testing.allocator, 200, 5000); + defer ts.stop(); + + var client = try RawClient.connectEx(testing.allocator, ts.port, .{ + .read_timeout_ms = poll_read_timeout_ms, + }); + defer client.deinit(); + + // Send messages at 50ms intervals for ~500ms, well past idle timeout + var i: usize = 0; + while (i < 10) : (i += 1) { + std.time.sleep(50 * std.time.ns_per_ms); + var msg = "ping".*; + try client.write(&msg); + + const echo = try client.waitForMessageType(.text, close_deadline_ms); + try testing.expectEqualSlices(u8, "ping", echo.data); + client.done(echo); + } + + // Wait for idle timeout close + const close_msg = try client.waitForCloseFrame(close_deadline_ms); + defer client.done(close_msg); + try testing.expectEqual(.close, close_msg.type); + + // Verify the close code is 1001 (going_away) + try testing.expect(close_msg.data.len >= 2); + const code = std.mem.readInt(u16, close_msg.data[0..2], .big); + try testing.expectEqual(@as(u16, 1001), code); + + // Echo close to complete handshake + try client.close(.{ .code = 1001 }); +} + +test "idle timeout into close timeout" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startEchoServerWithTimeouts(testing.allocator, 200, 200); + defer ts.stop(); + + var client = try RawClient.connectEx(testing.allocator, ts.port, .{ + .read_timeout_ms = poll_read_timeout_ms, + }); + defer client.deinit(); + + // Do nothing — wait for idle timeout close frame + const close_msg = try client.waitForCloseFrame(close_deadline_ms); + defer client.done(close_msg); + try testing.expectEqual(.close, close_msg.type); + + // Do NOT echo the close frame — let the close timeout fire. + try client.waitForClosedNoData(close_deadline_ms); +} + +test "close in onOpen with idle timeout configured" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + // Server calls close() in onOpen. idle_timeout_ms is configured, which + // previously would overwrite the close timer started by close(). The + // close timeout (200ms) should apply, not the idle timeout (5000ms). + const ts = try servers.startCloseOnOpenServerWithTimeouts(testing.allocator, 5000, 200); + defer ts.stop(); + + var client = try RawClient.connectEx(testing.allocator, ts.port, .{ + .read_timeout_ms = poll_read_timeout_ms, + }); + defer client.deinit(); + + // Server closes immediately on open — read the close frame + const close_msg = try client.waitForCloseFrame(close_deadline_ms); + defer client.done(close_msg); + try testing.expectEqual(.close, close_msg.type); + try testing.expect(close_msg.data.len >= 2); + const code = std.mem.readInt(u16, close_msg.data[0..2], .big); + try testing.expectEqual(@as(u16, 1000), code); + + // Do NOT echo the close frame — let the close timeout fire. + // If idle timeout (5000ms) were used by mistake, this bounded wait would fail. + try client.waitForClosedNoData(close_deadline_ms); +} + +test "normal close still works with timeouts enabled" { + const fd_check = FdLeakDetector.baseline(); + defer fd_check.assertNoLeaks(); + + const ts = try servers.startEchoServerWithTimeouts(testing.allocator, 500, 500); + defer ts.stop(); + + var client = try RawClient.connectEx(testing.allocator, ts.port, .{ + .read_timeout_ms = poll_read_timeout_ms, + }); + defer client.deinit(); + + // Send a message to verify connection works + var msg = "test".*; + try client.write(&msg); + + const echo = try client.waitForMessageType(.text, close_deadline_ms); + try testing.expectEqualSlices(u8, "test", echo.data); + client.done(echo); + + // Initiate a normal close handshake (peer-initiated) + try client.close(.{ .code = 1000, .reason = "goodbye" }); + + // Read the server's close echo + const close_msg = try client.waitForCloseFrame(close_deadline_ms); + defer client.done(close_msg); + try testing.expectEqual(.close, close_msg.type); + try testing.expect(close_msg.data.len >= 2); + const code = std.mem.readInt(u16, close_msg.data[0..2], .big); + try testing.expectEqual(@as(u16, 1000), code); +} diff --git a/src/rpc/webzockets/e2e_tests/support/client_handlers.zig b/src/rpc/webzockets/e2e_tests/support/client_handlers.zig new file mode 100644 index 0000000000..caf4322ac7 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/support/client_handlers.zig @@ -0,0 +1,587 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); + +/// Client-side handler that closes immediately on open. +/// Useful for testing close handshake and connection pool exhaustion. +pub const CloseOnOpenHandler = struct { + open_called: bool = false, + close_called: bool = false, + + pub fn onOpen(self: *CloseOnOpenHandler, conn: anytype) void { + self.open_called = true; + conn.close(.normal, ""); + } + + pub fn onMessage(_: *CloseOnOpenHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *CloseOnOpenHandler, _: anytype) void {} + + pub fn onClose(self: *CloseOnOpenHandler, _: anytype) void { + self.close_called = true; + } +}; + +/// Client-side handler that tracks open_called but takes no action. +/// Useful for testing rejection scenarios where onOpen should not be called. +pub const NoOpHandler = struct { + open_called: bool = false, + + pub fn onOpen(self: *NoOpHandler, _: anytype) void { + self.open_called = true; + } + + pub fn onMessage(_: *NoOpHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *NoOpHandler, _: anytype) void {} + pub fn onClose(_: *NoOpHandler, _: anytype) void {} +}; + +/// Client-side handler that waits for the server to close. +/// Tracks that onClose was called and that the client never initiated the close. +/// Used by client close_tests.zig for server-initiated close tests. +pub const ServerCloseHandler = struct { + close_called: bool = false, + open_called: bool = false, + + pub fn onOpen(self: *ServerCloseHandler, _: anytype) void { + self.open_called = true; + } + + pub fn onMessage(_: *ServerCloseHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *ServerCloseHandler, _: anytype) void {} + + pub fn onClose(self: *ServerCloseHandler, _: anytype) void { + self.close_called = true; + } +}; + +/// Client-side handler that tracks pong reception and then closes. +/// Used by client ping_pong_tests.zig to verify the onPong callback fires. +pub const PongTrackingHandler = struct { + pong_received: bool = false, + pong_data: ?[]const u8 = null, + open_called: bool = false, + allocator: std.mem.Allocator, + + pub fn deinit(self: *PongTrackingHandler) void { + if (self.pong_data) |data| { + self.allocator.free(data); + self.pong_data = null; + } + } + + pub fn onOpen(self: *PongTrackingHandler, _: anytype) void { + self.open_called = true; + } + + pub fn onMessage(_: *PongTrackingHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *PongTrackingHandler, _: anytype) void {} + + pub fn onPong(self: *PongTrackingHandler, conn: anytype, data: []const u8) void { + self.pong_received = true; + self.pong_data = self.allocator.dupe(u8, data) catch null; + conn.close(.normal, ""); + } + + pub fn onClose(_: *PongTrackingHandler, _: anytype) void {} +}; + +/// Client-side handler that tracks whether onClose was called. +/// Does not send or receive data. Used for max_message_size tests where the +/// server sends an oversized message and the client is expected to close with 1009. +pub const MaxMessageHandler = struct { + open_called: bool = false, + close_called: bool = false, + message_received: bool = false, + + pub fn onOpen(self: *MaxMessageHandler, _: anytype) void { + self.open_called = true; + } + + pub fn onMessage(self: *MaxMessageHandler, _: anytype, _: ws.Message) void { + self.message_received = true; + } + + pub fn onWriteComplete(_: *MaxMessageHandler, _: anytype) void {} + + pub fn onClose(self: *MaxMessageHandler, _: anytype) void { + self.close_called = true; + } +}; + +/// Client-side handler that tracks socket-level close for connection failure tests. +/// Implements onSocketClose to detect TCP/handshake failures. +pub const ConnectFailHandler = struct { + open_called: bool = false, + socket_close_called: bool = false, + + pub fn onOpen(self: *ConnectFailHandler, _: anytype) void { + self.open_called = true; + } + + pub fn onMessage(_: *ConnectFailHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *ConnectFailHandler, _: anytype) void {} + pub fn onClose(_: *ConnectFailHandler, _: anytype) void {} + + pub fn onSocketClose(self: *ConnectFailHandler) void { + self.socket_close_called = true; + } +}; + +/// Client-side handler that explicitly responds to pings via onPing. +/// Used to test that declaring onPing disables auto-pong and the handler +/// can manually send pong. +pub const ExplicitPongHandler = struct { + open_called: bool = false, + close_called: bool = false, + ping_received: bool = false, + + pub fn onOpen(self: *ExplicitPongHandler, _: anytype) void { + self.open_called = true; + } + + pub fn onPing(self: *ExplicitPongHandler, conn: anytype, data: []const u8) void { + self.ping_received = true; + conn.sendPong(data) catch |err| std.debug.panic("sendPong failed: {}", .{err}); + } + + pub fn onMessage(_: *ExplicitPongHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *ExplicitPongHandler, _: anytype) void {} + + pub fn onClose(self: *ExplicitPongHandler, _: anytype) void { + self.close_called = true; + } +}; + +/// Common client-side handler for e2e tests. +/// +/// Behavior: +/// - Optionally sends one message/ping on open. +/// - Captures the first message/pong. +/// - Initiates a normal close. +/// +/// Owns any allocations made for sending/capturing and provides `deinit()` +/// so tests can clean up consistently. +pub const SendOnceHandler = struct { + pub const SendKind = enum { + none, + text, + binary, + ping, + }; + + /// What to send on open. + send_kind: SendKind = .none, + + /// Payload to send (required unless `send_kind == .none`). + send_data: ?[]const u8 = null, + + /// Captured data from the first received message/pong. + received_data: ?[]const u8 = null, + received_type: ?ws.Message.Type = null, + + open_called: bool = false, + + allocator: std.mem.Allocator, + + /// Owned copy used for text/binary sends. + sent_data: ?[]const u8 = null, + + pub fn deinit(self: *SendOnceHandler) void { + if (self.received_data) |data| { + self.allocator.free(data); + self.received_data = null; + } + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } + + pub fn onOpen(self: *SendOnceHandler, conn: anytype) void { + self.open_called = true; + + const kind = self.send_kind; + if (kind == .none) return; + const data = self.send_data orelse return; + + switch (kind) { + .ping => { + // sendPing copies the payload into an internal queue (no onWriteComplete), + // so no allocation or lifetime management is needed. + conn.sendPing(data) catch return; + }, + .text => { + const copy = self.allocator.dupe(u8, data) catch return; + conn.sendText(copy) catch { + self.allocator.free(copy); + return; + }; + self.sent_data = copy; + }, + .binary => { + const copy = self.allocator.dupe(u8, data) catch return; + conn.sendBinary(copy) catch { + self.allocator.free(copy); + return; + }; + self.sent_data = copy; + }, + .none => {}, + } + } + + pub fn onMessage(self: *SendOnceHandler, conn: anytype, message: ws.Message) void { + self.received_data = self.allocator.dupe(u8, message.data) catch null; + self.received_type = message.type; + conn.close(.normal, ""); + } + + pub fn onPong(self: *SendOnceHandler, conn: anytype, data: []const u8) void { + self.received_data = self.allocator.dupe(u8, data) catch null; + self.received_type = .pong; + conn.close(.normal, ""); + } + + pub fn onWriteComplete(self: *SendOnceHandler, _: anytype) void { + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } + + pub fn onClose(self: *SendOnceHandler, _: anytype) void { + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } +}; + +/// Client-side handler that sends multiple messages sequentially. +/// Sends the next message only after the previous write completes and echo is received. +/// Supports mixed text/binary messages via MsgSpec. +pub const SequenceHandler = struct { + /// Message specification for sending. + pub const MsgSpec = struct { + data: []const u8, + is_binary: bool = false, + }; + + /// Captured result for a received message. + pub const RecvResult = struct { + data: []const u8, + len: usize, + }; + + messages: []const MsgSpec, + send_index: usize = 0, + recv_index: usize = 0, + results: std.ArrayList(RecvResult), + allocator: std.mem.Allocator, + sent_data: ?[]const u8 = null, + open_called: bool = false, + + pub fn deinit(self: *SequenceHandler) void { + for (self.results.items) |item| { + self.allocator.free(item.data); + } + self.results.deinit(); + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } + + pub fn onOpen(self: *SequenceHandler, conn: anytype) void { + self.open_called = true; + self.maybeSendNext(conn); + } + + fn maybeSendNext(self: *SequenceHandler, conn: anytype) void { + if (self.sent_data != null) return; + if (self.recv_index < self.send_index) return; + if (self.send_index >= self.messages.len) return; + + const spec = self.messages[self.send_index]; + const copy = self.allocator.dupe(u8, spec.data) catch return; + if (spec.is_binary) { + conn.sendBinary(copy) catch { + self.allocator.free(copy); + return; + }; + } else { + conn.sendText(copy) catch { + self.allocator.free(copy); + return; + }; + } + self.sent_data = copy; + self.send_index += 1; + } + + pub fn onMessage(self: *SequenceHandler, conn: anytype, message: ws.Message) void { + const copy = self.allocator.dupe(u8, message.data) catch return; + self.results.append(.{ .data = copy, .len = message.data.len }) catch { + self.allocator.free(copy); + return; + }; + self.recv_index += 1; + if (self.recv_index >= self.messages.len) { + conn.close(.normal, ""); + } else { + self.maybeSendNext(conn); + } + } + + pub fn onWriteComplete(self: *SequenceHandler, conn: anytype) void { + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + self.maybeSendNext(conn); + } + + pub fn onClose(self: *SequenceHandler, _: anytype) void { + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } +}; + +/// Client-side handler that pauses reads on open and waits for enough data +/// to accumulate in the read buffer before resuming. Uses `onBytesRead` to +/// observe raw TCP data arrival while paused. This makes burst-processing +/// tests deterministic by ensuring all expected messages are buffered before +/// dispatch begins. +pub const PauseUntilBufferedClientHandler = struct { + pub const RecvResult = struct { + data: []const u8, + }; + + allocator: std.mem.Allocator, + /// Number of messages expected before closing. 0 = don't auto-close. + expected_messages: usize = 0, + /// Minimum bytes that must be buffered before reads are resumed. + resume_threshold: usize = 0, + results: std.ArrayList(RecvResult), + open_called: bool = false, + close_called: bool = false, + + pub fn deinit(self: *PauseUntilBufferedClientHandler) void { + for (self.results.items) |item| { + self.allocator.free(item.data); + } + self.results.deinit(); + } + + pub fn onOpen(self: *PauseUntilBufferedClientHandler, conn: anytype) void { + self.open_called = true; + conn.pauseReads(); + } + + pub fn onBytesRead(self: *PauseUntilBufferedClientHandler, conn: anytype, _: usize) void { + if (conn.peekBufferedBytes().len >= self.resume_threshold) { + conn.resumeReads(); + } + } + + pub fn onMessage( + self: *PauseUntilBufferedClientHandler, + conn: anytype, + message: ws.Message, + ) void { + switch (message.type) { + .text, .binary => {}, + else => return, + } + + const copy = self.allocator.dupe(u8, message.data) catch return; + self.results.append(.{ .data = copy }) catch { + self.allocator.free(copy); + return; + }; + + if (self.expected_messages != 0 and self.results.items.len >= self.expected_messages) { + conn.close(.normal, "done"); + } + } + + pub fn onWriteComplete(_: *PauseUntilBufferedClientHandler, _: anytype) void {} + + pub fn onClose(self: *PauseUntilBufferedClientHandler, _: anytype) void { + self.close_called = true; + } +}; + +/// Client-side handler that pauses on open, waits for a byte threshold, then +/// resumes. After the initial resume, each onMessage pauses reads, records the +/// message, sends an echo/ack, and resumes in onWriteComplete. This tests the +/// processMessages loop breaking when read_paused is set mid-loop, and +/// re-entering from onWriteComplete via resumeReads. +pub const PauseMidStreamClientHandler = struct { + pub const RecvResult = struct { + data: []const u8, + }; + + allocator: std.mem.Allocator, + expected_messages: usize = 0, + resume_threshold: usize = 0, + results: std.ArrayList(RecvResult), + sent_data: ?[]u8 = null, + initial_resumed: bool = false, + open_called: bool = false, + close_called: bool = false, + + pub fn deinit(self: *PauseMidStreamClientHandler) void { + for (self.results.items) |item| { + self.allocator.free(item.data); + } + self.results.deinit(); + + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } + + pub fn onOpen(self: *PauseMidStreamClientHandler, conn: anytype) void { + self.open_called = true; + conn.pauseReads(); + } + + pub fn onBytesRead(self: *PauseMidStreamClientHandler, conn: anytype, _: usize) void { + if (!self.initial_resumed and conn.peekBufferedBytes().len >= self.resume_threshold) { + self.initial_resumed = true; + conn.resumeReads(); + } + } + + pub fn onMessage(self: *PauseMidStreamClientHandler, conn: anytype, message: ws.Message) void { + switch (message.type) { + .text, .binary => {}, + else => return, + } + + conn.pauseReads(); + + const copy = self.allocator.dupe(u8, message.data) catch { + conn.resumeReads(); + return; + }; + self.results.append(.{ .data = copy }) catch { + self.allocator.free(copy); + conn.resumeReads(); + return; + }; + + // Send an ack to trigger onWriteComplete where we resume reads. + const ack = self.allocator.dupe(u8, message.data) catch { + conn.resumeReads(); + return; + }; + + conn.sendText(ack) catch { + self.allocator.free(ack); + conn.resumeReads(); + return; + }; + + self.sent_data = ack; + } + + pub fn onWriteComplete(self: *PauseMidStreamClientHandler, conn: anytype) void { + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + + if (self.expected_messages != 0 and self.results.items.len >= self.expected_messages) { + conn.close(.normal, "done"); + return; + } + + conn.resumeReads(); + } + + pub fn onClose(self: *PauseMidStreamClientHandler, _: anytype) void { + self.close_called = true; + if (self.sent_data) |data| { + self.allocator.free(data); + self.sent_data = null; + } + } +}; + +/// Client-side handler that detects re-entrant onMessage dispatch. +/// +/// Pauses reads in onOpen and waits (via onBytesRead + peekBufferedBytes) +/// until all expected messages are buffered. Once resumed, all messages +/// dispatch synchronously. In each onMessage the handler calls +/// pauseReads() then resumeReads() — without the re-entrancy guard this +/// would recursively dispatch the next buffered message while still inside +/// onMessage. Closes with policy_violation if re-entrancy is detected. +pub const ReentrancyDetectClientHandler = struct { + pub const RecvResult = struct { + data: []const u8, + }; + + allocator: std.mem.Allocator, + /// Minimum bytes that must be buffered before reads are resumed. + resume_threshold: usize = 0, + results: std.ArrayList(RecvResult), + in_on_message: bool = false, + reentrant_detected: bool = false, + open_called: bool = false, + close_called: bool = false, + + pub fn deinit(self: *ReentrancyDetectClientHandler) void { + for (self.results.items) |item| { + self.allocator.free(item.data); + } + self.results.deinit(); + } + + pub fn onOpen(self: *ReentrancyDetectClientHandler, conn: anytype) void { + self.open_called = true; + conn.pauseReads(); + } + + pub fn onBytesRead(self: *ReentrancyDetectClientHandler, conn: anytype, _: usize) void { + if (conn.peekBufferedBytes().len >= self.resume_threshold) { + conn.resumeReads(); + } + } + + pub fn onMessage(self: *ReentrancyDetectClientHandler, conn: anytype, message: ws.Message) void { + switch (message.type) { + .text, .binary => {}, + else => return, + } + + if (self.in_on_message) { + self.reentrant_detected = true; + conn.close(.policy_violation, "reentrant"); + return; + } + self.in_on_message = true; + defer self.in_on_message = false; + + const copy = self.allocator.dupe(u8, message.data) catch return; + self.results.append(.{ .data = copy }) catch { + self.allocator.free(copy); + return; + }; + + // Exercise the re-entrancy guard: without it, resumeReads() would + // recursively call processMessages() and dispatch the next buffered + // message before we return. + conn.pauseReads(); + conn.resumeReads(); + } + + pub fn onWriteComplete(_: *ReentrancyDetectClientHandler, _: anytype) void {} + + pub fn onClose(self: *ReentrancyDetectClientHandler, _: anytype) void { + self.close_called = true; + } +}; diff --git a/src/rpc/webzockets/e2e_tests/support/fd_leak.zig b/src/rpc/webzockets/e2e_tests/support/fd_leak.zig new file mode 100644 index 0000000000..04b4349630 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/support/fd_leak.zig @@ -0,0 +1,132 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const native_os = builtin.os.tag; + +/// Utility to detect file descriptor leaks in tests. +/// +/// Snapshots the open FD count at baseline and asserts no new FDs are leaked +/// after cleanup. FD counts are process-wide, so tests using this should be +/// run serially to avoid false positives from concurrent tests. +/// +/// Usage: +/// ```zig +/// const fd_check = FdLeakDetector.baseline(); +/// defer fd_check.assertNoLeaks(); +/// // ... test body ... +/// ``` +pub const FdLeakDetector = struct { + baseline_count: usize, + + /// Snapshot current open FD count. Call as first line of test. + pub fn baseline() FdLeakDetector { + return .{ .baseline_count = countOpenFds() }; + } + + /// Assert FD count matches baseline. Panics with count delta on failure. + pub fn assertNoLeaks(self: *const FdLeakDetector) void { + const current = countOpenFds(); + if (current != self.baseline_count) { + const cur: isize = @intCast(current); + const base: isize = @intCast(self.baseline_count); + const delta: isize = cur - base; + std.debug.panic( + "FD leak detected: {d} more FDs open " ++ + "than at baseline (baseline={d}, current={d})", + .{ delta, base, current }, + ); + } + } +}; + +fn countOpenFds() usize { + if (native_os == .macos) { + return countOpenFdsDarwin(); + } else if (native_os == .linux) { + return countOpenFdsLinux(); + } else { + @compileError("FD leak detection not supported on this platform"); + } +} + +/// macOS `proc_pidinfo` flavor to list open file descriptors. +const PROC_PIDLISTFDS: c_int = 1; + +const proc_fdinfo = extern struct { + proc_fd: i32, + proc_fdtype: u32, +}; + +extern "c" fn proc_pidinfo( + pid: c_int, + flavor: c_int, + arg: u64, + buffer: ?*anyopaque, + buffersize: c_int, +) c_int; + +fn countOpenFdsDarwin() usize { + // Two-call pattern: first call with null buffer returns the FD table capacity + // (in bytes). Second call with a real buffer returns the actual bytes written + // for open FDs only. + const pid = std.posix.system.getpid(); + const buf_size = proc_pidinfo(pid, PROC_PIDLISTFDS, 0, null, 0); + if (buf_size <= 0) { + std.debug.panic("proc_pidinfo(PROC_PIDLISTFDS) sizing call failed", .{}); + } + + const buf = std.heap.c_allocator.alloc(u8, @intCast(buf_size)) catch { + std.debug.panic("Failed to allocate buffer for proc_pidinfo", .{}); + }; + defer std.heap.c_allocator.free(buf); + + const actual_bytes = proc_pidinfo(pid, PROC_PIDLISTFDS, 0, buf.ptr, buf_size); + if (actual_bytes <= 0) { + std.debug.panic("proc_pidinfo(PROC_PIDLISTFDS) list call failed", .{}); + } + return @as(usize, @intCast(actual_bytes)) / @sizeOf(proc_fdinfo); +} + +fn countOpenFdsLinux() usize { + var count: usize = 0; + var dir = std.fs.openDirAbsolute("/proc/self/fd", .{ .iterate = true }) catch { + std.debug.panic("Failed to open /proc/self/fd for FD counting", .{}); + }; + defer dir.close(); + + var iter = dir.iterate(); + while (iter.next() catch null) |_| { + count += 1; + } + // Subtract 1 for the directory FD itself (opened above). + return count -| 1; +} + +test "FdLeakDetector: no leak when no FDs opened" { + const detector = FdLeakDetector.baseline(); + // No FDs opened — should not panic. + detector.assertNoLeaks(); +} + +test "FdLeakDetector: detects leaked FD" { + const detector = FdLeakDetector.baseline(); + + // Open a file to leak an FD. + const leaked_fd = std.posix.open("/dev/null", .{}, 0) catch return; + // Don't close it — simulate a leak. + + const current = countOpenFds(); + // Verify our counting works: current should be > baseline. + std.debug.assert(current > detector.baseline_count); + + // Clean up so we don't actually leak. + std.posix.close(leaked_fd); + + // Now it should pass. + detector.assertNoLeaks(); +} + +test "countOpenFds returns reasonable value" { + const count = countOpenFds(); + // A running process should have at least stdin, stdout, stderr. + std.debug.assert(count >= 3); +} diff --git a/src/rpc/webzockets/e2e_tests/support/raw_client.zig b/src/rpc/webzockets/e2e_tests/support/raw_client.zig new file mode 100644 index 0000000000..8692808558 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/support/raw_client.zig @@ -0,0 +1,370 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); + +const frame = ws.frame; +const mask_mod = ws.mask; +const http = ws.http; +const Opcode = ws.Opcode; +const Message = ws.Message; + +/// Blocking, frame-level WebSocket client for testing close handshake behavior +/// against the webzockets server. Connects via TCP, performs HTTP upgrade, then +/// provides frame-level send/receive with client masking. +pub const RawClient = struct { + stream: std.net.Stream, + allocator: std.mem.Allocator, + read_buf: []u8, + read_pos: usize, // start of unconsumed data + read_end: usize, // end of valid data in read_buf + + pub const default_read_buf_size: usize = 4096; + + pub const ConnectOpts = struct { + read_buf_size: usize = default_read_buf_size, + read_timeout_ms: u32 = 2000, + }; + + // ==================================================================== + // Connect + Lifecycle + // ==================================================================== + + /// Blocking connect and WebSocket handshake to 127.0.0.1:port. + /// Uses default 4096-byte read buffer. + pub fn connect(allocator: std.mem.Allocator, port: u16) !RawClient { + return connectEx(allocator, port, .{}); + } + + /// Blocking connect with configurable options (e.g. larger read buffer + /// for tests that receive large echoed messages). + pub fn connectEx(allocator: std.mem.Allocator, port: u16, opts: ConnectOpts) !RawClient { + const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, port); + const stream = try std.net.tcpConnectToAddress(address); + errdefer stream.close(); + + // Set SO_RCVTIMEO for blocking reads (defaults to 2s) to reduce scheduler-jitter flakes. + const timeout = std.posix.timeval{ + .sec = @intCast(opts.read_timeout_ms / 1000), + .usec = @intCast((opts.read_timeout_ms % 1000) * 1000), + }; + try std.posix.setsockopt( + stream.handle, + std.posix.SOL.SOCKET, + std.posix.SO.RCVTIMEO, + std.mem.asBytes(&timeout), + ); + + // Generate WebSocket key and build upgrade request + var key_buf: [24]u8 = undefined; + var raw_key: [16]u8 = undefined; + std.crypto.random.bytes(&raw_key); + const key = http.encodeKey(&key_buf, &raw_key); + + var request_buf: [512]u8 = undefined; + const request = try http.writeRequest(&request_buf, address, "/", key); + try stream.writeAll(request); + + // Allocate read buffer + const read_buf = try allocator.alloc(u8, opts.read_buf_size); + errdefer allocator.free(read_buf); + + // Blocking read loop until handshake response is complete + var total_read: usize = 0; + var head_parser: http.HeadParser = .{}; + + while (true) { + if (total_read >= read_buf.len) return error.HandshakeResponseTooLarge; + const n = try stream.read(read_buf[total_read..]); + if (n == 0) return error.ConnectionClosed; + const old_pos = total_read; + total_read += n; + + // Feed only new bytes to the incremental head parser. + const consumed = head_parser.feed(read_buf[old_pos..total_read]); + if (head_parser.state != .finished) continue; + + // Headers complete — header_len is total bytes consumed across all feeds. + const header_len = old_pos + consumed; + + http.validateResponse(read_buf[0..header_len], key) catch + return error.HandshakeFailed; + return RawClient{ + .stream = stream, + .allocator = allocator, + .read_buf = read_buf, + .read_pos = header_len, + .read_end = total_read, + }; + } + } + + /// Close the underlying TCP connection and free the read buffer. + pub fn deinit(self: *RawClient) void { + self.stream.close(); + self.allocator.free(self.read_buf); + } + + // ==================================================================== + // Frame-level send + // ==================================================================== + + /// Options for fine-grained frame control in `writeFrameEx`. + pub const FrameOpts = struct { + fin: bool = true, + mask: bool = true, + rsv1: bool = false, + rsv2: bool = false, + rsv3: bool = false, + }; + + /// Send a single WebSocket frame with full control over header bits. + /// `opcode` is a raw u4, allowing reserved/invalid opcodes for protocol tests. + /// When `opts.mask` is true (default), payload is masked in-place — caller + /// must pass mutable `[]u8`. When false, payload is sent as-is. + pub fn writeFrameEx(self: *RawClient, opcode: u4, payload: []u8, opts: FrameOpts) !void { + var header_buf: [14]u8 = undefined; + var header_len: usize = 0; + + // Byte 0: FIN, RSV1-3, opcode + var byte0: u8 = 0; + if (opts.fin) byte0 |= 0x80; + if (opts.rsv1) byte0 |= 0x40; + if (opts.rsv2) byte0 |= 0x20; + if (opts.rsv3) byte0 |= 0x10; + byte0 |= @as(u8, opcode); + header_buf[0] = byte0; + + // Byte 1+: MASK bit, payload length + const mask_bit: u8 = if (opts.mask) 0x80 else 0x00; + if (payload.len <= 125) { + header_buf[1] = mask_bit | @as(u8, @truncate(payload.len)); + header_len = 2; + } else if (payload.len <= 65535) { + header_buf[1] = mask_bit | 126; + std.mem.writeInt(u16, header_buf[2..4], @intCast(payload.len), .big); + header_len = 4; + } else { + header_buf[1] = mask_bit | 127; + std.mem.writeInt(u64, header_buf[2..10], payload.len, .big); + header_len = 10; + } + + // Mask key (only when masking) + if (opts.mask) { + var mask_key: [4]u8 = undefined; + std.crypto.random.bytes(&mask_key); + @memcpy(header_buf[header_len..][0..4], &mask_key); + header_len += 4; + + // Mask payload in-place + mask_mod.mask(mask_key, payload); + } + + // Write header then payload + try self.stream.writeAll(header_buf[0..header_len]); + try self.stream.writeAll(payload); + } + + /// Send a single well-formed WebSocket frame with client masking (FIN=1). + /// Payload is masked in-place — caller must pass mutable `[]u8`. + pub fn writeFrame(self: *RawClient, opcode: Opcode, payload: []u8) !void { + try self.writeFrameEx(@intFromEnum(opcode), payload, .{}); + } + + // ==================================================================== + // Frame-level receive + // ==================================================================== + + pub const ReadResult = union(enum) { + message: Message, + timeout, + closed, + }; + + /// Blocking read of a single WebSocket frame with explicit timeout/close status. + /// Returned Message.data is heap-allocated; free via `done()`. + pub fn readResult(self: *RawClient) !ReadResult { + while (true) { + const available = self.read_buf[self.read_pos..self.read_end]; + + // Try to parse a header from available data + const header = frame.parseHeader(available) catch |err| switch (err) { + error.InsufficientData => { + switch (try self.fillBufferResult()) { + .ok => continue, + .timeout => return .timeout, + .closed => return .closed, + } + }, + else => return err, + }; + + // Check if we have the full frame (header + payload) + const payload_len = @as(usize, @intCast(header.payload_len)); + const total_frame_len = @as(usize, @intCast(header.totalLen())); + if (available.len < total_frame_len) { + switch (try self.fillBufferResult()) { + .ok => continue, + .timeout => return .timeout, + .closed => return .closed, + } + } + + // Validate frame + try header.validate(); + try header.validateClientBound(); + + // Extract payload + const payload_start = self.read_pos + header.header_len; + const payload_end = payload_start + payload_len; + const payload = self.read_buf[payload_start..payload_end]; + + // Map opcode to message type + const msg_type: Message.Type = switch (header.opcode) { + .text => .text, + .binary => .binary, + .close => .close, + .ping => .ping, + .pong => .pong, + .continuation => .text, // shouldn't happen in tests + }; + + // Dupe payload (caller-owned) + const data = try self.allocator.dupe(u8, payload); + + // Advance past consumed frame + self.read_pos += total_frame_len; + + // Compact: shift remaining bytes to front + if (self.read_pos > 0) { + const remaining = self.read_end - self.read_pos; + if (remaining > 0) { + const src = self.read_buf[self.read_pos..self.read_end]; + std.mem.copyForwards(u8, self.read_buf[0..remaining], src); + } + self.read_end = remaining; + self.read_pos = 0; + } + + return .{ .message = Message{ .type = msg_type, .data = data } }; + } + } + + /// Free the heap-allocated data from a received Message. + pub fn done(self: *RawClient, msg: Message) void { + self.allocator.free(msg.data); + } + + // ==================================================================== + // Sugar methods + // ==================================================================== + + /// Send a text frame. + pub fn write(self: *RawClient, data: []u8) !void { + try self.writeFrame(.text, data); + } + + pub const CloseOpts = struct { + code: u16, + reason: []const u8 = "", + }; + + /// Send a close frame with status code and optional reason. + /// Reason is truncated to 123 bytes (control frame payload max 125, minus 2 for code). + pub fn close(self: *RawClient, opts: CloseOpts) !void { + var buf: [125]u8 = undefined; + std.mem.writeInt(u16, buf[0..2], opts.code, .big); + const reason_len = @min(opts.reason.len, 123); + if (reason_len > 0) { + @memcpy(buf[2..][0..reason_len], opts.reason[0..reason_len]); + } + try self.writeFrame(.close, buf[0 .. 2 + reason_len]); + } + + /// Polls `readResult()` until any frame arrives or `deadline_ms` elapses. + /// Returns an owned message that must be freed via `done()`. + pub fn waitForMessage(self: *RawClient, deadline_ms: u64) !Message { + var timer = try std.time.Timer.start(); + const deadline_ns = deadline_ms * std.time.ns_per_ms; + + while (timer.read() < deadline_ns) { + switch (try self.readResult()) { + .timeout => continue, + .closed => return error.NoResponse, + .message => |msg| return msg, + } + } + return error.NoResponse; + } + + /// Polls until a frame of `expected_type` arrives or `deadline_ms` elapses. + /// Fails if a different frame type is received first. + pub fn waitForMessageType( + self: *RawClient, + expected_type: Message.Type, + deadline_ms: u64, + ) !Message { + const msg = try self.waitForMessage(deadline_ms); + if (msg.type != expected_type) { + self.done(msg); + return error.UnexpectedData; + } + return msg; + } + + /// Polls `readResult()` until a close frame arrives or `deadline_ms` elapses. + /// Returns an owned message that must be freed via `done()`. + pub fn waitForCloseFrame(self: *RawClient, deadline_ms: u64) !Message { + return self.waitForMessageType(.close, deadline_ms); + } + + /// Polls `readResult()` until peer disconnect or `deadline_ms` elapses. + /// Fails if any frame is received before disconnect. + pub fn waitForClosedNoData(self: *RawClient, deadline_ms: u64) !void { + var timer = try std.time.Timer.start(); + const deadline_ns = deadline_ms * std.time.ns_per_ms; + + while (timer.read() < deadline_ns) { + switch (try self.readResult()) { + .closed => return, + .timeout => continue, + .message => |msg| { + self.done(msg); + return error.UnexpectedData; + }, + } + } + return error.ExpectedDisconnect; + } + + // ==================================================================== + // Internal + // ==================================================================== + + const FillResult = enum { ok, timeout, closed }; + + /// Try to read more data into the buffer. + fn fillBufferResult(self: *RawClient) !FillResult { + if (self.read_end >= self.read_buf.len) { + // Compact first if possible + if (self.read_pos > 0) { + const remaining = self.read_end - self.read_pos; + if (remaining > 0) { + const src = self.read_buf[self.read_pos..self.read_end]; + std.mem.copyForwards(u8, self.read_buf[0..remaining], src); + } + self.read_end = remaining; + self.read_pos = 0; + } else { + return error.BufferFull; + } + } + + const n = self.stream.read(self.read_buf[self.read_end..]) catch |err| switch (err) { + error.WouldBlock => return .timeout, + else => return err, + }; + if (n == 0) return .closed; + self.read_end += n; + return .ok; + } +}; diff --git a/src/rpc/webzockets/e2e_tests/support/server_handlers.zig b/src/rpc/webzockets/e2e_tests/support/server_handlers.zig new file mode 100644 index 0000000000..8dcba56448 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/support/server_handlers.zig @@ -0,0 +1,576 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); + +/// Server-side handler that closes the connection immediately on open. +/// Used by client close_tests.zig to test server-initiated close. +pub const CloseOnOpenHandler = struct { + pub const Context = void; + + pub fn init(_: ws.http.Request, _: void) !CloseOnOpenHandler { + return .{}; + } + + pub fn onOpen(_: *CloseOnOpenHandler, conn: anytype) void { + conn.close(.normal, "server closing"); + } + + pub fn onMessage(_: *CloseOnOpenHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *CloseOnOpenHandler, _: anytype) void {} + pub fn onClose(_: *CloseOnOpenHandler, _: anytype) void {} +}; + +/// Server-side handler that sends a ping immediately on open, then closes +/// after receiving the client's pong response. +/// Used by client ping_pong_tests.zig to verify automatic pong response. +pub const PingOnOpenHandler = struct { + pub const Context = void; + + pub fn init(_: ws.http.Request, _: void) !PingOnOpenHandler { + return .{}; + } + + pub fn onOpen(_: *PingOnOpenHandler, conn: anytype) void { + conn.sendPing("hello") catch return; + } + + pub fn onMessage(_: *PingOnOpenHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *PingOnOpenHandler, _: anytype) void {} + + pub fn onPong(_: *PingOnOpenHandler, conn: anytype, _: []const u8) void { + conn.close(.normal, ""); + } + + pub fn onClose(_: *PingOnOpenHandler, _: anytype) void {} +}; + +/// Server-side handler that sends an unsolicited pong immediately on open. +/// Used by client ping_pong_tests.zig to verify client onPong callback. +/// Does not initiate close — the test client is expected to close after +/// receiving the pong (see PongTrackingHandler). +pub const PongOnOpenHandler = struct { + pub const Context = void; + + pub fn init(_: ws.http.Request, _: void) !PongOnOpenHandler { + return .{}; + } + + pub fn onOpen(_: *PongOnOpenHandler, conn: anytype) void { + conn.sendPong("hello") catch return; + } + + pub fn onMessage(_: *PongOnOpenHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *PongOnOpenHandler, _: anytype) void {} + pub fn onClose(_: *PongOnOpenHandler, _: anytype) void {} +}; + +/// Server-side handler that sends a message larger than the client's expected +/// max_message_size on open. Used by client max_message_tests.zig. +pub const SendOversizedOnOpenHandler = struct { + pub const Context = void; + + sent_data: ?[]const u8 = null, + + pub fn init(_: ws.http.Request, _: void) !SendOversizedOnOpenHandler { + return .{}; + } + + pub fn onOpen(self: *SendOversizedOnOpenHandler, conn: anytype) void { + // Send a 2048-byte message (client will be configured with max_message_size = 1024) + const payload = conn.allocator.alloc(u8, 2048) catch return; + @memset(payload, 'X'); + conn.sendBinary(payload) catch { + conn.allocator.free(payload); + return; + }; + self.sent_data = payload; + } + + pub fn onMessage(_: *SendOversizedOnOpenHandler, _: anytype, _: ws.Message) void {} + + pub fn onWriteComplete(self: *SendOversizedOnOpenHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + } + + pub fn onClose(self: *SendOversizedOnOpenHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + } +}; + +/// Server-side handler that echoes the first message, then closes the connection. +/// Used by server timeout_tests.zig to test close timeout behavior. +pub const CloseAfterFirstMessageHandler = struct { + pub const Context = void; + + got_first: bool = false, + sent_data: ?[]const u8 = null, + + pub fn init(_: ws.http.Request, _: void) !CloseAfterFirstMessageHandler { + return .{}; + } + + pub fn onMessage(self: *CloseAfterFirstMessageHandler, conn: anytype, message: ws.Message) void { + if (self.got_first) return; + self.got_first = true; + + switch (message.type) { + .text, .binary => { + const copy = conn.allocator.dupe(u8, message.data) catch return; + if (message.type == .text) { + conn.sendText(copy) catch { + conn.allocator.free(copy); + return; + }; + } else { + conn.sendBinary(copy) catch { + conn.allocator.free(copy); + return; + }; + } + self.sent_data = copy; + }, + else => {}, + } + } + + pub fn onWriteComplete(self: *CloseAfterFirstMessageHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + // Close after the echo completes + conn.close(.normal, "closing after first message"); + } + + pub fn onClose(self: *CloseAfterFirstMessageHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + } +}; + +/// Server-side handler that always rejects connections by returning an error +/// from init. Used by client connection_tests.zig. +pub const RejectOnInitHandler = struct { + pub const Context = void; + + pub fn init(_: ws.http.Request, _: void) !RejectOnInitHandler { + return error.ConnectionRejected; + } + + pub fn onMessage(_: *RejectOnInitHandler, _: anytype, _: ws.Message) void {} + pub fn onWriteComplete(_: *RejectOnInitHandler, _: anytype) void {} + pub fn onClose(_: *RejectOnInitHandler, _: anytype) void {} +}; + +/// Server-side handler that echoes any text/binary message back to the client. +/// +/// Copies inbound data into owned allocations and queues writes so only one +/// write is in flight at a time. +pub const EchoHandler = struct { + pub const Context = void; + + const PendingMessage = struct { + data: []const u8, + is_text: bool, + next: ?*PendingMessage = null, + }; + + sent_data: ?[]const u8 = null, + queue_head: ?*PendingMessage = null, + queue_tail: ?*PendingMessage = null, + + pub fn init(_: ws.http.Request, _: void) !EchoHandler { + return .{}; + } + + pub fn onMessage(self: *EchoHandler, conn: anytype, message: ws.Message) void { + switch (message.type) { + .text, .binary => { + const copy = conn.allocator.dupe(u8, message.data) catch return; + const msg = conn.allocator.create(PendingMessage) catch { + conn.allocator.free(copy); + return; + }; + msg.* = .{ + .data = copy, + .is_text = message.type == .text, + }; + if (self.queue_tail) |tail| { + tail.next = msg; + } else { + self.queue_head = msg; + } + self.queue_tail = msg; + self.drainQueue(conn); + }, + else => {}, + } + } + + fn drainQueue(self: *EchoHandler, conn: anytype) void { + while (self.queue_head) |msg| { + if (self.sent_data != null) return; + self.queue_head = msg.next; + if (self.queue_head == null) self.queue_tail = null; + + if (msg.is_text) { + conn.sendText(msg.data) catch { + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + continue; + }; + } else { + conn.sendBinary(msg.data) catch { + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + continue; + }; + } + + self.sent_data = msg.data; + conn.allocator.destroy(msg); + return; + } + } + + pub fn onWriteComplete(self: *EchoHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + self.drainQueue(conn); + } + + pub fn onClose(self: *EchoHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + } + while (self.queue_head) |msg| { + self.queue_head = msg.next; + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + } + self.queue_tail = null; + } +}; + +/// Server-side handler that pauses reads on open and waits (via onBytesRead + +/// peekBufferedBytes) until a configurable byte threshold is reached before +/// resuming. Echoes each text/binary message back to the client (one write +/// at a time). Optionally closes after a configured number of messages. +pub const PauseUntilBufferedEchoHandler = struct { + pub const Context = struct { + resume_threshold: usize, + expected_messages: usize = 0, + }; + + const PendingMessage = struct { + data: []const u8, + is_text: bool, + next: ?*PendingMessage = null, + }; + + resume_threshold: usize, + expected_messages: usize, + send_count: usize = 0, + sent_data: ?[]const u8 = null, + queue_head: ?*PendingMessage = null, + queue_tail: ?*PendingMessage = null, + + pub fn init(_: ws.http.Request, ctx: *Context) !PauseUntilBufferedEchoHandler { + return .{ + .resume_threshold = ctx.resume_threshold, + .expected_messages = ctx.expected_messages, + }; + } + + pub fn onOpen(_: *PauseUntilBufferedEchoHandler, conn: anytype) void { + conn.pauseReads(); + } + + pub fn onBytesRead(self: *PauseUntilBufferedEchoHandler, conn: anytype, _: usize) void { + if (conn.peekBufferedBytes().len >= self.resume_threshold) { + conn.resumeReads(); + } + } + + pub fn onMessage(self: *PauseUntilBufferedEchoHandler, conn: anytype, message: ws.Message) void { + switch (message.type) { + .text, .binary => {}, + else => return, + } + + const copy = conn.allocator.dupe(u8, message.data) catch return; + const msg = conn.allocator.create(PendingMessage) catch { + conn.allocator.free(copy); + return; + }; + msg.* = .{ .data = copy, .is_text = message.type == .text }; + if (self.queue_tail) |tail| { + tail.next = msg; + } else { + self.queue_head = msg; + } + self.queue_tail = msg; + self.drainQueue(conn); + } + + fn drainQueue(self: *PauseUntilBufferedEchoHandler, conn: anytype) void { + while (self.queue_head) |msg| { + if (self.sent_data != null) return; + self.queue_head = msg.next; + if (self.queue_head == null) self.queue_tail = null; + + if (msg.is_text) { + conn.sendText(msg.data) catch { + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + continue; + }; + } else { + conn.sendBinary(msg.data) catch { + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + continue; + }; + } + + self.sent_data = msg.data; + conn.allocator.destroy(msg); + return; + } + } + + pub fn onWriteComplete(self: *PauseUntilBufferedEchoHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + self.send_count += 1; + if (self.expected_messages != 0 and self.send_count >= self.expected_messages) { + conn.close(.normal, "done"); + return; + } + self.drainQueue(conn); + } + + pub fn onClose(self: *PauseUntilBufferedEchoHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + } + while (self.queue_head) |msg| { + self.queue_head = msg.next; + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + } + self.queue_tail = null; + } +}; + +/// Server-side handler that pauses on open, waits for a byte threshold, then +/// resumes. After the initial resume, each onMessage pauses reads, echoes the +/// message, and resumes in onWriteComplete. This tests the processMessages +/// loop breaking when read_paused is set mid-loop, and re-entering from +/// onWriteComplete via resumeReads. +pub const PauseMidStreamEchoHandler = struct { + pub const Context = struct { + resume_threshold: usize, + expected_messages: usize, + }; + + resume_threshold: usize, + expected_messages: usize, + initial_resumed: bool = false, + recv_count: usize = 0, + sent_data: ?[]const u8 = null, + + pub fn init(_: ws.http.Request, ctx: *Context) !PauseMidStreamEchoHandler { + return .{ + .resume_threshold = ctx.resume_threshold, + .expected_messages = ctx.expected_messages, + }; + } + + pub fn onOpen(_: *PauseMidStreamEchoHandler, conn: anytype) void { + conn.pauseReads(); + } + + pub fn onBytesRead(self: *PauseMidStreamEchoHandler, conn: anytype, _: usize) void { + if (!self.initial_resumed and conn.peekBufferedBytes().len >= self.resume_threshold) { + self.initial_resumed = true; + conn.resumeReads(); + } + } + + pub fn onMessage(self: *PauseMidStreamEchoHandler, conn: anytype, message: ws.Message) void { + switch (message.type) { + .text, .binary => {}, + else => return, + } + + self.recv_count += 1; + conn.pauseReads(); + + const copy = conn.allocator.dupe(u8, message.data) catch { + conn.resumeReads(); + return; + }; + + if (message.type == .text) { + conn.sendText(copy) catch { + conn.allocator.free(copy); + conn.resumeReads(); + return; + }; + } else { + conn.sendBinary(copy) catch { + conn.allocator.free(copy); + conn.resumeReads(); + return; + }; + } + + self.sent_data = copy; + } + + pub fn onWriteComplete(self: *PauseMidStreamEchoHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + if (self.expected_messages != 0 and self.recv_count >= self.expected_messages) { + conn.close(.normal, "done"); + return; + } + conn.resumeReads(); + } + + pub fn onClose(self: *PauseMidStreamEchoHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + } +}; + +/// Server-side handler that sends configured messages on open (one per write +/// completion) and then closes. +pub const SendMessagesOnOpenHandler = struct { + pub const Context = struct { + messages: []const []const u8, + close_reason: []const u8 = "done", + }; + + messages: []const []const u8, + close_reason: []const u8, + next_index: usize = 0, + sent_data: ?[]const u8 = null, + + pub fn init(_: ws.http.Request, ctx: *Context) !SendMessagesOnOpenHandler { + return .{ + .messages = ctx.messages, + .close_reason = ctx.close_reason, + }; + } + + pub fn onOpen(self: *SendMessagesOnOpenHandler, conn: anytype) void { + self.sendNext(conn); + } + + fn sendNext(self: *SendMessagesOnOpenHandler, conn: anytype) void { + if (self.sent_data != null) return; + + if (self.next_index >= self.messages.len) { + conn.close(.normal, self.close_reason); + return; + } + + const copy = conn.allocator.dupe(u8, self.messages[self.next_index]) catch return; + conn.sendText(copy) catch { + conn.allocator.free(copy); + return; + }; + self.sent_data = copy; + self.next_index += 1; + } + + pub fn onMessage(_: *SendMessagesOnOpenHandler, _: anytype, _: ws.Message) void {} + + pub fn onWriteComplete(self: *SendMessagesOnOpenHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + self.sendNext(conn); + } + + pub fn onClose(self: *SendMessagesOnOpenHandler, conn: anytype) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + self.sent_data = null; + } + } +}; + +/// Server-side handler that detects re-entrant onMessage dispatch. +/// +/// Pauses reads on open, waits for a byte threshold via onBytesRead + +/// peekBufferedBytes, then resumes. All buffered messages dispatch +/// synchronously. In each onMessage the handler calls pauseReads() then +/// resumeReads() — without the re-entrancy guard in processMessages this +/// would recursively dispatch the next buffered message while still inside +/// onMessage. Closes with policy_violation if re-entrancy is detected, or +/// normal when it receives a "done" sentinel. +pub const ReentrancyDetectHandler = struct { + pub const Context = struct { + resume_threshold: usize, + }; + + resume_threshold: usize, + in_on_message: bool = false, + + pub fn init(_: ws.http.Request, ctx: *Context) !ReentrancyDetectHandler { + return .{ + .resume_threshold = ctx.resume_threshold, + }; + } + + pub fn onOpen(_: *ReentrancyDetectHandler, conn: anytype) void { + conn.pauseReads(); + } + + pub fn onBytesRead(self: *ReentrancyDetectHandler, conn: anytype, _: usize) void { + if (conn.peekBufferedBytes().len >= self.resume_threshold) { + conn.resumeReads(); + } + } + + pub fn onMessage(self: *ReentrancyDetectHandler, conn: anytype, message: ws.Message) void { + if (message.type != .text and message.type != .binary) return; + + if (self.in_on_message) { + conn.close(.policy_violation, "reentrant"); + return; + } + self.in_on_message = true; + defer self.in_on_message = false; + + conn.pauseReads(); + conn.resumeReads(); + + if (std.mem.eql(u8, message.data, "done")) { + conn.close(.normal, "ok"); + } + } + + pub fn onWriteComplete(_: *ReentrancyDetectHandler, _: anytype) void {} + + pub fn onClose(_: *ReentrancyDetectHandler, _: anytype) void {} +}; diff --git a/src/rpc/webzockets/e2e_tests/support/server_runner.zig b/src/rpc/webzockets/e2e_tests/support/server_runner.zig new file mode 100644 index 0000000000..7d4b85358f --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/support/server_runner.zig @@ -0,0 +1,117 @@ +const std = @import("std"); +const xev = @import("xev"); + +/// Generic in-process WebSocket server runner for e2e tests. +/// +/// Responsibilities: +/// - Creates a dedicated `xev.ThreadPool` and `xev.Loop`. +/// - Initializes the server with caller-provided init options. +/// - Binds to an ephemeral port (via init options) and exposes it as `port`. +/// - Runs the server loop on a dedicated thread. +/// - Provides `stop()` to shutdown server + loop + thread pool. +/// +/// The `ServerType` must provide: +/// - `pub fn init(allocator: Allocator, loop: *xev.Loop, opts: anytype) !ServerType` +/// - `pub fn deinit(self: *ServerType) void` +/// - `pub fn accept(self: *ServerType) void` +/// - `pub fn shutdown(self: *ServerType, timeout_ms: u32, Ctx: type, ctx: ?*Ctx, cb: anytype) void` +/// - field `listen_socket.fd` +pub fn ServerRunner(comptime ServerType: type) type { + return struct { + const ServerRunnerSelf = @This(); + + loop: *xev.Loop, + thread_pool: xev.ThreadPool, + server: *ServerType, + stop_notifier: xev.Async, + stop_completion: xev.Completion, + thread: std.Thread, + allocator: std.mem.Allocator, + port: u16, + + fn runLoop(loop: *xev.Loop) void { + loop.run(.until_done) catch {}; + } + + pub fn start(allocator: std.mem.Allocator, config: ServerType.Config) !*ServerRunnerSelf { + const self = try allocator.create(ServerRunnerSelf); + errdefer allocator.destroy(self); + + self.allocator = allocator; + + self.thread_pool = xev.ThreadPool.init(.{}); + errdefer { + self.thread_pool.shutdown(); + self.thread_pool.deinit(); + } + + self.loop = try allocator.create(xev.Loop); + errdefer allocator.destroy(self.loop); + self.loop.* = try xev.Loop.init(.{ .thread_pool = &self.thread_pool }); + errdefer self.loop.deinit(); + + self.server = try allocator.create(ServerType); + errdefer allocator.destroy(self.server); + self.server.* = try ServerType.init(allocator, self.loop, config); + errdefer self.server.deinit(); + + self.port = getAssignedPortFromFd(self.server.listen_socket.fd); + + self.stop_notifier = try xev.Async.init(); + errdefer self.stop_notifier.deinit(); + self.stop_completion = .{}; + self.stop_notifier.wait( + self.loop, + &self.stop_completion, + ServerRunnerSelf, + self, + stopCallback, + ); + + self.server.accept(); + self.thread = try std.Thread.spawn(.{}, runLoop, .{self.loop}); + return self; + } + + fn stopCallback( + self_opt: ?*ServerRunnerSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.Async.WaitError!void, + ) xev.CallbackAction { + if (self_opt) |self| { + self.server.shutdown(5000, ServerRunnerSelf, self, onShutdownComplete); + } + return .disarm; + } + + fn onShutdownComplete(self_opt: ?*ServerRunnerSelf, _: ServerType.ShutdownResult) void { + if (self_opt) |self| { + self.loop.stop(); + } + } + + pub fn stop(self: *ServerRunnerSelf) void { + self.stop_notifier.notify() catch {}; + self.thread.join(); + + self.server.deinit(); + self.stop_notifier.deinit(); + self.loop.deinit(); + self.thread_pool.shutdown(); + self.thread_pool.deinit(); + + self.allocator.destroy(self.server); + self.allocator.destroy(self.loop); + self.allocator.destroy(self); + } + }; +} + +pub fn getAssignedPortFromFd(fd: std.posix.fd_t) u16 { + var addr: std.posix.sockaddr.storage = undefined; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(addr)); + std.posix.getsockname(fd, @ptrCast(&addr), &addr_len) catch return 0; + const sa4: *const std.posix.sockaddr.in = @ptrCast(@alignCast(&addr)); + return std.mem.bigToNative(u16, sa4.port); +} diff --git a/src/rpc/webzockets/e2e_tests/support/test_clients.zig b/src/rpc/webzockets/e2e_tests/support/test_clients.zig new file mode 100644 index 0000000000..60bd02e4eb --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/support/test_clients.zig @@ -0,0 +1,146 @@ +const std = @import("std"); +const xev = @import("xev"); +const ws = @import("webzockets_lib"); + +const client_handlers = @import("client_handlers.zig"); + +/// Client-side handler used by basic e2e tests. +/// +/// See `support/handlers.zig`. +pub const EchoTestHandler = client_handlers.SendOnceHandler; +pub const EchoSendKind = client_handlers.SendOnceHandler.SendKind; + +/// Client-side handler that closes immediately on open. +pub const CloseOnOpenHandler = client_handlers.CloseOnOpenHandler; + +/// Client-side handler that tracks open_called but takes no action. +pub const NoOpHandler = client_handlers.NoOpHandler; + +/// Sequence handler for multi-message tests. +pub const SequenceHandler = client_handlers.SequenceHandler; + +/// Client-side handler that waits for server-initiated close. +pub const ServerCloseHandler = client_handlers.ServerCloseHandler; + +/// Client-side handler that tracks pong reception. +pub const PongTrackingHandler = client_handlers.PongTrackingHandler; + +/// Client-side handler for max_message_size tests. +pub const MaxMessageHandler = client_handlers.MaxMessageHandler; + +/// Client-side handler that detects connection/handshake failures. +pub const ConnectFailHandler = client_handlers.ConnectFailHandler; + +/// Client-side handler that explicitly manages pong responses via onPing. +pub const ExplicitPongHandler = client_handlers.ExplicitPongHandler; + +/// Client-side handler that pauses reads on open and waits for enough data +/// to buffer before resuming. Used for deterministic burst-processing tests. +pub const PauseUntilBufferedClientHandler = client_handlers.PauseUntilBufferedClientHandler; + +/// Client-side handler that pauses on open, threshold-resumes, then pauses +/// per-message with resume in onWriteComplete. +pub const PauseMidStreamClientHandler = client_handlers.PauseMidStreamClientHandler; + +/// Client-side handler that detects re-entrant onMessage dispatch via +/// pauseReads/resumeReads while messages are buffered. +pub const ReentrancyDetectClientHandler = client_handlers.ReentrancyDetectClientHandler; + +/// WebSocket client type paired with `EchoTestHandler`. +pub const TestEchoClient = ws.Client(EchoTestHandler, 4096); + +/// WebSocket client type for close-on-open tests. +pub const TestCloseClient = ws.Client(CloseOnOpenHandler, 4096); + +/// WebSocket client type for rejection tests. +pub const TestNoOpClient = ws.Client(NoOpHandler, 4096); + +/// WebSocket client type for sequence tests. +pub const TestSequenceClient = ws.Client(SequenceHandler, 4096); + +/// WebSocket client type for server-initiated close tests. +pub const TestServerCloseClient = ws.Client(ServerCloseHandler, 4096); + +/// WebSocket client type for pong tracking tests. +pub const TestPongTrackingClient = ws.Client(PongTrackingHandler, 4096); + +/// WebSocket client type for max_message_size tests (small read buffer). +pub const TestMaxMessageClient = ws.Client(MaxMessageHandler, 4096); + +/// WebSocket client type for connection failure tests. +pub const TestConnectFailClient = ws.Client(ConnectFailHandler, 4096); + +/// WebSocket client type for explicit pong handler tests. +pub const TestExplicitPongClient = ws.Client(ExplicitPongHandler, 4096); + +/// WebSocket client type for deterministic pause-until-buffered tests. +pub const TestPauseUntilBufferedClient = ws.Client(PauseUntilBufferedClientHandler, 4096); + +/// WebSocket client type for deterministic pause-until-buffered tests with small read buffer. +pub const TestPauseUntilBufferedSmallBufClient = ws.Client(PauseUntilBufferedClientHandler, 256); + +/// WebSocket client type for pause-mid-stream tests. +pub const TestPauseMidStreamClient = ws.Client(PauseMidStreamClientHandler, 4096); + +/// WebSocket client type for re-entrancy detection tests. +pub const TestReentrancyDetectClient = ws.Client(ReentrancyDetectClientHandler, 4096); + +/// Bundles the thread pool, event loop, and CSPRNG needed by +/// client-side e2e tests. Uses pointer-stable init via `start()` on an +/// existing instance. Must not be moved or copied after `start()`. +/// +/// ```zig +/// var env: clients.TestEnv = undefined; +/// try env.start(); +/// defer env.deinit(); +/// +/// var client = env.initClient(clients.TestEchoClient, &handler, &conn, .{ +/// .address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, port), +/// }); +/// ``` +pub const TestEnv = struct { + tp: xev.ThreadPool, + loop: xev.Loop, + csprng: ws.ClientMaskPRNG, + + /// Initialize in place. The loop stores a pointer to `self.tp`, so the + /// struct must not be moved or copied after `start()` returns. + pub fn start(self: *TestEnv) !void { + self.tp = xev.ThreadPool.init(.{}); + errdefer { + self.tp.shutdown(); + self.tp.deinit(); + } + + self.loop = try xev.Loop.init(.{ .thread_pool = &self.tp }); + errdefer self.loop.deinit(); + + var seed: [ws.ClientMaskPRNG.secret_seed_length]u8 = undefined; + std.crypto.random.bytes(&seed); + self.csprng = ws.ClientMaskPRNG.init(seed); + } + + /// Create a client wired to this env's loop, buffer pool, and CSPRNG. + pub fn initClient( + self: *TestEnv, + comptime ClientType: type, + handler: anytype, + conn: anytype, + config: ClientType.Config, + ) ClientType { + return ClientType.init( + std.testing.allocator, + &self.loop, + handler, + conn, + &self.csprng, + config, + ); + } + + pub fn deinit(self: *TestEnv) void { + self.loop.deinit(); + self.tp.shutdown(); + self.tp.deinit(); + } +}; diff --git a/src/rpc/webzockets/e2e_tests/support/test_helpers.zig b/src/rpc/webzockets/e2e_tests/support/test_helpers.zig new file mode 100644 index 0000000000..163fd57d8e --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/support/test_helpers.zig @@ -0,0 +1,47 @@ +const std = @import("std"); +const testing = std.testing; +const RawClient = @import("raw_client.zig").RawClient; + +const default_wait_ms: u64 = 2_000; + +/// Connect to the server, complete a WebSocket handshake, send a text message, +/// verify the echo, and close. Used to confirm the server is still functional +/// after stress or error scenarios. +pub fn verifyServerFunctional(port: u16) !void { + var client = try RawClient.connect(testing.allocator, port); + defer client.deinit(); + + var msg = "hi".*; + try client.write(&msg); + + const response = try client.waitForMessageType(.text, default_wait_ms); + defer client.done(response); + try testing.expectEqualSlices(u8, "hi", response.data); + + try client.close(.{ .code = 1000 }); + const close_resp = try client.waitForCloseFrame(default_wait_ms); + client.done(close_resp); +} + +/// Read a close frame from the client and assert the close code matches. +pub fn expectCloseWithCode(client: *RawClient, expected_code: u16) !void { + const response = try client.waitForCloseFrame(default_wait_ms); + defer client.done(response); + try testing.expect(response.data.len >= 2); + const code = std.mem.readInt(u16, response.data[0..2], .big); + try testing.expectEqual(expected_code, code); +} + +/// Read a close frame and assert both the close code and reason match. +pub fn expectCloseWithCodeAndReason( + client: *RawClient, + expected_code: u16, + expected_reason: []const u8, +) !void { + const response = try client.waitForCloseFrame(default_wait_ms); + defer client.done(response); + try testing.expect(response.data.len >= 2); + const code = std.mem.readInt(u16, response.data[0..2], .big); + try testing.expectEqual(expected_code, code); + try testing.expectEqualSlices(u8, expected_reason, response.data[2..]); +} diff --git a/src/rpc/webzockets/e2e_tests/support/test_servers.zig b/src/rpc/webzockets/e2e_tests/support/test_servers.zig new file mode 100644 index 0000000000..7ef7931033 --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/support/test_servers.zig @@ -0,0 +1,244 @@ +const std = @import("std"); +const ws = @import("webzockets_lib"); + +const server_runner = @import("server_runner.zig"); +const server_handlers = @import("server_handlers.zig"); + +/// Localhost ephemeral-port address used by all test servers. +pub const localhost = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 0); + +/// Default read buffer size used by the test servers. +pub const default_read_buf_size: usize = 4096; + +/// Server-side handler that echoes any text/binary message back to the client. +pub const EchoHandler = server_handlers.EchoHandler; + +/// Generic server runner for creating custom test servers. +pub const ServerRunner = server_runner.ServerRunner; + +/// WebSocket server type used by most e2e tests (echo server). +pub const EchoServer = ws.Server(EchoHandler, default_read_buf_size); + +/// In-process echo server runner used by most e2e tests. +pub const TestServer = server_runner.ServerRunner(EchoServer); + +pub fn startTestServer(allocator: std.mem.Allocator) !*TestServer { + return try TestServer.start(allocator, .{ .address = localhost, .handler_context = {} }); +} + +/// Server that closes every connection immediately on open. +pub const CloseOnOpenServer = ws.Server(server_handlers.CloseOnOpenHandler, default_read_buf_size); +pub const CloseOnOpenTestServer = server_runner.ServerRunner(CloseOnOpenServer); + +pub fn startCloseOnOpenServer(allocator: std.mem.Allocator) !*CloseOnOpenTestServer { + return try CloseOnOpenTestServer.start(allocator, .{ + .address = localhost, + .handler_context = {}, + }); +} + +/// Server that sends a ping to every connection immediately on open. +pub const PingOnOpenServer = ws.Server(server_handlers.PingOnOpenHandler, default_read_buf_size); +pub const PingOnOpenTestServer = server_runner.ServerRunner(PingOnOpenServer); + +pub fn startPingOnOpenServer(allocator: std.mem.Allocator) !*PingOnOpenTestServer { + return try PingOnOpenTestServer.start(allocator, .{ + .address = localhost, + .handler_context = {}, + }); +} + +/// Server that sends an unsolicited pong to every connection immediately on open. +pub const PongOnOpenServer = ws.Server(server_handlers.PongOnOpenHandler, default_read_buf_size); +pub const PongOnOpenTestServer = server_runner.ServerRunner(PongOnOpenServer); + +pub fn startPongOnOpenServer(allocator: std.mem.Allocator) !*PongOnOpenTestServer { + return try PongOnOpenTestServer.start(allocator, .{ + .address = localhost, + .handler_context = {}, + }); +} + +/// Server that sends a 2048-byte message to every connection on open. +pub const OversizedServer = ws.Server( + server_handlers.SendOversizedOnOpenHandler, + default_read_buf_size, +); +pub const OversizedTestServer = server_runner.ServerRunner(OversizedServer); + +pub fn startOversizedServer(allocator: std.mem.Allocator) !*OversizedTestServer { + return try OversizedTestServer.start(allocator, .{ + .address = localhost, + .handler_context = {}, + }); +} + +/// Server that rejects every connection at init time (HTTP 403). +pub const RejectServer = ws.Server(server_handlers.RejectOnInitHandler, default_read_buf_size); +pub const RejectTestServer = server_runner.ServerRunner(RejectServer); + +pub fn startRejectServer(allocator: std.mem.Allocator) !*RejectTestServer { + return try RejectTestServer.start(allocator, .{ + .address = localhost, + .handler_context = {}, + }); +} + +/// CloseOnOpen server with custom timeout configuration for timeout tests. +pub fn startCloseOnOpenServerWithTimeouts( + allocator: std.mem.Allocator, + idle_timeout_ms: ?u32, + close_timeout_ms: u32, +) !*CloseOnOpenTestServer { + return try CloseOnOpenTestServer.start(allocator, .{ + .address = localhost, + .handler_context = {}, + .idle_timeout_ms = idle_timeout_ms, + .close_timeout_ms = close_timeout_ms, + }); +} + +/// Echo server with custom timeout configuration for timeout tests. +pub fn startEchoServerWithTimeouts( + allocator: std.mem.Allocator, + idle_timeout_ms: ?u32, + close_timeout_ms: u32, +) !*TestServer { + return try TestServer.start(allocator, .{ + .address = localhost, + .handler_context = {}, + .idle_timeout_ms = idle_timeout_ms, + .close_timeout_ms = close_timeout_ms, + }); +} + +/// Server that echoes the first message then closes the connection. +pub const CloseAfterFirstMessageServer = ws.Server( + server_handlers.CloseAfterFirstMessageHandler, + default_read_buf_size, +); +pub const CloseAfterFirstMessageTestServer = server_runner.ServerRunner( + CloseAfterFirstMessageServer, +); + +pub fn startCloseAfterFirstMessageServer( + allocator: std.mem.Allocator, + close_timeout_ms: u32, +) !*CloseAfterFirstMessageTestServer { + return try CloseAfterFirstMessageTestServer.start(allocator, .{ + .address = localhost, + .handler_context = {}, + .close_timeout_ms = close_timeout_ms, + }); +} + +/// Server that pauses on open, waits for byte threshold, resumes, echoes messages. +pub const PauseUntilBufferedEchoServer = ws.Server( + server_handlers.PauseUntilBufferedEchoHandler, + default_read_buf_size, +); +pub const PauseUntilBufferedEchoTestServer = + server_runner.ServerRunner(PauseUntilBufferedEchoServer); + +/// PauseUntilBufferedEcho variant with a small 256-byte read buffer. +pub const PauseUntilBufferedEchoSmallBufServer = ws.Server( + server_handlers.PauseUntilBufferedEchoHandler, + 256, +); +pub const PauseUntilBufferedEchoSmallBufTestServer = + server_runner.ServerRunner(PauseUntilBufferedEchoSmallBufServer); + +pub fn startPauseUntilBufferedEchoServer( + allocator: std.mem.Allocator, + ctx: *server_handlers.PauseUntilBufferedEchoHandler.Context, +) !*PauseUntilBufferedEchoTestServer { + return try PauseUntilBufferedEchoTestServer.start(allocator, .{ + .address = localhost, + .handler_context = ctx, + }); +} + +pub fn startPauseUntilBufferedEchoSmallBufServer( + allocator: std.mem.Allocator, + ctx: *server_handlers.PauseUntilBufferedEchoHandler.Context, +) !*PauseUntilBufferedEchoSmallBufTestServer { + return try PauseUntilBufferedEchoSmallBufTestServer.start(allocator, .{ + .address = localhost, + .handler_context = ctx, + }); +} + +/// Server that pauses on open, threshold resumes, then pauses per-message +/// with echo and resume in onWriteComplete. +pub const PauseMidStreamEchoServer = ws.Server( + server_handlers.PauseMidStreamEchoHandler, + default_read_buf_size, +); +pub const PauseMidStreamEchoTestServer = server_runner.ServerRunner(PauseMidStreamEchoServer); + +pub fn startPauseMidStreamEchoServer( + allocator: std.mem.Allocator, + ctx: *server_handlers.PauseMidStreamEchoHandler.Context, +) !*PauseMidStreamEchoTestServer { + return try PauseMidStreamEchoTestServer.start(allocator, .{ + .address = localhost, + .handler_context = ctx, + }); +} + +/// Server that sends configured messages on open, then closes. +pub const SendMessagesOnOpenServer = ws.Server( + server_handlers.SendMessagesOnOpenHandler, + default_read_buf_size, +); +pub const SendMessagesOnOpenTestServer = server_runner.ServerRunner(SendMessagesOnOpenServer); + +pub fn startSendMessagesOnOpenServer( + allocator: std.mem.Allocator, + ctx: *server_handlers.SendMessagesOnOpenHandler.Context, +) !*SendMessagesOnOpenTestServer { + return try SendMessagesOnOpenTestServer.start(allocator, .{ + .address = localhost, + .handler_context = ctx, + }); +} + +// 12 messages of 20 bytes each for small-buffer tests. +pub const small_buf_msg_len = 20; +pub const small_buf_msg_count = 12; + +fn makeSmallBufBufs() [small_buf_msg_count][small_buf_msg_len]u8 { + var bufs: [small_buf_msg_count][small_buf_msg_len]u8 = undefined; + for (0..small_buf_msg_count) |i| { + @memset(&bufs[i], @as(u8, @truncate('A' + i))); + } + return bufs; +} + +const small_buf_bufs = makeSmallBufBufs(); +pub const small_buf_slices = makeSmallBufSlices(); + +fn makeSmallBufSlices() [small_buf_msg_count][]const u8 { + var slices: [small_buf_msg_count][]const u8 = undefined; + for (0..small_buf_msg_count) |i| { + slices[i] = &small_buf_bufs[i]; + } + return slices; +} + +/// Server that detects re-entrant onMessage dispatch via pauseReads/resumeReads. +pub const ReentrancyDetectServer = ws.Server( + server_handlers.ReentrancyDetectHandler, + default_read_buf_size, +); +pub const ReentrancyDetectTestServer = server_runner.ServerRunner(ReentrancyDetectServer); + +pub fn startReentrancyDetectServer( + allocator: std.mem.Allocator, + ctx: *server_handlers.ReentrancyDetectHandler.Context, +) !*ReentrancyDetectTestServer { + return try ReentrancyDetectTestServer.start(allocator, .{ + .address = localhost, + .handler_context = ctx, + }); +} diff --git a/src/rpc/webzockets/e2e_tests/tests.zig b/src/rpc/webzockets/e2e_tests/tests.zig new file mode 100644 index 0000000000..868ed81d6d --- /dev/null +++ b/src/rpc/webzockets/e2e_tests/tests.zig @@ -0,0 +1,11 @@ +const std = @import("std"); + +test { + // Uncomment to see logs during tests + // std.testing.log_level = .debug; +} + +comptime { + _ = @import("server/tests.zig"); + _ = @import("client/tests.zig"); +} diff --git a/src/rpc/webzockets/examples/echo_server.zig b/src/rpc/webzockets/examples/echo_server.zig new file mode 100644 index 0000000000..3f65e9e070 --- /dev/null +++ b/src/rpc/webzockets/examples/echo_server.zig @@ -0,0 +1,150 @@ +const std = @import("std"); +const xev = @import("xev"); +const ws = @import("webzockets_lib"); + +/// Default read buffer size for the echo server. +const default_read_buf_size: usize = 4096; + +/// Echo handler: sends back every text/binary message it receives. +/// Messages are copied into owned allocations and queued so that +/// back-to-back arrivals are echoed in order (only one write can +/// be in flight at a time). +/// +/// Note: We allocate a copy for every message for simplicity. Both received +/// message data (internal read buffers) and sent data (zero-copy writes) have +/// transient lifetimes, so copies are needed to safely hold onto the data. +/// +/// Uses an intrusive singly-linked list for O(1) queue operations. +const EchoHandler = struct { + pub const Context = void; + + const PendingMessage = struct { + data: []const u8, + is_text: bool, + next: ?*PendingMessage = null, + }; + + /// Data currently being written (freed in onWriteComplete). + sent_data: ?[]const u8 = null, + /// Head of pending message queue. + queue_head: ?*PendingMessage = null, + /// Tail of pending message queue (for O(1) append). + queue_tail: ?*PendingMessage = null, + + pub fn init(_: ws.http.Request, _: void) !EchoHandler { + return .{}; + } + + pub fn onMessage(self: *EchoHandler, conn: *EchoServer.Conn, message: ws.Message) void { + switch (message.type) { + .text, .binary => { + std.debug.print("Received {s} ({d} bytes): {s}\n", .{ + @tagName(message.type), + message.data.len, + message.data, + }); + const copy = conn.allocator.dupe(u8, message.data) catch return; + const msg = conn.allocator.create(PendingMessage) catch { + conn.allocator.free(copy); + return; + }; + msg.* = .{ + .data = copy, + .is_text = message.type == .text, + }; + + if (self.queue_tail) |tail| { + tail.next = msg; + } else { + self.queue_head = msg; + } + self.queue_tail = msg; + self.drainQueue(conn); + }, + else => {}, + } + } + + fn drainQueue(self: *EchoHandler, conn: *EchoServer.Conn) void { + while (self.queue_head) |msg| { + if (self.sent_data != null) return; + + self.queue_head = msg.next; + if (self.queue_head == null) self.queue_tail = null; + + if (msg.is_text) { + conn.sendText(msg.data) catch { + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + continue; + }; + } else { + conn.sendBinary(msg.data) catch { + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + continue; + }; + } + + self.sent_data = msg.data; + conn.allocator.destroy(msg); + return; + } + } + + pub fn onWriteComplete(self: *EchoHandler, conn: *EchoServer.Conn) void { + if (self.sent_data) |data| { + std.debug.print("Sent ({d} bytes): {s}\n", .{ data.len, data }); + conn.allocator.free(data); + self.sent_data = null; + } + self.drainQueue(conn); + } + + pub fn onClose(self: *EchoHandler, conn: *EchoServer.Conn) void { + if (self.sent_data) |data| { + conn.allocator.free(data); + } + while (self.queue_head) |msg| { + self.queue_head = msg.next; + conn.allocator.free(msg.data); + conn.allocator.destroy(msg); + } + self.queue_tail = null; + } +}; + +const EchoServer = ws.Server(EchoHandler, default_read_buf_size); + +pub fn main() !void { + const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 8080); + + var thread_pool = xev.ThreadPool.init(.{}); + defer thread_pool.deinit(); + defer thread_pool.shutdown(); + + var loop = try xev.Loop.init(.{ .thread_pool = &thread_pool }); + defer loop.deinit(); + + var server = try EchoServer.init( + std.heap.page_allocator, + &loop, + .{ + .address = address, + .handler_context = {}, + .tcp_accept_backlog = 128, + .max_message_size = 16 * 1024 * 1024, + .initial_handshake_pool_size = 16, + .initial_connection_pool_size = 64, + .max_handshakes = null, + .max_connections = null, + .idle_timeout_ms = null, + .close_timeout_ms = 5_000, + }, + ); + defer server.deinit(); + + server.accept(); + std.debug.print("WebSocket echo server listening on ws://127.0.0.1:8080\n", .{}); + try loop.run(.until_done); +} diff --git a/src/rpc/webzockets/examples/simple_client.zig b/src/rpc/webzockets/examples/simple_client.zig new file mode 100644 index 0000000000..911186fa15 --- /dev/null +++ b/src/rpc/webzockets/examples/simple_client.zig @@ -0,0 +1,125 @@ +const std = @import("std"); +const xev = @import("xev"); +const ws = @import("webzockets_lib"); + +/// Simple client that sends a message, waits for a response, and disconnects. +/// +/// Note: We allocate copies for messages for simplicity. Received message data +/// points into internal read buffers with transient lifetimes, so copies are +/// needed to safely hold onto the data beyond the callback. +const ClientHandler = struct { + allocator: std.mem.Allocator, + payload: []const u8, + received: ?[]const u8 = null, + sent_copy: ?[]const u8 = null, + + pub fn onOpen(self: *ClientHandler, conn: *SimpleClient.Conn) void { + const copy = self.allocator.dupe(u8, self.payload) catch return; + // Print before sending — sendText masks the buffer in-place. + std.debug.print("Sent ({d} bytes): {s}\n", .{ copy.len, copy }); + conn.sendText(copy) catch { + self.allocator.free(copy); + return; + }; + self.sent_copy = copy; + } + + pub fn onMessage(self: *ClientHandler, conn: *SimpleClient.Conn, message: ws.Message) void { + std.debug.print("Received {s} ({d} bytes): {s}\n", .{ + @tagName(message.type), + message.data.len, + message.data, + }); + self.received = self.allocator.dupe(u8, message.data) catch null; + conn.close(.normal, ""); + } + + pub fn onWriteComplete(self: *ClientHandler, _: *SimpleClient.Conn) void { + if (self.sent_copy) |buf| { + std.debug.print("Write complete ({d} masked bytes): 0x{}\n", .{ + buf.len, + std.fmt.fmtSliceHexLower(buf), + }); + self.allocator.free(buf); + self.sent_copy = null; + } + } + + pub fn onClose(self: *ClientHandler, _: *SimpleClient.Conn) void { + if (self.sent_copy) |buf| { + self.allocator.free(buf); + self.sent_copy = null; + } + } +}; + +const SimpleClient = ws.Client(ClientHandler, 4096); + +fn usage(exe_name: []const u8) void { + std.debug.print( + "Usage: {s} [ip] [port] [message]\n" ++ + "Defaults: ip=127.0.0.1 port=8080 message=hello\n", + .{exe_name}, + ); +} + +pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + + const allocator = gpa.allocator(); + + const args = try std.process.argsAlloc(allocator); + defer std.process.argsFree(allocator, args); + + const exe_name = if (args.len > 0) args[0] else "echo-client"; + if (args.len > 1 and (std.mem.eql(u8, args[1], "-h") or std.mem.eql(u8, args[1], "--help"))) { + usage(exe_name); + return; + } + + const ip_str = if (args.len > 1) args[1] else "127.0.0.1"; + const port: u16 = if (args.len > 2) std.fmt.parseInt(u16, args[2], 10) catch { + usage(exe_name); + return error.InvalidPort; + } else 8080; + const address = std.net.Address.parseIp4(ip_str, port) catch { + std.debug.print("Invalid IPv4 address: {s}\n", .{ip_str}); + usage(exe_name); + return error.InvalidAddress; + }; + const msg = if (args.len > 3) args[3] else "hello"; + + var handler: ClientHandler = .{ .allocator = allocator, .payload = msg }; + + var thread_pool = xev.ThreadPool.init(.{}); + defer thread_pool.deinit(); + defer thread_pool.shutdown(); + + var loop = try xev.Loop.init(.{ .thread_pool = &thread_pool }); + defer loop.deinit(); + + var seed: [ws.ClientMaskPRNG.secret_seed_length]u8 = undefined; + std.crypto.random.bytes(&seed); + var csprng = ws.ClientMaskPRNG.init(seed); + + var conn: SimpleClient.Conn = undefined; + var client = SimpleClient.init(allocator, &loop, &handler, &conn, &csprng, .{ + .address = address, + .path = "/", + .max_message_size = 16 * 1024 * 1024, + .close_timeout_ms = 5_000, + }); + + try client.connect(); + try loop.run(.until_done); + + if (handler.received) |data| { + defer allocator.free(data); + std.debug.print("Response: {s}\n", .{data}); + } else { + std.debug.print("No response received\n", .{}); + } + + conn.deinit(); +} diff --git a/src/rpc/webzockets/src/client/client.zig b/src/rpc/webzockets/src/client/client.zig new file mode 100644 index 0000000000..403218055d --- /dev/null +++ b/src/rpc/webzockets/src/client/client.zig @@ -0,0 +1,282 @@ +const std = @import("std"); +const xev = @import("xev"); + +const types = @import("../types.zig"); +const client_connection = @import("connection.zig"); +const client_handshake = @import("handshake.zig"); + +/// Transient WebSocket client that connects TCP, performs the handshake, and +/// initializes a caller-provided `ClientConnection` (allowing external +/// pooling/management). After the handshake completes, the client struct can +/// be discarded — the connection is self-contained. +/// +/// NOTE: this type does not allocate, the allocator and buffer pool +/// is just passed through to the connection for its later use. +/// +/// Comptime parameters: +/// - `Handler`: User handler for protocol events. +/// - `read_buf_size`: Size of the embedded per-connection read buffer. Also +/// defines the maximum size for the handshake response. +/// +/// See `ClientConnection` for the `Handler` interface and callback semantics. +/// +/// ## Connection sequence +/// +/// ``` +/// Caller Client xev Loop Server +/// │ │ │ │ +/// │ connect() │ │ │ +/// ├────────────────────>│ TCP.init() │ │ +/// │ │ TCP.connect() │ │ +/// │ ├───────────────────────>│──── SYN ───────────>│ +/// │ │ │<──── SYN-ACK ───────│ +/// │ │ onConnectComplete │──── ACK ───────────>│ +/// │ │<───────────────────────│ │ +/// │ │ │ │ +/// │ │ Handshake.init() │ │ +/// │ │ Handshake.start() │ │ +/// │ │ ┌─────────────────┐ │ │ +/// │ │ │ generate ws key │ │ │ +/// │ │ │ build HTTP req │ │ │ +/// │ │ └─────────────────┘ │ │ +/// │ │ TCP.write(upgrade req)│ │ +/// │ ├───────────────────────>│──── HTTP Upgrade ──>│ +/// │ │ onWriteComplete │ │ +/// │ │<───────────────────────│ │ +/// │ │ (loop if partial write) │ +/// │ │ │ │ +/// │ │ TCP.read() │ │ +/// │ ├───────────────────────>│<──── HTTP 101 ──────│ +/// │ │ onReadCallback │ │ +/// │ │<───────────────────────│ │ +/// │ │ (loop if headers incomplete) │ +/// │ │ │ │ +/// │ │ processResponse() │ │ +/// │ │ ┌─────────────────┐ │ │ +/// │ │ │ validate 101 │ │ │ +/// │ │ │ check accept key│ │ │ +/// │ │ └─────────────────┘ │ │ +/// │ │ │ │ +/// │ │ ┌── onSuccess ──────────────────────────┐ │ +/// │ │ │ conn.init() — init ClientConnection │ │ +/// │ │ │ conn.start() — begin WebSocket I/O │ │ +/// │ │ └───────────────────────────────────────┘ │ +/// │ │ │ │ +/// │ ┌──────────────────────────────────────┐ │ │ +/// │ │ Client struct can now be discarded. │ │ │ +/// │ │ ClientConnection is self-contained. │ │ │ +/// │ └──────────────────────────────────────┘ │ │ +/// │ │ │ │ +/// │ conn.send() │ │ │ +/// ├────────────────────>│ TCP.write() │ │ +/// │ ├───────────────────────>│──── WS frame ──────>│ +/// │ │ │ │ +/// │ │ │<──── WS frame ──────│ +/// │ │ onReadCallback │ │ +/// │ handler.onMessage()│<───────────────────────│ │ +/// │<────────────────────│ │ │ +/// ``` +/// +/// On failure at any stage, the socket is closed and `Handler.onSocketClose` +/// is called (if declared). +pub fn Client(comptime Handler: type, comptime read_buf_size: usize) type { + return struct { + // -- Caller-provided state -- + conn: *Conn, + handler: *Handler, + allocator: std.mem.Allocator, + config: Config, + loop: *xev.Loop, + + /// Randomness source used for the client handshake (`Sec-WebSocket-Key`) + /// and for per-frame mask key generation (via `ClientConnection`). + /// + /// Must be pointer-stable and outlive the `ClientConnection`. The PRNG + /// is mutated and is not thread-safe; only use it from the `loop.run()` + /// thread and do not share it across loops/threads. + csprng: *types.ClientMaskPRNG, + + // -- Transient handshake state -- + hs: Handshake, + connect_completion: xev.Completion, + socket: xev.TCP, + + const ClientSelf = @This(); + + const log = std.log.scoped(.client); + + const Handshake = client_handshake.ClientHandshake(ClientSelf); + + /// The per-connection WebSocket type exposed to handler callbacks. + pub const Conn = client_connection.ClientConnection(Handler, read_buf_size); + + /// Configuration for a WebSocket client connection. + pub const Config = struct { + /// Server address to connect to (IPv4 or IPv6). Also used + /// (formatted as IP:port) for the HTTP `Host` header during + /// the WebSocket handshake. DNS resolution is not performed. + address: std.net.Address, + /// HTTP request path for the WebSocket upgrade (e.g. "/ws"). + path: []const u8 = "/", + /// Maximum total size in bytes of a reassembled message + /// (across all fragments). Messages exceeding this limit + /// cause the connection to be closed with a message too big error. + max_message_size: usize = 16 * 1024 * 1024, + /// Close handshake timeout in ms. Force disconnect if peer doesn't + /// respond to our close frame within this duration. Default: 5000. + close_timeout_ms: u32 = 5_000, + }; + + pub fn init( + allocator: std.mem.Allocator, + loop: *xev.Loop, + handler: *Handler, + conn: *Conn, + csprng: *types.ClientMaskPRNG, + config: Config, + ) ClientSelf { + // The kqueue and epoll backends require a thread pool to be set on + // the loop, otherwise they cannot perform socket close operations. + if (comptime @hasField(xev.Loop, "thread_pool")) { + std.debug.assert(loop.thread_pool != null); + } + + return ClientSelf{ + .allocator = allocator, + .loop = loop, + .handler = handler, + .conn = conn, + .config = config, + .csprng = csprng, + .hs = undefined, + .connect_completion = .{}, + .socket = undefined, + }; + } + + /// Start connecting to the WebSocket server. + pub fn connect(self: *ClientSelf) !void { + log.debug("connect: address={}, path={s}", .{ self.config.address, self.config.path }); + + self.socket = try xev.TCP.init(self.config.address); + + self.socket.connect( + self.loop, + &self.connect_completion, + self.config.address, + ClientSelf, + self, + onConnectComplete, + ); + log.debug("connect: connect submitted to loop", .{}); + } + + fn onConnectComplete( + self_opt: ?*ClientSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + result: xev.ConnectError!void, + ) xev.CallbackAction { + const self = self_opt orelse { + log.debug("onConnectComplete: self is null", .{}); + return .disarm; + }; + + result catch |err| { + log.debug("onConnectComplete: TCP connect failed: {s}", .{@errorName(err)}); + // TCP connect failed — close and notify handler + self.socket.close( + self.loop, + &self.connect_completion, + ClientSelf, + self, + onFailClose, + ); + return .disarm; + }; + + log.debug("onConnectComplete: TCP connected, starting handshake", .{}); + // TCP connected — start WebSocket handshake using connection's read buffer + // Safe: connection isn't started yet, so read_buf is unused + self.hs = Handshake.init( + self.socket, + self.loop, + &self.conn.read_buf, + self, + self.csprng, + ); + self.hs.start(self.config.address, self.config.path); + + return .disarm; + } + + fn onFailClose( + self_opt: ?*ClientSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + result: xev.CloseError!void, + ) xev.CallbackAction { + result catch |err| log.debug("onFailClose: close failed: {}", .{err}); + const self = self_opt orelse return .disarm; + log.debug("onFailClose: notifying handler", .{}); + if (comptime @hasDecl(Handler, "onSocketClose")) { + self.handler.onSocketClose(); + } + return .disarm; + } + + // -- Handshake context callbacks -- + + pub fn onSuccess(self: *ClientSelf, hs: *Handshake) void { + const leftover = hs.read_pos - hs.header_len; + log.debug("onSuccess: handshake completed, leftover={d} bytes", .{leftover}); + + // Initialize the caller's connection in-place. + // Critical: init() does NOT touch read_buf — the handshake read its + // HTTP response into conn.read_buf (borrowed). The leftover bytes are + // still there at an offset. Point the reader at them in-place. + self.conn.init( + hs.socket, + self.loop, + self.handler, + self.allocator, + .{ + .max_message_size = self.config.max_message_size, + .close_timeout_ms = self.config.close_timeout_ms, + }, + self.csprng, + ); + self.conn.start(hs.header_len, hs.read_pos); + } + + pub fn onError(self: *ClientSelf, hs: *Handshake) void { + log.debug("onError: handshake failed, state={s}, write_pos={d}, " ++ + "request_len={d}, read_pos={d}", .{ + @tagName(hs.state), + hs.write_pos, + hs.request_len, + hs.read_pos, + }); + if (hs.read_pos > 0) { + const preview_len = @min(hs.read_pos, 256); + log.debug("onError: server response ({d} bytes): {s}", .{ + hs.read_pos, + hs.read_buf[0..preview_len], + }); + } else if (hs.request_len == 0) { + log.debug("onError: writeRequest failed (request never built)", .{}); + } else if (hs.write_pos < hs.request_len) { + log.debug("onError: write failed at {d}/{d} bytes", .{ + hs.write_pos, + hs.request_len, + }); + } else { + log.debug("onError: write succeeded but read got error/EOF", .{}); + } + // Close the socket, then notify handler via onSocketClose + hs.closeSocketWithCallback(ClientSelf, self, onFailClose); + } + }; +} diff --git a/src/rpc/webzockets/src/client/connection.zig b/src/rpc/webzockets/src/client/connection.zig new file mode 100644 index 0000000000..7638e69a30 --- /dev/null +++ b/src/rpc/webzockets/src/client/connection.zig @@ -0,0 +1,914 @@ +const std = @import("std"); +const xev = @import("xev"); + +const types = @import("../types.zig"); +const frame = @import("../frame.zig"); +const mask_mod = @import("../mask.zig"); +const Reader = @import("../reader.zig").Reader(.client); +const ControlQueue = @import("../control_queue.zig").ControlQueue; + +/// Per-connection WebSocket client protocol handler wrapping libxev I/O. +/// +/// Handles the WebSocket protocol phase (after HTTP upgrade). Use +/// `ClientHandshake` to perform the upgrade first, then `init()` + +/// `start()` to begin the protocol phase. +/// +/// Comptime parameters: +/// - `Handler`: User handler for protocol events (see below). +/// - `inline_buf_size`: Size of the embedded per-connection read buffer. +/// +/// **Required handler methods** (comptime-enforced): +/// - `onMessage(*Handler, *Self, types.Message)` — complete message received. +/// - `onWriteComplete(*Handler, *Self)` — send finished; caller buffer may be +/// freed. Also called on disconnect if a data write was in-flight/pending. +/// - `onClose(*Handler, *Self)` — connection torn down (exactly once). +/// +/// **Optional handler methods** (detected via `@hasDecl`): +/// - `onOpen(*Handler, *Self)` — protocol phase started, ready to send. +/// - `onPing(*Handler, *Self, []const u8)` — ping received. When declared, +/// the handler is responsible for sending pong via `sendPong()`, which +/// enqueues each pong individually into the control queue. When absent, +/// the library auto-pongs using "latest wins" semantics: if multiple +/// pings arrive before a pong can be sent, only the most recent ping's +/// payload is used (permitted by RFC 6455 §5.5.3). Implement `onPing` +/// if you need to respond to every ping. +/// - `onPong(*Handler, *Self, []const u8)` — pong received. +/// - `onSocketClose(*Handler)` — TCP socket closed (final cleanup, no +/// connection pointer). +/// - `onBytesRead(*Handler, *Self, usize)` — raw TCP data received. +/// Fires on every read completion regardless of whether reads are paused. +/// `usize` parameter is the number of bytes received. Combine with +/// `peekBufferedBytes()` to inspect raw data as it arrives. +/// +/// `sendText`/`sendBinary` mask the caller's buffer in-place per RFC 6455. +/// The caller must not free or reuse the buffer until `onWriteComplete` fires. +/// `sendPing`/`sendPong` copy the payload into an internal queue and do not +/// fire `onWriteComplete`; the caller may free the buffer immediately. +/// +/// All data slices passed to read callbacks point into internal buffers and are +/// only valid for the duration of that callback. +pub fn ClientConnection( + comptime Handler: type, + comptime inline_buf_size: usize, +) type { + comptime { + if (!@hasDecl(Handler, "onMessage")) + @compileError("Handler must declare an onMessage method"); + if (!@hasDecl(Handler, "onWriteComplete")) + @compileError("Handler must declare an onWriteComplete method"); + if (!@hasDecl(Handler, "onClose")) + @compileError("Handler must declare an onClose method"); + } + + return struct { + // -- Owned read buffer -- + read_buf: [inline_buf_size]u8, + + // -- Core state -- + state: types.ConnectionState, + socket: xev.TCP, + loop: *xev.Loop, + allocator: std.mem.Allocator, + handler: *Handler, + config: Config, + /// When true all message bytes received are buffered (up to current reader capacity) + /// but onMessage/onPing/onPong/close-frame handling is deferred until reads + /// are resumed. + read_paused: bool, + /// Re-entrancy guard: true while processMessages is on the call stack. + in_process_messages: bool, + + // -- Tiered read buffer management -- + reader: Reader, + + // -- Completions -- + read_completion: xev.Completion, + close_completion: xev.Completion, + write_completion: xev.Completion, + cancel_completion: xev.Completion, + + // -- Write state -- + write: WriteState, + + // -- Close-handshake timer state -- + timer: xev.Timer, + close_timer_completion: xev.Completion, + close_timer_cancel_completion: xev.Completion, + + // -- PRNG for mask key generation (same loop thread, must outlive connection) -- + csprng: *types.ClientMaskPRNG, + + const ConnectionSelf = @This(); + + const log = std.log.scoped(.client_connection); + + /// Comptime flags for optional handler capabilities. + const has = struct { + const on_open = @hasDecl(Handler, "onOpen"); + const on_ping = @hasDecl(Handler, "onPing"); + const on_pong = @hasDecl(Handler, "onPong"); + const on_socket_close = @hasDecl(Handler, "onSocketClose"); + const on_bytes_read = @hasDecl(Handler, "onBytesRead"); + }; + + pub const Config = struct { + /// Maximum total size of a reassembled fragmented message. + max_message_size: usize = 16 * 1024 * 1024, + /// Close handshake timeout in ms. Force disconnect if peer doesn't + /// respond to our close frame within this duration. Default: 5000. + close_timeout_ms: u32 = 5_000, + }; + + /// Holds pending auto-pong payload when the handler does not declare + /// `onPing` (the library auto-responds). When `onPing` is declared this + /// is a zero-sized struct, adding no bytes to WriteState. + const AutoPongState = if (has.on_ping) + // Handler manages pong responses; no auto-pong state needed. + struct {} + else + struct { + /// Whether an auto-pong is pending (latest ping payload wins). + pending: bool = false, + /// Payload for the pending auto-pong. + data: [125]u8 = undefined, + /// Length of the pending auto-pong payload. + len: u8 = 0, + }; + + /// All write-path state, grouped for clarity. + const WriteState = struct { + /// What is currently being written to the socket. + const InFlight = union(enum) { + idle, + /// Two-phase data write: header bytes first, then masked payload. + data: struct { + phase: enum { header, payload }, + offset: usize, + }, + /// Single-phase control frame write from control_buf. + control: struct { + offset: usize, + }, + }; + + /// Current in-flight write operation state. + in_flight: InFlight = .idle, + + // Data write buffers + /// Header buffer for current or pending data write (10 + 4 mask key). + header_buf: [14]u8 = undefined, + /// Length of actual header in header_buf. + header_len: usize = 0, + /// Payload for current or pending data write (caller owned slice, data masked in place). + payload: []const u8 = &.{}, + + // Control frame buffers (127 payload + 4 mask key) + /// Control frame buffer for in-flight control writes. + control_buf: [131]u8 = undefined, + /// Length of actual control frame in control_buf. + control_len: usize = 0, + /// Queue for pending control frames (close, ping, pong). + control_queue: ControlQueue = ControlQueue.init(), + /// Auto-pong state (only present when Handler lacks onPing). + auto_pong: AutoPongState = .{}, + /// Whether to disconnect after sending a close frame. + disconnect_after_close: bool = false, + /// Whether the peer initiated the close (used to detect duplicate close frames). + peer_initiated_close: bool = false, + }; + + // ==================================================================== + // Lifecycle + // ==================================================================== + + /// Initialize the connection in-place. Sets all runtime fields except + /// `read_buf` (which may contain handshake leftover data that must be + /// preserved). Reader is initialized with `&self.read_buf`. + pub fn init( + self: *ConnectionSelf, + socket: xev.TCP, + loop: *xev.Loop, + handler: *Handler, + allocator: std.mem.Allocator, + config: Config, + csprng: *types.ClientMaskPRNG, + ) void { + self.state = .open; + self.socket = socket; + self.loop = loop; + self.allocator = allocator; + self.handler = handler; + self.config = config; + self.reader = Reader.init( + &self.read_buf, + allocator, + config.max_message_size, + ); + self.read_completion = .{}; + self.close_completion = .{}; + self.write_completion = .{}; + self.cancel_completion = .{}; + self.write = .{}; + self.timer = xev.Timer.init() catch unreachable; + self.close_timer_completion = .{}; + self.close_timer_cancel_completion = .{}; + self.read_paused = false; + self.in_process_messages = false; + self.csprng = csprng; + } + + /// Release any resources allocated by the connection (reader buffers). + pub fn deinit(self: *ConnectionSelf) void { + self.reader.deinit(); + self.timer.deinit(); + } + + /// Begin the WebSocket protocol phase. Sets the reader position to + /// account for leftover handshake bytes already in the read buffer, + /// calls `onOpen` (if defined), then processes any leftover data and + /// starts reading new data. + pub fn start(self: *ConnectionSelf, data_start: usize, data_end: usize) void { + std.debug.assert(data_start <= data_end); + std.debug.assert(data_end <= self.reader.buf.len); + self.reader.start = data_start; + self.reader.pos = data_end; + + if (comptime has.on_open) { + self.handler.onOpen(self); + } + // Process any leftover handshake data, and start reading + self.processMessages(); + } + + // ==================================================================== + // Public send API + // ==================================================================== + + /// Send a text message. The payload is masked in place; keep the buffer + /// alive and unmodified until `onWriteComplete` fires. Only one data + /// write can be in flight; returns `error.WriteBusy` if another send is + /// pending. Queue additional sends and retry from `onWriteComplete`. + pub fn sendText(self: *ConnectionSelf, data: []u8) !void { + if (self.state != .open) return error.InvalidState; + try self.startDataWrite(.text, data); + } + + /// Send a binary message. The payload is masked in place; keep the buffer + /// alive and unmodified until `onWriteComplete` fires. Only one data + /// write can be in flight; returns `error.WriteBusy` if another send is + /// pending. Queue additional sends and retry from `onWriteComplete`. + pub fn sendBinary(self: *ConnectionSelf, data: []u8) !void { + if (self.state != .open) return error.InvalidState; + try self.startDataWrite(.binary, data); + } + + /// Send a ping frame. The payload is copied into an internal control queue, + /// so the caller can free the buffer immediately after this returns. + /// No `onWriteComplete` callback fires for ping sends. + /// Returns `error.ControlFrameTooBig` if payload exceeds 125 bytes, + /// or `error.QueueFull` if the control queue has insufficient space. + pub fn sendPing(self: *ConnectionSelf, data: []const u8) !void { + if (self.state != .open) return error.InvalidState; + if (data.len > 125) return error.ControlFrameTooBig; + try self.enqueueAndFlush(.ping, data); + } + + /// Send a pong frame. The payload is copied into an internal control queue, + /// so the caller can free the buffer immediately after this returns. + /// No `onWriteComplete` callback fires for pong sends. + /// Returns `error.ControlFrameTooBig` if payload exceeds 125 bytes, + /// or `error.QueueFull` if the control queue has insufficient space. + pub fn sendPong(self: *ConnectionSelf, data: []const u8) !void { + if (self.state != .open) return error.InvalidState; + if (data.len > 125) return error.ControlFrameTooBig; + try self.enqueueAndFlush(.pong, data); + } + + /// Initiate a close handshake with the given status code and optional reason. + /// The connection transitions to `.closing` and waits for the peer's close + /// response. Arms a close-handshake timer that force-disconnects if the peer + /// doesn't respond within `close_timeout_ms`. + /// The reason is silently truncated to 123 bytes (the maximum allowed by RFC 6455 + /// after the 2-byte close code in a 125-byte control frame payload). + pub fn close(self: *ConnectionSelf, code: types.CloseCode, reason: []const u8) void { + if (self.state != .open) return; + self.state = .closing; + + // Build close payload: 2-byte big-endian status code + reason text + var payload: [125]u8 = undefined; + payload[0..2].* = code.payloadBytes(); + const reason_len = @min(reason.len, 123); // control frame payload max 125 + @memcpy(payload[2..][0..reason_len], reason[0..reason_len]); + const total_len: u8 = @intCast(2 + reason_len); + + self.enqueueClose(payload[0..total_len]); + + // Start close-handshake deadline timer. + if (self.close_timer_completion.state() != .active) { + self.timer.run( + self.loop, + &self.close_timer_completion, + self.config.close_timeout_ms, + ConnectionSelf, + self, + onCloseTimerCallback, + ); + } + } + + /// Pause frame dispatch. While paused, onMessage/onPing/onPong/close-frame + /// handling stops until `resumeReads()` is called. TCP reads continue until + /// read buffer is full, but will not grow the buffer while paused. + pub fn pauseReads(self: *ConnectionSelf) void { + self.read_paused = true; + } + + /// Resume frame dispatch and drain any already-buffered frames, this will + /// cause onMessage/onPing/onPong/close-frame handling to resume. + pub fn resumeReads(self: *ConnectionSelf) void { + if (self.state == .closed or !self.read_paused) return; + self.read_paused = false; + self.processMessages(); + } + + /// Peek at the raw bytes currently buffered in the reader (received + /// from TCP but not yet consumed as websocket frames). The returned + /// slice points into an internal buffer and may be invalidated as soon + /// as the xev loop ticks again. + pub fn peekBufferedBytes(self: *ConnectionSelf) []const u8 { + return self.reader.buf[self.reader.start..self.reader.pos]; + } + + // ==================================================================== + // Read path + // ==================================================================== + + /// Arm another socket read when legal; while paused this continues + /// filling the read buffer until it has no free space. + fn maybeReadMore(self: *ConnectionSelf) void { + if (self.state == .closed or self.write.peer_initiated_close) return; + if (self.read_completion.state() == .active) return; + if (self.read_paused) { + // Reclaim consumed bytes if buffer is full so we can buffer as much as + // possible without growing the read buffer. + self.reader.compactIfFull(); + // If still full then just return to avoid growing the buffer + if (self.reader.availableSpace() == 0) return; + } + self.startRead(); + } + + fn startRead(self: *ConnectionSelf) void { + const slice = self.reader.readSlice() catch |err| { + log.debug("readSlice failed: {}", .{err}); + self.handleDisconnect(); + return; + }; + self.socket.read( + self.loop, + &self.read_completion, + .{ .slice = slice }, + ConnectionSelf, + self, + onReadCallback, + ); + } + + fn onReadCallback( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + _: xev.ReadBuffer, + result: xev.ReadError!usize, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + if (self.state == .closed) { + self.checkAllDone(); + return .disarm; + } + + const bytes_read = result catch |err| { + log.debug("read failed: {}", .{err}); + self.handleDisconnect(); + return .disarm; + }; + + if (bytes_read == 0) { + log.debug("peer closed TCP connection", .{}); + self.handleDisconnect(); + return .disarm; + } + + self.reader.advancePos(bytes_read); + + if (comptime has.on_bytes_read) { + self.handler.onBytesRead(self, bytes_read); + } + + if (self.read_paused) { + self.maybeReadMore(); + return .disarm; + } + + switch (self.state) { + .open, .closing => self.processMessages(), + .closed => {}, + } + + return .disarm; + } + + // ==================================================================== + // Frame processing + // ==================================================================== + + fn processMessages(self: *ConnectionSelf) void { + // Re-entrancy guard: prevents recursion into processMessages() when a handler + // calls resumeReads() + if (self.in_process_messages) return; + + if (self.read_paused) { + self.maybeReadMore(); + return; + } + + self.in_process_messages = true; + defer self.in_process_messages = false; + + while (true) { + const maybe_msg = self.reader.nextMessage() catch |err| { + log.debug("nextMessage failed: {}", .{err}); + switch (err) { + error.ProtocolError => self.failWithClose(.protocol_error), + error.MessageTooBig => self.failWithClose(.message_too_big), + error.OutOfMemory => self.handleDisconnect(), + } + return; + }; + const msg = maybe_msg orelse break; + + switch (msg.type) { + .text, .binary => self.handler.onMessage(self, msg), + .ping => { + if (comptime has.on_ping) { + // Handler manages pong response + self.handler.onPing(self, msg.data); + } else { + // Auto-pong: store in dedicated field (latest wins) + const len: u8 = @intCast(msg.data.len); + @memcpy(self.write.auto_pong.data[0..len], msg.data[0..len]); + self.write.auto_pong.len = len; + self.write.auto_pong.pending = true; + self.trySubmitNextControl(); + } + }, + .pong => { + if (comptime has.on_pong) { + self.handler.onPong(self, msg.data); + } + }, + .close => { + self.handleCloseFrame(msg.data); + if (self.state == .closed) return; + }, + } + if (self.state == .closed or self.read_paused) break; + } + // Only start a new read if we're still active. + // When peer_initiated_close is set, we're echoing the peer's close frame + // and will disconnect as soon as the write completes — no need to read more. + self.maybeReadMore(); + } + + /// Send a close frame with the given code and disconnect. + /// Used for client-initiated error closes (protocol error, message too big). + fn failWithClose(self: *ConnectionSelf, code: types.CloseCode) void { + if (self.state == .closed) return; + self.state = .closing; + self.write.disconnect_after_close = true; + const close_payload = code.payloadBytes(); + self.enqueueClose(&close_payload); + } + + fn handleCloseFrame(self: *ConnectionSelf, payload: []const u8) void { + if (self.state == .closing) { + // Already in closing state. If we're waiting to send our close response + // (peer_initiated_close), ignore this duplicate. Otherwise, we initiated + // the close and this is the peer's response — complete the handshake. + if (!self.write.peer_initiated_close) { + self.handleDisconnect(); + } + // else: peer sent multiple close frames — ignore per RFC 6455 + } else { + // Peer initiated — validate and echo the close frame, then disconnect + self.state = .closing; + self.write.peer_initiated_close = true; + self.write.disconnect_after_close = true; + + const validation = types.validateClosePayload(payload); + switch (validation) { + .valid_payload => |vp| self.enqueueClose(vp), + .close_code => |code| { + const close_payload = code.payloadBytes(); + self.enqueueClose(&close_payload); + }, + } + } + } + + // ==================================================================== + // Write path (masked) + // ==================================================================== + + fn generateMaskKey(self: *ConnectionSelf) [4]u8 { + var key: [4]u8 = undefined; + self.csprng.fill(&key); + return key; + } + + /// True when a user data write is in-flight or deferred. + fn outstandingUserWrite(self: *ConnectionSelf) bool { + return self.write.header_len != 0; + } + + /// Start a two-phase data write: header first, then payload (masked in-place in + /// the caller's buffer). + /// If a control frame is in flight, the data write is deferred and will + /// start automatically when the control frame completes. + /// Returns error.WriteBusy if another data write is already in flight or pending. + fn startDataWrite(self: *ConnectionSelf, opcode: types.Opcode, payload: []u8) !void { + // Only one data write can be pending/in-flight at a time + if (self.outstandingUserWrite()) { + return error.WriteBusy; + } + + const mask_key = self.generateMaskKey(); + + // Build masked header + const header = frame.writeClientFrameHeader( + &self.write.header_buf, + opcode, + payload.len, + mask_key, + false, + ); + self.write.header_len = header.len; + + // Mask payload in-place in the caller's buffer + if (payload.len > 0) { + mask_mod.mask(mask_key, payload); + } + self.write.payload = payload; + + if (self.write.in_flight == .control) { + // Control frame in flight — defer data write until it completes + return; + } + + // Idle — start immediately + self.write.in_flight = .{ .data = .{ .phase = .header, .offset = 0 } }; + self.submitWrite(self.write.header_buf[0..self.write.header_len]); + } + + /// Submit a write to the socket using the shared write completion. + fn submitWrite(self: *ConnectionSelf, slice: []const u8) void { + self.socket.write( + self.loop, + &self.write_completion, + .{ .slice = slice }, + ConnectionSelf, + self, + onWriteCallback, + ); + } + + /// Advance a write offset by bytes_written. If the full slice has been + /// written, returns true. Otherwise reissues a write for the remainder. + fn advanceWrite( + self: *ConnectionSelf, + offset: *usize, + bytes_written: usize, + full_slice: []const u8, + ) bool { + offset.* += bytes_written; + if (offset.* < full_slice.len) { + self.submitWrite(full_slice[offset.*..]); + return false; + } + return true; + } + + /// Build a masked control frame into control_buf and issue the socket write. + fn writeControlFrame(self: *ConnectionSelf, opcode: types.Opcode, payload: []const u8) void { + const mask_key = self.generateMaskKey(); + + const header = frame.writeClientFrameHeader( + self.write.control_buf[0..14], + opcode, + payload.len, + mask_key, + false, + ); + const header_len = header.len; + if (payload.len > 0) { + @memcpy(self.write.control_buf[header_len..][0..payload.len], payload); + // Mask control payload in-place + mask_mod.mask(mask_key, self.write.control_buf[header_len..][0..payload.len]); + } + self.write.control_len = header_len + payload.len; + self.write.in_flight = .{ .control = .{ .offset = 0 } }; + self.submitWrite(self.write.control_buf[0..self.write.control_len]); + } + + fn onWriteCallback( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + _: xev.WriteBuffer, + result: xev.WriteError!usize, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + + if (self.state == .closed) { + self.checkAllDone(); + return .disarm; + } + + const bytes_written = result catch |err| { + log.debug("write failed: {}", .{err}); + self.handleDisconnect(); + return .disarm; + }; + + switch (self.write.in_flight) { + .data => |*d| switch (d.phase) { + .header => { + const header_data = self.write.header_buf[0..self.write.header_len]; + if (self.advanceWrite(&d.offset, bytes_written, header_data)) { + if (self.write.payload.len == 0) { + self.finishWrite(); + } else { + d.phase = .payload; + d.offset = 0; + self.submitWrite(self.write.payload); + } + } + }, + .payload => { + if (self.advanceWrite(&d.offset, bytes_written, self.write.payload)) { + self.finishWrite(); + } + }, + }, + .control => |*c| { + const control_data = self.write.control_buf[0..self.write.control_len]; + if (self.advanceWrite(&c.offset, bytes_written, control_data)) { + self.finishControlWrite(); + } + }, + .idle => {}, + } + + return .disarm; + } + + fn finishWrite(self: *ConnectionSelf) void { + self.write.in_flight = .idle; + self.write.payload = &.{}; + self.write.header_len = 0; + // Flush pending controls first (priority over data) before invoking user callback + self.trySubmitNextControl(); + self.handler.onWriteComplete(self); + } + + fn finishControlWrite(self: *ConnectionSelf) void { + self.write.in_flight = .idle; + + if (self.write.disconnect_after_close) { + // Close frame sent (peer echo or protocol error) — tear down + self.handleDisconnect(); + return; + } + + // Pending controls have priority over pending data writes. + self.trySubmitNextControl(); + if (self.write.in_flight == .idle and self.outstandingUserWrite()) { + // Start deferred data write (header already built in header_buf). + self.write.in_flight = .{ .data = .{ .phase = .header, .offset = 0 } }; + self.submitWrite(self.write.header_buf[0..self.write.header_len]); + } + } + + /// Enqueue a control frame and flush if the write path is idle. + fn enqueueAndFlush(self: *ConnectionSelf, opcode: types.Opcode, data: []const u8) !void { + try self.write.control_queue.enqueue(opcode, data); + self.trySubmitNextControl(); + } + + /// Enqueue a close frame, clearing any pending controls (close takes priority). + fn enqueueClose(self: *ConnectionSelf, payload: []const u8) void { + self.write.control_queue.clear(); + if (comptime !has.on_ping) { + self.write.auto_pong.pending = false; + } + self.write.control_queue.enqueue(.close, payload) catch unreachable; + self.trySubmitNextControl(); + } + + /// Flush the next pending control frame, if any. No-op when a write is + /// already in flight. Implements the priority chain: + /// 1. Close frame in the queue (always wins) + /// 2. Pending auto-pong (latest ping payload) + /// 3. Other queued control frames (FIFO) + fn trySubmitNextControl(self: *ConnectionSelf) void { + if (self.write.in_flight != .idle) return; + + var payload_buf: [125]u8 = undefined; + // Priority 1: Close frame in queue + if (self.write.control_queue.isNextClose()) { + const entry = self.write.control_queue.dequeue(&payload_buf).?; + self.writeControlFrame(entry.opcode, payload_buf[0..entry.len]); + return; + } + // Priority 2: Pending auto-pong + if (comptime !has.on_ping) { + if (self.write.auto_pong.pending) { + self.write.auto_pong.pending = false; + const pong_data = self.write.auto_pong.data[0..self.write.auto_pong.len]; + self.writeControlFrame(.pong, pong_data); + return; + } + } + // Priority 3: Other queued control frames + if (self.write.control_queue.dequeue(&payload_buf)) |entry| { + self.writeControlFrame(entry.opcode, payload_buf[0..entry.len]); + } + } + + // ==================================================================== + // Timer callbacks + // ==================================================================== + + fn onCloseTimerCallback( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + result: xev.Timer.RunError!void, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + + result catch |err| switch (err) { + error.Canceled => { + if (self.state == .closed) self.checkAllDone(); + return .disarm; + }, + error.Unexpected => |e| { + // should never happen + log.err("close timer error: {}", .{e}); + // fallthrough to enforce close/cleanup + }, + }; + + if (self.state == .closing and !self.write.peer_initiated_close) { + // Close timeout expired (client-initiated) — force disconnect + self.handleDisconnect(); + } else if (self.state == .closed) { + self.checkAllDone(); + } + + return .disarm; + } + + fn onTimerCancelled( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + result: xev.Timer.CancelError!void, + ) xev.CallbackAction { + // Cancel completion exists so we can wait for a backend-specific + // timer removal operation to complete (io_uring uses timer_remove). + result catch |err| log.debug("timer cancel error: {}", .{err}); + + // Only relevant for teardown if the connection closed while the + // cancellation was in-flight. + if (self_opt) |self| { + if (self.state == .closed) self.checkAllDone(); + } + return .disarm; + } + + // ==================================================================== + // Cleanup + // ==================================================================== + + fn handleDisconnect(self: *ConnectionSelf) void { + if (self.state == .closed) return; + self.state = .closed; + + // Notify handler if a data write was in flight or pending so buffer + // can be cleaned up. Not fired for internal control frame writes. + if (self.outstandingUserWrite()) { + self.handler.onWriteComplete(self); + } + + self.write = .{}; + self.handler.onClose(self); + + const timer_active = self.close_timer_completion.state() == .active; + const cancel_active = self.close_timer_cancel_completion.state() == .active; + if (timer_active and !cancel_active) { + // Cancel via xev.Timer.cancel() (not a raw `.cancel` op) since some + // backends (e.g. io_uring) require a timer-specific remove. + self.timer.cancel( + self.loop, + &self.close_timer_completion, + &self.close_timer_cancel_completion, + ConnectionSelf, + self, + onTimerCancelled, + ); + } + + self.cancelActive(&self.read_completion, &self.cancel_completion); + self.cancelActive(&self.write_completion, &self.close_completion); + + self.checkAllDone(); + } + + /// Raw xev callback for cancel completions. + fn cancelCallback( + ud: ?*anyopaque, + _: *xev.Loop, + _: *xev.Completion, + result: xev.Result, + ) xev.CallbackAction { + result.cancel catch |err| log.debug("cancel failed: {}", .{err}); + if (ud) |ptr| { + const self: *ConnectionSelf = @ptrCast(@alignCast(ptr)); + self.checkAllDone(); + } + return .disarm; + } + + /// Cancel an active completion using a cancel slot. + fn cancelActive( + self: *ConnectionSelf, + target: *xev.Completion, + cancel_slot: *xev.Completion, + ) void { + if (target.state() == .active) { + cancel_slot.* = .{ + .op = .{ .cancel = .{ .c = target } }, + .userdata = @ptrCast(self), + .callback = cancelCallback, + }; + self.loop.add(cancel_slot); + } + } + + /// Check if all completions (read, write, cancel, close) are inactive. + /// If so, close the socket. + fn checkAllDone(self: *ConnectionSelf) void { + const completions = [_]*xev.Completion{ + &self.read_completion, + &self.write_completion, + &self.cancel_completion, + &self.close_completion, + &self.close_timer_completion, + &self.close_timer_cancel_completion, + }; + for (completions) |c| { + if (c.state() == .active) return; + } + self.closeSocket(); + } + + fn closeSocket(self: *ConnectionSelf) void { + self.socket.close( + self.loop, + &self.close_completion, + ConnectionSelf, + self, + onCloseComplete, + ); + } + + fn onCloseComplete( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + result: xev.CloseError!void, + ) xev.CallbackAction { + result catch |err| log.debug("close failed: {}", .{err}); + if (self_opt) |self| { + if (comptime has.on_socket_close) { + self.handler.onSocketClose(); + } + } + return .disarm; + } + }; +} diff --git a/src/rpc/webzockets/src/client/handshake.zig b/src/rpc/webzockets/src/client/handshake.zig new file mode 100644 index 0000000000..655451494a --- /dev/null +++ b/src/rpc/webzockets/src/client/handshake.zig @@ -0,0 +1,257 @@ +const std = @import("std"); +const xev = @import("xev"); +const types = @import("../types.zig"); +const http = @import("../http.zig"); + +/// Async HTTP upgrade client handshake wrapping libxev I/O. +/// +/// Writes the HTTP upgrade request to the socket, reads the 101 response, +/// validates it, and calls the context's `onSuccess` or `onError` callback. +/// On error, the caller is responsible for closing the socket via +/// `closeSocketWithCallback`. +/// +/// Comptime-parameterized by `Context`, which must declare: +/// - `onSuccess(ctx: *Context, hs: *ClientHandshake(Context)) void` +/// - `onError(ctx: *Context, hs: *ClientHandshake(Context)) void` +pub fn ClientHandshake(comptime Context: type) type { + comptime { + if (!@hasDecl(Context, "onSuccess")) + @compileError("Context must declare an onSuccess method"); + if (!@hasDecl(Context, "onError")) + @compileError("Context must declare an onError method"); + } + + return struct { + state: types.HandshakeState, + socket: xev.TCP, + loop: *xev.Loop, + context: *Context, + + read_buf: []u8, + read_pos: usize, + request_buf: [512]u8, + request_len: usize, + write_pos: usize, + head_parser: http.HeadParser, + + key_buf: [24]u8, + key: []const u8, + + /// Caller-provided PRNG used to generate the WebSocket key. + /// Must only be used from the loop thread (not thread-safe). + csprng: *types.ClientMaskPRNG, + + read_completion: xev.Completion, + write_completion: xev.Completion, + close_completion: xev.Completion, + + header_len: usize, + + const ClientHandshakeSelf = @This(); + + const log = std.log.scoped(.client_handshake); + + pub fn init( + socket: xev.TCP, + loop: *xev.Loop, + read_buf: []u8, + context: *Context, + csprng: *types.ClientMaskPRNG, + ) ClientHandshakeSelf { + return .{ + .state = .writing, + .socket = socket, + .loop = loop, + .context = context, + .read_buf = read_buf, + .read_pos = 0, + .request_buf = undefined, + .request_len = 0, + .write_pos = 0, + .head_parser = .{}, + .key_buf = undefined, + .key = &.{}, + .csprng = csprng, + .read_completion = .{}, + .write_completion = .{}, + .close_completion = .{}, + .header_len = 0, + }; + } + + /// Begin the handshake: generate key, build request, write to socket. + pub fn start(self: *ClientHandshakeSelf, address: std.net.Address, path: []const u8) void { + log.debug("start: address={}, path={s}", .{ address, path }); + + var raw_key: [16]u8 = undefined; + self.csprng.fill(&raw_key); + self.key = http.encodeKey(&self.key_buf, &raw_key); + + const request = http.writeRequest(&self.request_buf, address, path, self.key) catch { + log.debug("start: writeRequest failed (buffer too small)", .{}); + self.fail(); + return; + }; + self.request_len = request.len; + self.write_pos = 0; + self.state = .writing; + log.debug("start: sending {d} byte request", .{self.request_len}); + + self.socket.write( + self.loop, + &self.write_completion, + .{ .slice = self.request_buf[0..self.request_len] }, + ClientHandshakeSelf, + self, + onWriteComplete, + ); + } + + /// Close the socket with a caller-provided callback (so the caller + /// can be notified when the close completes). + pub fn closeSocketWithCallback( + self: *ClientHandshakeSelf, + comptime Ctx: type, + ctx: *Ctx, + comptime cb: fn ( + ?*Ctx, + *xev.Loop, + *xev.Completion, + xev.TCP, + xev.CloseError!void, + ) xev.CallbackAction, + ) void { + self.socket.close(self.loop, &self.close_completion, Ctx, ctx, cb); + } + + // ==================================================================== + // Internal + // ==================================================================== + + fn fail(self: *ClientHandshakeSelf) void { + self.state = .failed; + self.context.onError(self); + } + + fn onWriteComplete( + self_opt: ?*ClientHandshakeSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + _: xev.WriteBuffer, + result: xev.WriteError!usize, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + if (self.state != .writing) return .disarm; + + const bytes_written = result catch |err| { + log.debug("onWriteComplete: write error: {s}", .{@errorName(err)}); + self.fail(); + return .disarm; + }; + + self.write_pos += bytes_written; + log.debug("onWriteComplete: wrote {d}/{d} bytes", .{ self.write_pos, self.request_len }); + + if (self.write_pos < self.request_len) { + self.socket.write( + self.loop, + &self.write_completion, + .{ .slice = self.request_buf[self.write_pos..self.request_len] }, + ClientHandshakeSelf, + self, + onWriteComplete, + ); + return .disarm; + } + + // Request fully sent — start reading response + log.debug("onWriteComplete: request fully sent, reading response", .{}); + self.state = .reading; + self.startRead(); + return .disarm; + } + + fn startRead(self: *ClientHandshakeSelf) void { + if (self.read_pos >= self.read_buf.len) { + log.debug("startRead: buffer full ({d} bytes), failing", .{self.read_pos}); + self.fail(); + return; + } + + log.debug("startRead: reading at offset {d}, buf remaining {d}", .{ + self.read_pos, + self.read_buf.len - self.read_pos, + }); + self.socket.read( + self.loop, + &self.read_completion, + .{ .slice = self.read_buf[self.read_pos..] }, + ClientHandshakeSelf, + self, + onReadCallback, + ); + } + + fn onReadCallback( + self_opt: ?*ClientHandshakeSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + _: xev.ReadBuffer, + result: xev.ReadError!usize, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + if (self.state != .reading) return .disarm; + + const bytes_read = result catch |err| { + log.debug("onReadCallback: read error: {s}", .{@errorName(err)}); + self.fail(); + return .disarm; + }; + + if (bytes_read == 0) { + log.debug("onReadCallback: EOF (0 bytes read), total read so far: {d}", .{ + self.read_pos, + }); + self.fail(); + return .disarm; + } + + log.debug("onReadCallback: read {d} bytes, total {d}", .{ + bytes_read, + self.read_pos + bytes_read, + }); + const old_pos = self.read_pos; + self.read_pos += bytes_read; + + // Feed only the new bytes to the incremental head parser. + const new_data = self.read_buf[old_pos..self.read_pos]; + const consumed = self.head_parser.feed(new_data); + if (self.head_parser.state != .finished) { + // Headers not yet complete — keep reading. + self.startRead(); + return .disarm; + } + + self.header_len = old_pos + consumed; + self.processResponse(); + return .disarm; + } + + fn processResponse(self: *ClientHandshakeSelf) void { + http.validateResponse(self.read_buf[0..self.header_len], self.key) catch |err| { + log.debug("processResponse: parse error: {s}", .{@errorName(err)}); + self.fail(); + return; + }; + + self.state = .completed; + log.debug("processResponse: success, header_len={d}, leftover={d}", .{ + self.header_len, + self.read_pos - self.header_len, + }); + self.context.onSuccess(self); + } + }; +} diff --git a/src/rpc/webzockets/src/control_queue.zig b/src/rpc/webzockets/src/control_queue.zig new file mode 100644 index 0000000000..e3602899ff --- /dev/null +++ b/src/rpc/webzockets/src/control_queue.zig @@ -0,0 +1,336 @@ +const std = @import("std"); +const testing = std.testing; +const Opcode = @import("types.zig").Opcode; + +/// A fixed-capacity inline circular buffer queue for WebSocket control frames. +/// +/// Stores variable-length entries of the form `[opcode: u8][length: u8][payload: 0-125 bytes]` +/// in a 256-byte ring buffer with no heap allocation. Used by both client and server +/// connections to queue outbound control frames (close, ping, pong) independently of +/// the data write path. +/// +/// ``` +/// buf (256 bytes, wraps around) +/// ┌──────────────────────────────────────────────────────────────────┐ +/// │ free │◄── entry 1 ──►│◄──── entry 2 ────►│◄─ entry 3 ─►│free │ +/// └──────────────────────────────────────────────────────────────────┘ +/// ▲ ▲ +/// head tail +/// +/// Each entry is variable-length: +/// ┌────────┬────────┬─────────────────────────┐ +/// │ opcode │ length │ payload (0-125 bytes) │ +/// │ (u8) │ (u8) │ │ +/// └────────┴────────┴─────────────────────────┘ +/// ◄─ header_size=2 ─► +/// ◄──────── 2 to 127 bytes total ─────────────► +/// +/// head/tail are u8 and wrap naturally at 256 (the buffer capacity). +/// count (u16) tracks total bytes used, enabling availableSpace = 256 - count. +/// +/// enqueue → writes [opcode][len][payload...] at tail, advances tail +/// dequeue → reads [opcode][len][payload...] at head, advances head +/// +/// Wraparound example (entry spans the buffer boundary): +/// ┌──────────────────────────────────────────────────────────────────┐ +/// │ yload...]│ free │[op][len][pa... │ +/// └──────────────────────────────────────────────────────────────────┘ +/// ▲ ▲ +/// head tail +/// ``` +pub const ControlQueue = struct { + buf: [capacity]u8 = undefined, + head: u8 = 0, + tail: u8 = 0, + count: u16 = 0, + + const capacity = 256; + comptime { + // IMPORTANT: capacity must be 256 for head/tail wraparound logic with u8. + std.debug.assert(capacity == 256); + } + /// Overhead per entry: 1 byte opcode + 1 byte length. + const header_size = 2; + + pub const Entry = struct { + opcode: Opcode, + len: u8, + }; + + pub const Error = error{QueueFull}; + + /// Returns a new empty ControlQueue. + pub fn init() ControlQueue { + return .{}; + } + + /// Appends a control frame entry to the back of the queue. + /// Returns `error.QueueFull` if there is insufficient space for the entry. + /// Payload must be at most 125 bytes (per RFC 6455 control frame limit). + pub fn enqueue(self: *ControlQueue, opcode: Opcode, payload: []const u8) Error!void { + std.debug.assert(payload.len <= 125); + std.debug.assert(opcode.isControl()); + const entry_size: u8 = @intCast(header_size + payload.len); + if (self.availableSpace() < entry_size) return error.QueueFull; + + self.writeByte(@intFromEnum(opcode)); + self.writeByte(@intCast(payload.len)); + for (payload) |b| { + self.writeByte(b); + } + self.count += entry_size; + } + + /// Pops the front entry from the queue, copying its payload into `out_buf`. + /// Returns the entry metadata, or `null` if the queue is empty. + /// `out_buf` must be at least 125 bytes to accommodate any control frame payload. + pub fn dequeue(self: *ControlQueue, out_buf: []u8) ?Entry { + std.debug.assert(out_buf.len >= 125); + if (self.count == 0) return null; + + const opcode_byte = self.readByte(); + const len = self.readByte(); + const entry_size: u8 = header_size + len; + + for (0..len) |i| { + out_buf[i] = self.readByte(); + } + + self.count -= entry_size; + return .{ + .opcode = @enumFromInt(opcode_byte), + .len = len, + }; + } + + /// Inspects the front entry without removing it. + /// Returns the entry metadata, or `null` if the queue is empty. + pub fn peek(self: *const ControlQueue) ?Entry { + if (self.count == 0) return null; + + const opcode_byte = self.buf[self.head]; + const len = self.buf[self.head +% 1]; + return .{ + .opcode = @enumFromInt(opcode_byte), + .len = len, + }; + } + + /// Resets the queue to empty. + pub fn clear(self: *ControlQueue) void { + self.head = 0; + self.tail = 0; + self.count = 0; + } + + /// Returns `true` if the queue contains no entries. + pub fn isEmpty(self: *const ControlQueue) bool { + return self.count == 0; + } + + /// Returns `true` if the front entry is a close frame. + pub fn isNextClose(self: *const ControlQueue) bool { + const entry = self.peek() orelse return false; + return entry.opcode == .close; + } + + /// Returns the number of free bytes in the buffer. + pub fn availableSpace(self: *const ControlQueue) usize { + return capacity - @as(usize, self.count); + } + + // -- internal helpers -- + + fn writeByte(self: *ControlQueue, byte: u8) void { + self.buf[self.tail] = byte; + self.tail +%= 1; + } + + fn readByte(self: *ControlQueue) u8 { + const byte = self.buf[self.head]; + self.head +%= 1; + return byte; + } +}; + +test "enqueue and dequeue single entry — FIFO order" { + var q = ControlQueue.init(); + const payload = "hello"; + try q.enqueue(.ping, payload); + + var out: [125]u8 = undefined; + const entry = q.dequeue(&out).?; + try testing.expectEqual(Opcode.ping, entry.opcode); + try testing.expectEqual(@as(u8, 5), entry.len); + try testing.expectEqualSlices(u8, payload, out[0..entry.len]); + try testing.expect(q.isEmpty()); +} + +test "enqueue and dequeue multiple entries — FIFO order" { + var q = ControlQueue.init(); + try q.enqueue(.ping, "aaa"); + try q.enqueue(.pong, "bb"); + try q.enqueue(.close, "c"); + + var out: [125]u8 = undefined; + + const e1 = q.dequeue(&out).?; + try testing.expectEqual(Opcode.ping, e1.opcode); + try testing.expectEqualSlices(u8, "aaa", out[0..e1.len]); + + const e2 = q.dequeue(&out).?; + try testing.expectEqual(Opcode.pong, e2.opcode); + try testing.expectEqualSlices(u8, "bb", out[0..e2.len]); + + const e3 = q.dequeue(&out).?; + try testing.expectEqual(Opcode.close, e3.opcode); + try testing.expectEqualSlices(u8, "c", out[0..e3.len]); + + try testing.expect(q.isEmpty()); +} + +test "wraparound — dequeue some, enqueue more that wraps around buffer" { + var q = ControlQueue.init(); + var out: [125]u8 = undefined; + + // Fill most of the buffer with a large payload (header_size + 120 = 122 bytes). + const large = "x" ** 120; + try q.enqueue(.ping, large); + try q.enqueue(.pong, large); // 244 bytes used + + // Dequeue first to free space at the front. + const e1 = q.dequeue(&out).?; + try testing.expectEqual(Opcode.ping, e1.opcode); + try testing.expectEqualSlices(u8, large, out[0..e1.len]); + // 122 bytes free at front, tail is at 244. + + // Enqueue an entry that wraps around the end of the buffer. + const wrap_payload = "wrap-around!"; + try q.enqueue(.ping, wrap_payload); + + // Dequeue the pong, then the wrapped entry. + const e2 = q.dequeue(&out).?; + try testing.expectEqual(Opcode.pong, e2.opcode); + + const e3 = q.dequeue(&out).?; + try testing.expectEqual(Opcode.ping, e3.opcode); + try testing.expectEqualSlices(u8, wrap_payload, out[0..e3.len]); + try testing.expect(q.isEmpty()); +} + +test "full queue returns error.QueueFull" { + var q = ControlQueue.init(); + + // Each entry with 125-byte payload takes 127 bytes. Two fit in 256 (254 bytes). + const max_payload = "M" ** 125; + try q.enqueue(.ping, max_payload); // 127 bytes + try q.enqueue(.pong, max_payload); // 127 bytes → 254 used, 2 free + + // Even an empty-payload entry needs 2 bytes — exactly fits. + try q.enqueue(.ping, ""); + + // Now 0 bytes free — any enqueue should fail. + try testing.expectError(error.QueueFull, q.enqueue(.ping, "")); +} + +test "empty dequeue returns null" { + var q = ControlQueue.init(); + var out: [125]u8 = undefined; + try testing.expect(q.dequeue(&out) == null); +} + +test "variable-length entries — empty, small, max 125-byte payload" { + var q = ControlQueue.init(); + var out: [125]u8 = undefined; + + try q.enqueue(.pong, ""); + try q.enqueue(.ping, "hi"); + + const e1 = q.dequeue(&out).?; + try testing.expectEqual(@as(u8, 0), e1.len); + + const e2 = q.dequeue(&out).?; + try testing.expectEqual(@as(u8, 2), e2.len); + try testing.expectEqualSlices(u8, "hi", out[0..e2.len]); + + // Max payload in isolation. + const max_payload = "Z" ** 125; + try q.enqueue(.close, max_payload); + const e3 = q.dequeue(&out).?; + try testing.expectEqual(Opcode.close, e3.opcode); + try testing.expectEqual(@as(u8, 125), e3.len); + try testing.expectEqualSlices(u8, max_payload, out[0..e3.len]); +} + +test "clear resets to empty" { + var q = ControlQueue.init(); + try q.enqueue(.ping, "data"); + try q.enqueue(.pong, "more"); + try testing.expect(!q.isEmpty()); + + q.clear(); + try testing.expect(q.isEmpty()); + try testing.expectEqual(@as(usize, 256), q.availableSpace()); + + var out: [125]u8 = undefined; + try testing.expect(q.dequeue(&out) == null); +} + +test "isNextClose detection" { + var q = ControlQueue.init(); + try testing.expect(!q.isNextClose()); + + try q.enqueue(.ping, ""); + try testing.expect(!q.isNextClose()); + + // Dequeue the ping, then enqueue a close. + var out: [125]u8 = undefined; + const e1 = q.dequeue(&out).?; + try testing.expectEqual(Opcode.ping, e1.opcode); + try testing.expectEqual(@as(u8, 0), e1.len); + + try q.enqueue(.close, &[_]u8{ 0x03, 0xE8 }); + try testing.expect(q.isNextClose()); +} + +test "peek returns null on empty queue" { + var q = ControlQueue.init(); + try testing.expect(q.peek() == null); +} + +test "peek does not consume the entry" { + var q = ControlQueue.init(); + try q.enqueue(.pong, "peek-test"); + + // Peek twice — should return the same entry each time. + const p1 = q.peek().?; + const p2 = q.peek().?; + try testing.expectEqual(p1.opcode, p2.opcode); + try testing.expectEqual(p1.len, p2.len); + + // Queue should still have exactly one entry. + try testing.expect(!q.isEmpty()); + + var out: [125]u8 = undefined; + const entry = q.dequeue(&out).?; + try testing.expectEqual(Opcode.pong, entry.opcode); + try testing.expectEqualSlices(u8, "peek-test", out[0..entry.len]); + try testing.expect(q.isEmpty()); +} + +test "availableSpace tracks usage correctly" { + var q = ControlQueue.init(); + try testing.expectEqual(@as(usize, 256), q.availableSpace()); + + try q.enqueue(.ping, "abc"); // 2 + 3 = 5 bytes + try testing.expectEqual(@as(usize, 251), q.availableSpace()); + + try q.enqueue(.pong, ""); // 2 bytes + try testing.expectEqual(@as(usize, 249), q.availableSpace()); + + var out: [125]u8 = undefined; + const e1 = q.dequeue(&out).?; // free 5 bytes + try testing.expectEqual(Opcode.ping, e1.opcode); + try testing.expectEqualSlices(u8, "abc", out[0..e1.len]); + try testing.expectEqual(@as(usize, 254), q.availableSpace()); +} diff --git a/src/rpc/webzockets/src/frame.zig b/src/rpc/webzockets/src/frame.zig new file mode 100644 index 0000000000..64f3dfe989 --- /dev/null +++ b/src/rpc/webzockets/src/frame.zig @@ -0,0 +1,1091 @@ +const std = @import("std"); + +const types = @import("types.zig"); +const mask_mod = @import("mask.zig"); + +const Opcode = types.Opcode; + +/// Errors that can occur during frame parsing or validation. +pub const FrameError = error{ + InsufficientData, + InvalidOpcode, + ReservedFlags, + FragmentedControlFrame, + OversizedControlFrame, + MaskRequired, + MaskForbidden, + PayloadTooLarge, +}; + +/// Parsed WebSocket frame header. +/// Represents the decoded header fields from the wire format without the payload. +pub const FrameHeader = struct { + fin: bool, + rsv1: bool, + rsv2: bool, + rsv3: bool, + opcode: Opcode, + masked: bool, + payload_len: u64, + mask_key: [4]u8, + /// Total byte length of the header on the wire (2-14 bytes). + header_len: u8, + + /// Returns total frame length on the wire (header bytes + payload bytes). + pub fn totalLen(self: FrameHeader) u64 { + return @as(u64, self.header_len) + self.payload_len; + } + + /// Unmasks the payload in-place using the mask key from the header. + /// No-op if the frame is not masked. + pub fn unmaskPayload(self: FrameHeader, payload: []u8) void { + if (self.masked) { + mask_mod.mask(self.mask_key, payload); + } + } + + /// Validates the frame header per RFC 6455 requirements. + /// + /// Checks: + /// - RSV bits must all be zero (compression not yet supported) -> `ReservedFlags` + /// - Control frames (opcode >= 0x8) payload must be <= 125 bytes -> `OversizedControlFrame` + /// - Control frames must have FIN set -> `FragmentedControlFrame` + pub fn validate(self: FrameHeader) FrameError!void { + // RSV bits must be zero (no extensions negotiated) + if (self.rsv1 or self.rsv2 or self.rsv3) { + return FrameError.ReservedFlags; + } + + // Control frame checks + if (self.opcode.isControl()) { + if (self.payload_len > 125) { + return FrameError.OversizedControlFrame; + } + if (!self.fin) { + return FrameError.FragmentedControlFrame; + } + } + } + + /// Validates that the frame is properly masked for client-to-server transmission. + /// Per RFC 6455 Section 5.1, all client-to-server frames MUST be masked. + pub fn validateServerBound(self: FrameHeader) FrameError!void { + if (!self.masked) { + return FrameError.MaskRequired; + } + } + + /// Validates that the frame is not masked, as required for server-to-client transmission. + /// Per RFC 6455 Section 5.1, server-to-client frames MUST NOT be masked. + pub fn validateClientBound(self: FrameHeader) FrameError!void { + if (self.masked) { + return FrameError.MaskForbidden; + } + } +}; + +/// Parse a WebSocket frame header from raw bytes per RFC 6455 Section 5.2. +/// +/// Returns the decoded `FrameHeader` with all fields populated including the +/// computed `header_len` (total bytes consumed by the header on the wire). +/// +/// Returns `FrameError.InsufficientData` if `data` does not contain enough +/// bytes to fully decode the header. +/// Returns `FrameError.InvalidOpcode` if the opcode is a reserved value. +pub fn parseHeader(data: []const u8) FrameError!FrameHeader { + if (data.len < 2) return FrameError.InsufficientData; + + const byte0 = data[0]; + const byte1 = data[1]; + + // byte0: [FIN | RSV1 | RSV2 | RSV3 | OP3 | OP2 | OP1 | OP0] + const fin = (byte0 & 0x80) != 0; + const rsv1 = (byte0 & 0x40) != 0; + const rsv2 = (byte0 & 0x20) != 0; + const rsv3 = (byte0 & 0x10) != 0; + const raw_opcode: u4 = @truncate(byte0); // low 4 bits + const opcode = std.meta.intToEnum(Opcode, raw_opcode) catch return FrameError.InvalidOpcode; + + // byte1: [MASK | LEN6 | LEN5 | LEN4 | LEN3 | LEN2 | LEN1 | LEN0] + const masked = (byte1 & 0x80) != 0; + const len7: u7 = @truncate(byte1); // low 7 bits + + // Calculate header length and decode payload length + var header_len: u8 = 2; + var payload_len: u64 = undefined; + + if (len7 <= 125) { + payload_len = len7; + } else if (len7 == 126) { + header_len += 2; + if (data.len < header_len) return FrameError.InsufficientData; + payload_len = std.mem.readInt(u16, data[2..4], .big); + // NOTE: RFC 6455 requires minimal encoding (e.g. payload_len > 125 here). + // We are currently permissive and allow non-minimal encodings. + } else { + // len7 == 127 + header_len += 8; + if (data.len < header_len) return FrameError.InsufficientData; + payload_len = std.mem.readInt(u64, data[2..10], .big); + // NOTE: RFC 6455 requires the most significant bit to be 0 and minimal + // encoding (payload_len > 65535). We are currently permissive. + } + + var mask_key: [4]u8 = .{ 0, 0, 0, 0 }; + if (masked) { + const mask_start = header_len; + header_len += 4; + if (data.len < header_len) return FrameError.InsufficientData; + @memcpy(&mask_key, data[mask_start..][0..4]); + } + + return FrameHeader{ + .fin = fin, + .rsv1 = rsv1, + .rsv2 = rsv2, + .rsv3 = rsv3, + .opcode = opcode, + .masked = masked, + .payload_len = payload_len, + .mask_key = mask_key, + .header_len = header_len, + }; +} + +/// Writes a server-to-client frame header into `buf`. +/// +/// Server frames are always unmasked with FIN=1. Sets RSV1 if `compressed` is true. +/// Returns the slice of `buf` that was written (2, 4, or 10 bytes). +/// +/// The caller must ensure `buf` is at least 10 bytes long. +pub fn writeFrameHeader( + buf: []u8, + opcode: Opcode, + payload_len: u64, + compressed: bool, +) []u8 { + std.debug.assert(buf.len >= 10); + return writeHeader(buf, opcode, payload_len, compressed, null); +} + +/// Writes a client-to-server (masked) frame header into `buf`. +/// +/// Client frames always have MASK=1 and FIN=1. Sets RSV1 if `compressed` is true. +/// Returns the slice of `buf` that was written (6, 8, or 14 bytes). +/// +/// The caller must ensure `buf` is at least 14 bytes long. +pub fn writeClientFrameHeader( + buf: []u8, + opcode: Opcode, + payload_len: u64, + mask_key: [4]u8, + compressed: bool, +) []u8 { + std.debug.assert(buf.len >= 14); + return writeHeader(buf, opcode, payload_len, compressed, mask_key); +} + +/// Writes a WebSocket frame header into `buf`. When `mask_key` is non-null, +/// the MASK bit is set and the 4-byte key is appended after the length bytes. +/// Always sets FIN=1. Sets RSV1 when `compressed` is true and opcode is not a control frame. +fn writeHeader( + buf: []u8, + opcode: Opcode, + payload_len: u64, + compressed: bool, + mask_key: ?[4]u8, +) []u8 { + // Byte 0: FIN=1, RSV bits, opcode + var byte0: u8 = 0x80; // FIN bit set + if (compressed and !opcode.isControl()) byte0 |= 0x40; // RSV1 bit + byte0 |= @intFromEnum(opcode); + buf[0] = byte0; + + // Byte 1+: MASK flag, payload length + const mask_bit: u8 = if (mask_key != null) 0x80 else 0; + + const len_end: usize = blk: { + if (payload_len <= 125) { + buf[1] = mask_bit | @as(u8, @truncate(payload_len)); + break :blk 2; + } else if (payload_len <= 65535) { + buf[1] = mask_bit | 126; + std.mem.writeInt(u16, buf[2..4], @truncate(payload_len), .big); + break :blk 4; + } else { + buf[1] = mask_bit | 127; + std.mem.writeInt(u64, buf[2..10], payload_len, .big); + break :blk 10; + } + }; + + // Append mask key if present + if (mask_key) |key| { + @memcpy(buf[len_end..][0..4], &key); + return buf[0 .. len_end + 4]; + } + + return buf[0..len_end]; +} + +/// Comptime: returns the total frame size for a server-to-client message of +/// the given payload length (header bytes + payload, no mask). +pub fn calculateFrameLen(comptime msg_len: usize) usize { + if (msg_len <= 125) { + return 2 + msg_len; + } else if (msg_len <= 65535) { + return 4 + msg_len; + } else { + return 10 + msg_len; + } +} + +/// Comptime: returns a complete pre-framed WebSocket message as a fixed-size array. +/// +/// The frame has FIN=1, the given opcode, and appropriate length encoding. +/// Server frames are never masked. +/// +/// Usage: +/// ``` +/// const framed = comptime frame.frameStatic(.text, "Hello"); +/// ``` +pub fn frameStatic( + comptime opcode: Opcode, + comptime msg: []const u8, +) [calculateFrameLen(msg.len)]u8 { + comptime { + const total_len = calculateFrameLen(msg.len); + var buf: [total_len]u8 = undefined; + + // Byte 0: FIN=1, no RSV, opcode + buf[0] = 0x80 | @as(u8, @intFromEnum(opcode)); + + // Encode length and copy payload + if (msg.len <= 125) { + buf[1] = @truncate(msg.len); + @memcpy(buf[2..], msg); + } else if (msg.len <= 65535) { + buf[1] = 126; + buf[2] = @truncate(msg.len >> 8); + buf[3] = @truncate(msg.len); + @memcpy(buf[4..], msg); + } else { + buf[1] = 127; + // 8 bytes big-endian length + for (0..8) |i| { + buf[2 + i] = @truncate(msg.len >> @intCast((7 - i) * 8)); + } + @memcpy(buf[10..], msg); + } + + return buf; + } +} + +const testing = std.testing; + +test "parseHeader: 7-bit payload length (unmasked)" { + // FIN=1, opcode=text(1), MASK=0, len=5 + const data = [_]u8{ 0x81, 0x05 } ++ [_]u8{ 'H', 'e', 'l', 'l', 'o' }; + const header = try parseHeader(&data); + + try testing.expect(header.fin); + try testing.expect(!header.rsv1); + try testing.expect(!header.rsv2); + try testing.expect(!header.rsv3); + try testing.expectEqual(Opcode.text, header.opcode); + try testing.expect(!header.masked); + try testing.expectEqual(@as(u64, 5), header.payload_len); + try testing.expectEqual(@as(u8, 2), header.header_len); +} + +test "parseHeader: 7-bit payload length (masked)" { + // FIN=1, opcode=text(1), MASK=1, len=5, mask key=0x37FA213D + const data = [_]u8{ 0x81, 0x85, 0x37, 0xFA, 0x21, 0x3D, 0x7F, 0x9F, 0x4D, 0x51, 0x58 }; + const header = try parseHeader(&data); + + try testing.expect(header.fin); + try testing.expectEqual(Opcode.text, header.opcode); + try testing.expect(header.masked); + try testing.expectEqual(@as(u64, 5), header.payload_len); + try testing.expectEqualSlices(u8, &[_]u8{ 0x37, 0xFA, 0x21, 0x3D }, &header.mask_key); + try testing.expectEqual(@as(u8, 6), header.header_len); +} + +test "parseHeader: 16-bit payload length" { + // FIN=1, opcode=binary(2), MASK=0, len=256 (126 + 2-byte BE) + var data: [4]u8 = undefined; + data[0] = 0x82; // FIN=1, binary + data[1] = 126; // 16-bit extended length + std.mem.writeInt(u16, data[2..4], 256, .big); + + const header = try parseHeader(&data); + + try testing.expect(header.fin); + try testing.expectEqual(Opcode.binary, header.opcode); + try testing.expect(!header.masked); + try testing.expectEqual(@as(u64, 256), header.payload_len); + try testing.expectEqual(@as(u8, 4), header.header_len); +} + +test "parseHeader: 16-bit payload length (masked)" { + // FIN=1, opcode=text(1), MASK=1, len=300 + var data: [8]u8 = undefined; + data[0] = 0x81; // FIN=1, text + data[1] = 0xFE; // MASK=1, 126 -> 16-bit extended length + std.mem.writeInt(u16, data[2..4], 300, .big); + data[4] = 0xAA; + data[5] = 0xBB; + data[6] = 0xCC; + data[7] = 0xDD; + + const header = try parseHeader(&data); + + try testing.expect(header.fin); + try testing.expectEqual(Opcode.text, header.opcode); + try testing.expect(header.masked); + try testing.expectEqual(@as(u64, 300), header.payload_len); + try testing.expectEqualSlices(u8, &[_]u8{ 0xAA, 0xBB, 0xCC, 0xDD }, &header.mask_key); + try testing.expectEqual(@as(u8, 8), header.header_len); +} + +test "parseHeader: 64-bit payload length" { + // FIN=1, opcode=binary(2), MASK=0, len=70000 (127 + 8-byte BE) + var data: [10]u8 = undefined; + data[0] = 0x82; // FIN=1, binary + data[1] = 127; // 64-bit extended length + std.mem.writeInt(u64, data[2..10], 70000, .big); + + const header = try parseHeader(&data); + + try testing.expect(header.fin); + try testing.expectEqual(Opcode.binary, header.opcode); + try testing.expect(!header.masked); + try testing.expectEqual(@as(u64, 70000), header.payload_len); + try testing.expectEqual(@as(u8, 10), header.header_len); +} + +test "parseHeader: 64-bit payload length (masked)" { + // FIN=1, opcode=binary(2), MASK=1, len=70000 + var data: [14]u8 = undefined; + data[0] = 0x82; // FIN=1, binary + data[1] = 0xFF; // MASK=1, 127 -> 64-bit extended length + std.mem.writeInt(u64, data[2..10], 70000, .big); + data[10] = 0x11; + data[11] = 0x22; + data[12] = 0x33; + data[13] = 0x44; + + const header = try parseHeader(&data); + + try testing.expect(header.fin); + try testing.expectEqual(Opcode.binary, header.opcode); + try testing.expect(header.masked); + try testing.expectEqual(@as(u64, 70000), header.payload_len); + try testing.expectEqualSlices(u8, &[_]u8{ 0x11, 0x22, 0x33, 0x44 }, &header.mask_key); + try testing.expectEqual(@as(u8, 14), header.header_len); +} + +test "parseHeader: continuation frame (FIN=0)" { + // FIN=0, opcode=continuation(0), MASK=0, len=10 + const data = [_]u8{ 0x00, 0x0A }; + const header = try parseHeader(&data); + + try testing.expect(!header.fin); + try testing.expectEqual(Opcode.continuation, header.opcode); + try testing.expectEqual(@as(u64, 10), header.payload_len); +} + +test "parseHeader: control frames (close, ping, pong)" { + // Close frame: FIN=1, opcode=0x8, MASK=0, len=2 + const close_data = [_]u8{ 0x88, 0x02 }; + const close_header = try parseHeader(&close_data); + try testing.expectEqual(Opcode.close, close_header.opcode); + try testing.expect(close_header.fin); + try testing.expectEqual(@as(u64, 2), close_header.payload_len); + + // Ping frame: FIN=1, opcode=0x9, MASK=0, len=0 + const ping_data = [_]u8{ 0x89, 0x00 }; + const ping_header = try parseHeader(&ping_data); + try testing.expectEqual(Opcode.ping, ping_header.opcode); + + // Pong frame: FIN=1, opcode=0xA, MASK=0, len=0 + const pong_data = [_]u8{ 0x8A, 0x00 }; + const pong_header = try parseHeader(&pong_data); + try testing.expectEqual(Opcode.pong, pong_header.opcode); +} + +test "parseHeader: InsufficientData for empty input" { + const data = [_]u8{}; + try testing.expectError(FrameError.InsufficientData, parseHeader(&data)); +} + +test "parseHeader: InsufficientData for single byte" { + const data = [_]u8{0x81}; + try testing.expectError(FrameError.InsufficientData, parseHeader(&data)); +} + +test "parseHeader: InsufficientData for truncated 16-bit length" { + // Needs 4 bytes for header but only 3 provided + const data = [_]u8{ 0x81, 126, 0x01 }; + try testing.expectError(FrameError.InsufficientData, parseHeader(&data)); +} + +test "parseHeader: InsufficientData for truncated 64-bit length" { + // Needs 10 bytes for header but only 5 provided + const data = [_]u8{ 0x81, 127, 0x00, 0x00, 0x00 }; + try testing.expectError(FrameError.InsufficientData, parseHeader(&data)); +} + +test "parseHeader: InsufficientData for truncated mask key" { + // FIN=1, text, MASK=1, len=5 but only 2 bytes of mask key + const data = [_]u8{ 0x81, 0x85, 0x37, 0xFA }; + try testing.expectError(FrameError.InsufficientData, parseHeader(&data)); +} + +test "parseHeader: InvalidOpcode for reserved non-control opcodes (3-7)" { + const reserved_opcodes = [_]u8{ 3, 4, 5, 6, 7 }; + for (reserved_opcodes) |op| { + const data = [_]u8{ 0x80 | op, 0x00 }; + try testing.expectError(FrameError.InvalidOpcode, parseHeader(&data)); + } +} + +test "parseHeader: InvalidOpcode for reserved control opcodes (0xB-0xF)" { + const reserved_opcodes = [_]u8{ 0xB, 0xC, 0xD, 0xE, 0xF }; + for (reserved_opcodes) |op| { + const data = [_]u8{ 0x80 | op, 0x00 }; + try testing.expectError(FrameError.InvalidOpcode, parseHeader(&data)); + } +} + +test "parseHeader: zero payload length" { + // FIN=1, ping, MASK=0, len=0 + const data = [_]u8{ 0x89, 0x00 }; + const header = try parseHeader(&data); + try testing.expectEqual(@as(u64, 0), header.payload_len); + try testing.expectEqual(@as(u8, 2), header.header_len); +} + +test "parseHeader: max 7-bit payload length (125)" { + const data = [_]u8{ 0x81, 125 }; + const header = try parseHeader(&data); + try testing.expectEqual(@as(u64, 125), header.payload_len); + try testing.expectEqual(@as(u8, 2), header.header_len); +} + +test "FrameHeader.totalLen: basic correctness" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = false, + .payload_len = 100, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 2, + }; + try testing.expectEqual(@as(u64, 102), header.totalLen()); +} + +test "FrameHeader.totalLen: with 16-bit length" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .binary, + .masked = false, + .payload_len = 1000, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 4, + }; + try testing.expectEqual(@as(u64, 1004), header.totalLen()); +} + +test "FrameHeader.totalLen: masked with 64-bit length" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .binary, + .masked = true, + .payload_len = 70000, + .mask_key = .{ 0x11, 0x22, 0x33, 0x44 }, + .header_len = 14, + }; + try testing.expectEqual(@as(u64, 70014), header.totalLen()); +} + +test "FrameHeader.unmaskPayload: masks payload when header.masked is true" { + var payload = [_]u8{ 0x48, 0x65, 0x6c, 0x6c, 0x6f }; // "Hello" + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = true, + .payload_len = 5, + .mask_key = .{ 0x37, 0xFA, 0x21, 0x3D }, + .header_len = 6, + }; + + header.unmaskPayload(&payload); + // "Hello" XOR'd with mask key + try testing.expectEqualSlices(u8, &[_]u8{ 0x7F, 0x9F, 0x4D, 0x51, 0x58 }, &payload); +} + +test "FrameHeader.unmaskPayload: no-op when header.masked is false" { + const original = [_]u8{ 0x48, 0x65, 0x6c, 0x6c, 0x6f }; + var payload = original; + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = false, + .payload_len = 5, + .mask_key = .{ 0x37, 0xFA, 0x21, 0x3D }, + .header_len = 2, + }; + + header.unmaskPayload(&payload); + try testing.expectEqualSlices(u8, &original, &payload); +} + +test "FrameHeader.unmaskPayload: round-trip mask and unmask" { + const original = "Hello, WebSocket!"; + var payload: [original.len]u8 = undefined; + @memcpy(&payload, original); + + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = true, + .payload_len = original.len, + .mask_key = .{ 0xAB, 0xCD, 0xEF, 0x01 }, + .header_len = 6, + }; + + // Mask + header.unmaskPayload(&payload); + try testing.expect(!std.mem.eql(u8, &payload, original)); + // Unmask (same operation) + header.unmaskPayload(&payload); + try testing.expectEqualSlices(u8, original, &payload); +} + +test "FrameHeader.validate: valid data frame passes" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = true, + .payload_len = 100, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try header.validate(); +} + +test "FrameHeader.validate: valid continuation frame (FIN=0) passes" { + const header = FrameHeader{ + .fin = false, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .continuation, + .masked = true, + .payload_len = 50, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try header.validate(); +} + +test "FrameHeader.validate: RSV1 set rejects with ReservedFlags" { + const header = FrameHeader{ + .fin = true, + .rsv1 = true, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = true, + .payload_len = 5, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try testing.expectError(FrameError.ReservedFlags, header.validate()); +} + +test "FrameHeader.validate: RSV2 set rejects with ReservedFlags" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = true, + .rsv3 = false, + .opcode = .text, + .masked = true, + .payload_len = 5, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try testing.expectError(FrameError.ReservedFlags, header.validate()); +} + +test "FrameHeader.validate: RSV3 set rejects with ReservedFlags" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = true, + .opcode = .text, + .masked = true, + .payload_len = 5, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try testing.expectError(FrameError.ReservedFlags, header.validate()); +} + +test "FrameHeader.validate: oversized control frame rejects" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .close, + .masked = true, + .payload_len = 126, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try testing.expectError(FrameError.OversizedControlFrame, header.validate()); +} + +test "FrameHeader.validate: control frame at max size (125) passes" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .ping, + .masked = true, + .payload_len = 125, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try header.validate(); +} + +test "FrameHeader.validate: fragmented control frame rejects" { + const header = FrameHeader{ + .fin = false, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .ping, + .masked = true, + .payload_len = 5, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try testing.expectError(FrameError.FragmentedControlFrame, header.validate()); +} + +test "FrameHeader.validate: data frame with FIN=0 is allowed" { + const header = FrameHeader{ + .fin = false, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = true, + .payload_len = 1000, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 6, + }; + try header.validate(); +} + +test "FrameHeader.validateServerBound: masked frame passes" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = true, + .payload_len = 5, + .mask_key = .{ 0x37, 0xFA, 0x21, 0x3D }, + .header_len = 6, + }; + try header.validateServerBound(); +} + +test "FrameHeader.validateServerBound: unmasked frame rejects with MaskRequired" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = false, + .payload_len = 5, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 2, + }; + try testing.expectError(FrameError.MaskRequired, header.validateServerBound()); +} + +test "writeFrameHeader: 7-bit, 16-bit, and 64-bit payload lengths" { + var buf: [10]u8 = undefined; + + // 7-bit: payload fits in single length byte + const r7 = writeFrameHeader(&buf, .text, 5, false); + try testing.expectEqual(@as(usize, 2), r7.len); + try testing.expectEqual(@as(u8, 0x81), r7[0]); // FIN=1, text + try testing.expectEqual(@as(u8, 5), r7[1]); // len=5, MASK=0 + + // 16-bit: extended length + const r16 = writeFrameHeader(&buf, .binary, 300, false); + try testing.expectEqual(@as(usize, 4), r16.len); + try testing.expectEqual(@as(u8, 0x82), r16[0]); // FIN=1, binary + try testing.expectEqual(@as(u8, 126), r16[1]); + try testing.expectEqual(@as(u16, 300), std.mem.readInt(u16, r16[2..4], .big)); + + // 64-bit: extended length + const r64 = writeFrameHeader(&buf, .binary, 70000, false); + try testing.expectEqual(@as(usize, 10), r64.len); + try testing.expectEqual(@as(u8, 0x82), r64[0]); // FIN=1, binary + try testing.expectEqual(@as(u8, 127), r64[1]); + try testing.expectEqual(@as(u64, 70000), std.mem.readInt(u64, r64[2..10], .big)); +} + +test "writeFrameHeader: compressed flag and close opcode" { + var buf: [10]u8 = undefined; + + const compressed = writeFrameHeader(&buf, .text, 10, true); + try testing.expectEqual(@as(u8, 0xC1), compressed[0]); // FIN=1, RSV1=1, text + + const close = writeFrameHeader(&buf, .close, 2, false); + try testing.expectEqual(@as(usize, 2), close.len); + try testing.expectEqual(@as(u8, 0x88), close[0]); // FIN=1, close + try testing.expectEqual(@as(u8, 2), close[1]); +} + +test "writeFrameHeader: compressed flag ignored for control opcodes" { + var buf: [10]u8 = undefined; + + // Control frames with compressed=true should NOT have RSV1 bit set + const close = writeFrameHeader(&buf, .close, 0, true); + try testing.expectEqual(@as(u8, 0x88), close[0]); // FIN=1, opcode=8 (NOT 0xC8) + + const ping = writeFrameHeader(&buf, .ping, 0, true); + try testing.expectEqual(@as(u8, 0x89), ping[0]); // FIN=1, opcode=9 (NOT 0xC9) + + // Data frame with compressed=true SHOULD have RSV1 bit set + const text = writeFrameHeader(&buf, .text, 0, true); + try testing.expectEqual(@as(u8, 0xC1), text[0]); // FIN=1, RSV1=1, opcode=1 +} + +test "writeFrameHeader: length tier boundaries (125/126, 65535/65536)" { + var buf: [10]u8 = undefined; + + // 125 → 7-bit tier (last value in 7-bit range) + const r125 = writeFrameHeader(&buf, .text, 125, false); + try testing.expectEqual(@as(usize, 2), r125.len); + try testing.expectEqual(@as(u8, 125), r125[1]); + + // 126 → 16-bit tier (first value in 16-bit range) + const r126 = writeFrameHeader(&buf, .text, 126, false); + try testing.expectEqual(@as(usize, 4), r126.len); + try testing.expectEqual(@as(u8, 126), r126[1]); + try testing.expectEqual(@as(u16, 126), std.mem.readInt(u16, r126[2..4], .big)); + + // 65535 → 16-bit tier (last value in 16-bit range) + const r65535 = writeFrameHeader(&buf, .text, 65535, false); + try testing.expectEqual(@as(usize, 4), r65535.len); + try testing.expectEqual(@as(u16, 65535), std.mem.readInt(u16, r65535[2..4], .big)); + + // 65536 → 64-bit tier (first value in 64-bit range) + const r65536 = writeFrameHeader(&buf, .text, 65536, false); + try testing.expectEqual(@as(usize, 10), r65536.len); + try testing.expectEqual(@as(u8, 127), r65536[1]); + try testing.expectEqual(@as(u64, 65536), std.mem.readInt(u64, r65536[2..10], .big)); +} + +test "writeFrameHeader + parseHeader round-trip: all length tiers and compressed" { + const Case = struct { opcode: Opcode, payload_len: u64, compressed: bool }; + const cases = [_]Case{ + .{ .opcode = .text, .payload_len = 42, .compressed = false }, + .{ .opcode = .binary, .payload_len = 1000, .compressed = false }, + .{ .opcode = .binary, .payload_len = 100000, .compressed = false }, + .{ .opcode = .text, .payload_len = 50, .compressed = true }, + }; + + for (cases) |c| { + var buf: [10]u8 = undefined; + const written = writeFrameHeader(&buf, c.opcode, c.payload_len, c.compressed); + const header = try parseHeader(written); + + try testing.expect(header.fin); + try testing.expectEqual(c.opcode, header.opcode); + try testing.expect(!header.masked); + try testing.expectEqual(c.payload_len, header.payload_len); + try testing.expectEqual(c.compressed, header.rsv1); + try testing.expectEqual(@as(u8, @intCast(written.len)), header.header_len); + } +} + +test "calculateFrameLen: 7-bit tier" { + try testing.expectEqual(@as(usize, 2 + 0), calculateFrameLen(0)); + try testing.expectEqual(@as(usize, 2 + 1), calculateFrameLen(1)); + try testing.expectEqual(@as(usize, 2 + 125), calculateFrameLen(125)); +} + +test "calculateFrameLen: 16-bit tier" { + try testing.expectEqual(@as(usize, 4 + 126), calculateFrameLen(126)); + try testing.expectEqual(@as(usize, 4 + 65535), calculateFrameLen(65535)); +} + +test "calculateFrameLen: 64-bit tier" { + try testing.expectEqual(@as(usize, 10 + 65536), calculateFrameLen(65536)); +} + +test "frameStatic: text message 'Hello'" { + const framed = comptime frameStatic(.text, "Hello"); + try testing.expectEqual(@as(usize, 7), framed.len); // 2 header + 5 payload + try testing.expectEqual(@as(u8, 0x81), framed[0]); // FIN=1, text + try testing.expectEqual(@as(u8, 5), framed[1]); // len=5 + try testing.expectEqualSlices(u8, "Hello", framed[2..]); +} + +test "frameStatic: empty payload" { + const framed = comptime frameStatic(.ping, ""); + try testing.expectEqual(@as(usize, 2), framed.len); + try testing.expectEqual(@as(u8, 0x89), framed[0]); // FIN=1, ping + try testing.expectEqual(@as(u8, 0), framed[1]); +} + +test "frameStatic: close frame with status code bytes" { + const framed = comptime frameStatic(.close, &[_]u8{ 0x03, 0xE8 }); // 1000 big-endian + try testing.expectEqual(@as(usize, 4), framed.len); + try testing.expectEqual(@as(u8, 0x88), framed[0]); // FIN=1, close + try testing.expectEqual(@as(u8, 2), framed[1]); + try testing.expectEqual(@as(u8, 0x03), framed[2]); + try testing.expectEqual(@as(u8, 0xE8), framed[3]); +} + +test "frameStatic: binary frame" { + const framed = comptime frameStatic(.binary, &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }); + try testing.expectEqual(@as(usize, 6), framed.len); + try testing.expectEqual(@as(u8, 0x82), framed[0]); // FIN=1, binary + try testing.expectEqual(@as(u8, 4), framed[1]); + try testing.expectEqualSlices(u8, &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }, framed[2..]); +} + +test "frameStatic: verify parseHeader can parse the result" { + const framed = comptime frameStatic(.text, "Hello, World!"); + const header = try parseHeader(&framed); + try testing.expect(header.fin); + try testing.expect(!header.rsv1); + try testing.expectEqual(Opcode.text, header.opcode); + try testing.expect(!header.masked); + try testing.expectEqual(@as(u64, 13), header.payload_len); + try testing.expectEqual(@as(u8, 2), header.header_len); + try testing.expectEqualSlices(u8, "Hello, World!", framed[header.header_len..]); +} + +test "frameStatic: 126-byte payload uses 16-bit length" { + const payload = [_]u8{0x42} ** 126; + const framed = comptime frameStatic(.text, &payload); + try testing.expectEqual(@as(usize, 4 + 126), framed.len); + try testing.expectEqual(@as(u8, 0x81), framed[0]); // FIN=1, text + try testing.expectEqual(@as(u8, 126), framed[1]); // 16-bit extended + try testing.expectEqual(@as(u16, 126), std.mem.readInt(u16, framed[2..4], .big)); + // Verify payload content + for (framed[4..]) |b| { + try testing.expectEqual(@as(u8, 0x42), b); + } +} + +test "frameStatic: pong frame" { + const framed = comptime frameStatic(.pong, "pong-data"); + try testing.expectEqual(@as(u8, 0x8A), framed[0]); // FIN=1, pong + try testing.expectEqual(@as(u8, 9), framed[1]); + try testing.expectEqualSlices(u8, "pong-data", framed[2..]); +} + +// --- validateClientBound tests --- + +test "FrameHeader.validateClientBound: unmasked frame passes" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = false, + .payload_len = 5, + .mask_key = .{ 0, 0, 0, 0 }, + .header_len = 2, + }; + try header.validateClientBound(); +} + +test "FrameHeader.validateClientBound: masked frame rejects with MaskForbidden" { + const header = FrameHeader{ + .fin = true, + .rsv1 = false, + .rsv2 = false, + .rsv3 = false, + .opcode = .text, + .masked = true, + .payload_len = 5, + .mask_key = .{ 0x37, 0xFA, 0x21, 0x3D }, + .header_len = 6, + }; + try testing.expectError(FrameError.MaskForbidden, header.validateClientBound()); +} + +// --- writeClientFrameHeader tests --- + +test "writeClientFrameHeader: 7-bit, 16-bit, and 64-bit payload lengths" { + const mask_key = [_]u8{ 0xAA, 0xBB, 0xCC, 0xDD }; + var buf: [14]u8 = undefined; + + // 7-bit + const r7 = writeClientFrameHeader(&buf, .text, 5, mask_key, false); + try testing.expectEqual(@as(usize, 6), r7.len); + try testing.expectEqual(@as(u8, 0x81), r7[0]); // FIN=1, text + try testing.expectEqual(@as(u8, 0x85), r7[1]); // MASK=1, len=5 + try testing.expectEqualSlices(u8, &mask_key, r7[2..6]); + + // 16-bit + const r16 = writeClientFrameHeader(&buf, .binary, 300, mask_key, false); + try testing.expectEqual(@as(usize, 8), r16.len); + try testing.expectEqual(@as(u8, 0x82), r16[0]); // FIN=1, binary + try testing.expectEqual(@as(u8, 0xFE), r16[1]); // MASK=1, 126 + try testing.expectEqual(@as(u16, 300), std.mem.readInt(u16, r16[2..4], .big)); + try testing.expectEqualSlices(u8, &mask_key, r16[4..8]); + + // 64-bit + const r64 = writeClientFrameHeader(&buf, .binary, 70000, mask_key, false); + try testing.expectEqual(@as(usize, 14), r64.len); + try testing.expectEqual(@as(u8, 0x82), r64[0]); // FIN=1, binary + try testing.expectEqual(@as(u8, 0xFF), r64[1]); // MASK=1, 127 + try testing.expectEqual(@as(u64, 70000), std.mem.readInt(u64, r64[2..10], .big)); + try testing.expectEqualSlices(u8, &mask_key, r64[10..14]); +} + +test "writeClientFrameHeader: compressed flag sets RSV1" { + var buf: [14]u8 = undefined; + const result = writeClientFrameHeader(&buf, .text, 10, .{ 0x01, 0x02, 0x03, 0x04 }, true); + try testing.expectEqual(@as(u8, 0xC1), result[0]); // FIN=1, RSV1=1, text +} + +test "writeClientFrameHeader + parseHeader round-trip" { + const mask_key = [_]u8{ 0xAB, 0xCD, 0xEF, 0x01 }; + var buf: [14]u8 = undefined; + const written = writeClientFrameHeader(&buf, .text, 42, mask_key, false); + + const header = try parseHeader(written); + try testing.expect(header.fin); + try testing.expect(!header.rsv1); + try testing.expectEqual(Opcode.text, header.opcode); + try testing.expect(header.masked); + try testing.expectEqual(@as(u64, 42), header.payload_len); + try testing.expectEqualSlices(u8, &mask_key, &header.mask_key); + try testing.expectEqual(@as(u8, @intCast(written.len)), header.header_len); +} + +// --- Fuzz tests --- + +test "fuzz parseHeader: no crash on arbitrary input" { + try testing.fuzz({}, struct { + fn run(_: void, input: []const u8) anyerror!void { + // parseHeader must either return a valid header or an error, never crash. + _ = parseHeader(input) catch return; + } + }.run, .{ + .corpus = &[_][]const u8{ + // Minimal valid frames + &[_]u8{ 0x81, 0x05 }, // text, len=5 + &[_]u8{ 0x88, 0x00 }, // close, len=0 + &[_]u8{ 0x89, 0x00 }, // ping, len=0 + &[_]u8{ 0x8A, 0x00 }, // pong, len=0 + // Masked frame + &[_]u8{ 0x81, 0x85, 0x37, 0xFA, 0x21, 0x3D }, + // 16-bit length + &[_]u8{ 0x82, 126, 0x01, 0x00 }, + // 64-bit length + &[_]u8{ 0x82, 127, 0, 0, 0, 0, 0, 1, 0x11, 0x70 }, + // Reserved opcodes + &[_]u8{ 0x83, 0x00 }, + &[_]u8{ 0x8F, 0x00 }, + // Empty + &[_]u8{}, + // Single byte + &[_]u8{0x81}, + }, + }); +} + +test "fuzz writeHeader + parseHeader round-trip (server and client)" { + try testing.fuzz({}, struct { + // Input layout: [0] opcode selector, [1..9] payload_len, [9..13] mask_key, [13] flags + // flags bit 0 = compressed, flags bit 1 = masked (client frame) + fn run(_: void, input: []const u8) anyerror!void { + if (input.len < 14) return; + + const valid_opcodes = [_]Opcode{ .continuation, .text, .binary, .close, .ping, .pong }; + const opcode = valid_opcodes[input[0] % valid_opcodes.len]; + const payload_len = std.mem.readInt(u64, input[1..9], .big); + const mask_key: [4]u8 = input[9..13].*; + const compressed = (input[13] & 1) != 0; + const masked = (input[13] & 2) != 0; + + var buf: [14]u8 = undefined; + const written = if (masked) + writeClientFrameHeader(&buf, opcode, payload_len, mask_key, compressed) + else + writeFrameHeader(buf[0..10], opcode, payload_len, compressed); + + const header = try parseHeader(written); + + try testing.expect(header.fin); + try testing.expectEqual(opcode, header.opcode); + try testing.expectEqual(masked, header.masked); + try testing.expectEqual(payload_len, header.payload_len); + try testing.expectEqual(compressed, header.rsv1); + try testing.expectEqual(@as(u8, @intCast(written.len)), header.header_len); + if (masked) { + try testing.expectEqualSlices(u8, &mask_key, &header.mask_key); + } + } + }.run, .{ + .corpus = &[_][]const u8{ + // Server (unmasked): flags bit 1 = 0 + &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0 }, + &[_]u8{ 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, + &[_]u8{ 2, 0, 0, 0, 0, 0, 0, 0, 125, 0, 0, 0, 0, 0 }, + &[_]u8{ 2, 0, 0, 0, 0, 0, 0, 0, 126, 0, 0, 0, 0, 0 }, + &[_]u8{ 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, + &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 1 }, // compressed + // Client (masked): flags bit 1 = 2 + &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 5, 0xAA, 0xBB, 0xCC, 0xDD, 2 }, + &[_]u8{ 1, 0, 0, 0, 0, 0, 0, 1, 0, 0x11, 0x22, 0x33, 0x44, 2 }, + &[_]u8{ 2, 0, 0, 0, 0, 0, 0, 0, 126, 0xDE, 0xAD, 0xBE, 0xEF, 3 }, // masked + compressed + }, + }); +} diff --git a/src/rpc/webzockets/src/http.zig b/src/rpc/webzockets/src/http.zig new file mode 100644 index 0000000000..887360e8ef --- /dev/null +++ b/src/rpc/webzockets/src/http.zig @@ -0,0 +1,666 @@ +const std = @import("std"); +const types = @import("types.zig"); + +const HandshakeError = types.HandshakeError; + +pub const HeadParser = std.http.HeadParser; + +/// RFC 6455 magic GUID used in Sec-WebSocket-Accept computation. +const websocket_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +/// Parsed HTTP upgrade request. +/// All string fields borrow from the input buffer (no allocations). +pub const Request = struct { + /// HTTP method — must be "GET". + method: []const u8, + /// Request URI path. + path: []const u8, + /// Sec-WebSocket-Key header value. + websocket_key: []const u8, + /// Host header value (empty slice if absent). + /// NOTE: permissive to missing Host when parsing despite RFC 6455 + /// and HTTP requiring it (validate in Handler init if you care). + host: []const u8, + /// Origin header value (empty slice if absent). + origin: []const u8, + /// Sec-WebSocket-Protocol header value (empty slice if absent). + protocols: []const u8, + /// Sec-WebSocket-Extensions header value (empty slice if absent). + extensions: []const u8, +}; + +/// Bitmask bits for tracking required headers during parsing. +const required_upgrade: u4 = 0b0001; +const required_connection: u4 = 0b0010; +const required_key: u4 = 0b0100; +const required_version: u4 = 0b1000; + +/// Parses an HTTP WebSocket upgrade request from `buf`. +/// +/// `buf` must contain a complete HTTP head as delimited by `HeadParser`. +/// Headers must use `\r\n` line endings; bare-LF lines are not correctly +/// parsed and will likely result in missing-header errors. +/// +/// Returns a `HandshakeError` if the request is malformed. +/// On success, returns a `Request` struct with all fields borrowing from `buf`. +pub fn parseRequest(buf: []const u8) HandshakeError!Request { + + // Step 1: Parse the request line. + var lines = std.mem.splitSequence(u8, buf, "\r\n"); + const request_line = lines.next() orelse return HandshakeError.MalformedRequest; + + // Parse "METHOD PATH HTTP/1.1" + const method_end = std.mem.indexOf(u8, request_line, " ") orelse + return HandshakeError.MalformedRequest; + const method = request_line[0..method_end]; + + if (!std.mem.eql(u8, method, "GET")) { + return HandshakeError.InvalidMethod; + } + + const after_method = request_line[method_end + 1 ..]; + const path_end = std.mem.indexOf(u8, after_method, " ") orelse + return HandshakeError.MalformedRequest; + const path = after_method[0..path_end]; + + const version = after_method[path_end + 1 ..]; + if (!std.mem.eql(u8, version, "HTTP/1.1")) { + return HandshakeError.InvalidHttpVersion; + } + + // Step 2: Parse headers. + var found: u4 = 0; + var version_seen = false; + + var websocket_key: []const u8 = &[_]u8{}; + var host: []const u8 = &[_]u8{}; + var origin: []const u8 = &[_]u8{}; + var protocols: []const u8 = &[_]u8{}; + var extensions: []const u8 = &[_]u8{}; + + while (lines.next()) |line| { + if (line.len == 0) break; + const colon_pos = std.mem.indexOf(u8, line, ":") orelse continue; + const header_name = line[0..colon_pos]; + const header_value = std.mem.trim(u8, line[colon_pos + 1 ..], " \t"); + + if (std.ascii.eqlIgnoreCase(header_name, "Upgrade")) { + if (headerContainsToken(header_value, "websocket")) { + found |= required_upgrade; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "Connection")) { + if (headerContainsToken(header_value, "upgrade")) { + found |= required_connection; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "Sec-WebSocket-Key")) { + websocket_key = header_value; + found |= required_key; + } else if (std.ascii.eqlIgnoreCase(header_name, "Sec-WebSocket-Version")) { + version_seen = true; + if (std.mem.eql(u8, header_value, "13")) { + found |= required_version; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "Host")) { + host = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "Origin")) { + origin = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "Sec-WebSocket-Protocol")) { + protocols = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "Sec-WebSocket-Extensions")) { + extensions = header_value; + } + } + + // Step 3: Check all required headers were found. + if (found & required_upgrade == 0) return HandshakeError.MissingUpgradeHeader; + if (found & required_connection == 0) return HandshakeError.MissingConnectionHeader; + if (found & required_key == 0) return HandshakeError.MissingWebSocketKey; + if (found & required_version == 0) { + if (version_seen) { + return HandshakeError.UnsupportedWebSocketVersion; + } + return HandshakeError.MissingWebSocketVersion; + } + + return Request{ + .method = method, + .path = path, + .websocket_key = websocket_key, + .host = host, + .origin = origin, + .protocols = protocols, + .extensions = extensions, + }; +} + +/// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2. +/// +/// Concatenates `client_key` with the magic GUID, computes SHA-1, and base64-encodes +/// the result into `accept_buf`. Returns the valid slice of `accept_buf`. +pub fn computeAcceptKey(accept_buf: *[28]u8, client_key: []const u8) []const u8 { + var hasher = std.crypto.hash.Sha1.init(.{}); + hasher.update(client_key); + hasher.update(websocket_guid); + const digest = hasher.finalResult(); + + const encoded = std.base64.standard.Encoder.encode(accept_buf, &digest); + return encoded; +} + +/// Error set for writeResponse. +pub const WriteError = error{ + BufferTooSmall, +}; + +/// Writes the HTTP 101 Switching Protocols response into `buf`. +/// +/// Returns the slice of `buf` that was written, or `error.BufferTooSmall` if the +/// buffer is not large enough to hold the complete response. +pub fn writeResponse(buf: []u8, client_key: []const u8) WriteError![]const u8 { + var accept_buf: [28]u8 = undefined; + const accept_key = computeAcceptKey(&accept_buf, client_key); + + const prefix = "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Accept: "; + const suffix = "\r\n\r\n"; + + const total_len = prefix.len + accept_key.len + suffix.len; + if (buf.len < total_len) return WriteError.BufferTooSmall; + + @memcpy(buf[0..prefix.len], prefix); + @memcpy(buf[prefix.len .. prefix.len + accept_key.len], accept_key); + @memcpy(buf[prefix.len + accept_key.len .. total_len], suffix); + + return buf[0..total_len]; +} + +/// Base64-encode a 16-byte raw key into `key_buf`. +/// Returns the 24-byte encoded slice. +pub fn encodeKey(key_buf: *[24]u8, raw: *const [16]u8) []const u8 { + return std.base64.standard.Encoder.encode(key_buf, raw); +} + +/// Write an HTTP GET upgrade request into `buf`. +/// Returns the slice of `buf` that was written, or error if the buffer is too small. +pub fn writeRequest( + buf: []u8, + address: std.net.Address, + path: []const u8, + key: []const u8, +) ![]const u8 { + var fbs = std.io.fixedBufferStream(buf); + const w = fbs.writer(); + w.print("GET {s} HTTP/1.1\r\n", .{path}) catch return error.BufferTooSmall; + w.print("Host: {}\r\n", .{address}) catch return error.BufferTooSmall; + w.writeAll("Upgrade: websocket\r\n") catch return error.BufferTooSmall; + w.writeAll("Connection: Upgrade\r\n") catch return error.BufferTooSmall; + w.print("Sec-WebSocket-Key: {s}\r\n", .{key}) catch return error.BufferTooSmall; + w.writeAll("Sec-WebSocket-Version: 13\r\n") catch return error.BufferTooSmall; + w.writeAll("\r\n") catch return error.BufferTooSmall; + return buf[0..fbs.pos]; +} + +/// Validates an HTTP 101 Switching Protocols response. +/// +/// `buf` must contain a complete HTTP head as delimited by `HeadParser`. +/// Headers must use `\r\n` line endings; bare-LF lines are not correctly +/// parsed and will likely result in missing-header errors. +/// +/// Returns a `HandshakeError` if the response is malformed or invalid. +pub fn validateResponse(buf: []const u8, expected_key: []const u8) HandshakeError!void { + // Parse status line + var lines = std.mem.splitSequence(u8, buf, "\r\n"); + const status_line = lines.next() orelse return HandshakeError.MalformedResponse; + + if (!std.mem.startsWith(u8, status_line, "HTTP/1.1 ")) { + return HandshakeError.MalformedResponse; + } + const after_version = status_line["HTTP/1.1 ".len..]; + if (after_version.len < 3) { + return HandshakeError.MalformedResponse; + } + if (!std.mem.eql(u8, after_version[0..3], "101")) { + return HandshakeError.InvalidStatusCode; + } + + // Parse headers — track all three required response headers per RFC 6455 §4.2.2. + var found_upgrade = false; + var found_connection = false; + var found_accept = false; + + while (lines.next()) |line| { + if (line.len == 0) break; + const colon_pos = std.mem.indexOf(u8, line, ":") orelse continue; + const header_name = line[0..colon_pos]; + const header_value = std.mem.trim(u8, line[colon_pos + 1 ..], " \t"); + + if (std.ascii.eqlIgnoreCase(header_name, "Upgrade")) { + if (headerContainsToken(header_value, "websocket")) { + found_upgrade = true; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "Connection")) { + if (headerContainsToken(header_value, "upgrade")) { + found_connection = true; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "Sec-WebSocket-Accept")) { + var accept_buf: [28]u8 = undefined; + const expected_accept = computeAcceptKey(&accept_buf, expected_key); + if (!std.mem.eql(u8, header_value, expected_accept)) { + return HandshakeError.InvalidAcceptKey; + } + found_accept = true; + } + } + + if (!found_upgrade) return HandshakeError.MissingUpgradeHeader; + if (!found_connection) return HandshakeError.MissingConnectionHeader; + if (!found_accept) return HandshakeError.MissingAcceptHeader; +} + +// --- Internal helpers --- + +/// Returns true if a comma-separated header value contains the given token +/// (case-insensitive, trims whitespace from value items). E.g. "keep-alive, Upgrade" contains "upgrade". +fn headerContainsToken(value: []const u8, token: []const u8) bool { + var it = std.mem.splitScalar(u8, value, ','); + while (it.next()) |item| { + if (std.ascii.eqlIgnoreCase(std.mem.trim(u8, item, " \t"), token)) return true; + } + return false; +} + +const testing = std.testing; + +// =========================================================================== +// Tests — Server-side (from handshake.zig) +// =========================================================================== + +test "parseRequest: valid full upgrade request" { + const request_bytes = + "GET /chat HTTP/1.1\r\n" ++ + "Host: server.example.com\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "\r\n"; + + const req = try parseRequest(request_bytes); + + try testing.expectEqualStrings("GET", req.method); + try testing.expectEqualStrings("/chat", req.path); + try testing.expectEqualStrings("dGhlIHNhbXBsZSBub25jZQ==", req.websocket_key); + try testing.expectEqualStrings("server.example.com", req.host); + try testing.expectEqualStrings("", req.origin); + try testing.expectEqualStrings("", req.protocols); + try testing.expectEqualStrings("", req.extensions); +} + +test "parseRequest: missing Upgrade header" { + const request_bytes = + "GET /chat HTTP/1.1\r\n" ++ + "Host: server.example.com\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "\r\n"; + + const result = parseRequest(request_bytes); + try testing.expectError(HandshakeError.MissingUpgradeHeader, result); +} + +test "parseRequest: missing Connection header" { + const request_bytes = + "GET /chat HTTP/1.1\r\n" ++ + "Host: server.example.com\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "\r\n"; + + const result = parseRequest(request_bytes); + try testing.expectError(HandshakeError.MissingConnectionHeader, result); +} + +test "parseRequest: missing Sec-WebSocket-Key header" { + const request_bytes = + "GET /chat HTTP/1.1\r\n" ++ + "Host: server.example.com\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "\r\n"; + + const result = parseRequest(request_bytes); + try testing.expectError(HandshakeError.MissingWebSocketKey, result); +} + +test "parseRequest: missing Sec-WebSocket-Version header" { + const request_bytes = + "GET /chat HTTP/1.1\r\n" ++ + "Host: server.example.com\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "\r\n"; + + const result = parseRequest(request_bytes); + try testing.expectError(HandshakeError.MissingWebSocketVersion, result); +} + +test "parseRequest: wrong HTTP version" { + const request_bytes = + "GET /chat HTTP/1.0\r\n" ++ + "Host: server.example.com\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "\r\n"; + + const result = parseRequest(request_bytes); + try testing.expectError(HandshakeError.InvalidHttpVersion, result); +} + +test "parseRequest: wrong method" { + const request_bytes = + "POST /chat HTTP/1.1\r\n" ++ + "Host: server.example.com\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "\r\n"; + + const result = parseRequest(request_bytes); + try testing.expectError(HandshakeError.InvalidMethod, result); +} + +test "parseRequest: wrong WebSocket version" { + const request_bytes = + "GET /chat HTTP/1.1\r\n" ++ + "Host: server.example.com\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "Sec-WebSocket-Version: 8\r\n" ++ + "\r\n"; + + const result = parseRequest(request_bytes); + try testing.expectError(HandshakeError.UnsupportedWebSocketVersion, result); +} + +test "parseRequest: optional headers captured" { + const request_bytes = + "GET /chat HTTP/1.1\r\n" ++ + "Host: server.example.com\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "Sec-WebSocket-Version: 13\r\n" ++ + "Origin: http://example.com\r\n" ++ + "Sec-WebSocket-Protocol: chat, superchat\r\n" ++ + "Sec-WebSocket-Extensions: permessage-deflate\r\n" ++ + "\r\n"; + + const req = try parseRequest(request_bytes); + + try testing.expectEqualStrings("http://example.com", req.origin); + try testing.expectEqualStrings("chat, superchat", req.protocols); + try testing.expectEqualStrings("permessage-deflate", req.extensions); +} + +test "parseRequest: case insensitive headers" { + const request_bytes = + "GET /chat HTTP/1.1\r\n" ++ + "host: server.example.com\r\n" ++ + "upgrade: WebSocket\r\n" ++ + "connection: UPGRADE\r\n" ++ + "sec-websocket-key: dGhlIHNhbXBsZSBub25jZQ==\r\n" ++ + "sec-websocket-version: 13\r\n" ++ + "\r\n"; + + const req = try parseRequest(request_bytes); + + try testing.expectEqualStrings("GET", req.method); + try testing.expectEqualStrings("dGhlIHNhbXBsZSBub25jZQ==", req.websocket_key); + try testing.expectEqualStrings("server.example.com", req.host); +} + +test "computeAcceptKey: RFC 6455 example" { + var accept_buf: [28]u8 = undefined; + const accept_key = computeAcceptKey(&accept_buf, "dGhlIHNhbXBsZSBub25jZQ=="); + try testing.expectEqualStrings("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", accept_key); +} + +test "writeResponse: produces valid 101 response" { + var buf: [256]u8 = undefined; + const response = try writeResponse(&buf, "dGhlIHNhbXBsZSBub25jZQ=="); + + // Verify the response starts with 101 status. + try testing.expect(std.mem.startsWith(u8, response, "HTTP/1.1 101 Switching Protocols\r\n")); + + // Verify required headers are present. + try testing.expect(std.mem.indexOf(u8, response, "Upgrade: websocket\r\n") != null); + try testing.expect(std.mem.indexOf(u8, response, "Connection: Upgrade\r\n") != null); + const accept_header = "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"; + try testing.expect(std.mem.indexOf(u8, response, accept_header) != null); + + // Verify response ends with \r\n\r\n. + try testing.expect(std.mem.endsWith(u8, response, "\r\n\r\n")); +} + +test "writeResponse: buffer too small" { + var buf: [10]u8 = undefined; + const result = writeResponse(&buf, "dGhlIHNhbXBsZSBub25jZQ=="); + try testing.expectError(WriteError.BufferTooSmall, result); +} + +test "writeResponse: accept key matches computeAcceptKey" { + const client_key = "dGhlIHNhbXBsZSBub25jZQ=="; + + // Compute accept key directly. + var accept_buf: [28]u8 = undefined; + const expected_accept = computeAcceptKey(&accept_buf, client_key); + + // Write full response. + var response_buf: [256]u8 = undefined; + const response = try writeResponse(&response_buf, client_key); + + // Extract accept key from response. + const accept_prefix = "Sec-WebSocket-Accept: "; + const accept_start = (std.mem.indexOf(u8, response, accept_prefix) orelse + return error.TestFailed) + accept_prefix.len; + const accept_end = std.mem.indexOf(u8, response[accept_start..], "\r\n") orelse + return error.TestFailed; + const response_accept = response[accept_start .. accept_start + accept_end]; + + try testing.expectEqualStrings(expected_accept, response_accept); +} + +test "headerContainsToken: exact match" { + try testing.expect(headerContainsToken("websocket", "websocket")); + try testing.expect(headerContainsToken("Upgrade", "upgrade")); +} + +test "headerContainsToken: case insensitive" { + try testing.expect(headerContainsToken("WebSocket", "websocket")); + try testing.expect(headerContainsToken("WEBSOCKET", "websocket")); + try testing.expect(headerContainsToken("UPGRADE", "upgrade")); +} + +test "headerContainsToken: comma-separated list" { + try testing.expect(headerContainsToken("keep-alive, Upgrade", "upgrade")); + try testing.expect(headerContainsToken("Upgrade, keep-alive", "upgrade")); + try testing.expect(headerContainsToken("foo, Upgrade, bar", "upgrade")); +} + +test "headerContainsToken: trims whitespace" { + try testing.expect(headerContainsToken(" websocket ", "websocket")); + try testing.expect(headerContainsToken("keep-alive,\tUpgrade", "upgrade")); + try testing.expect(headerContainsToken("keep-alive , Upgrade ", "upgrade")); +} + +test "headerContainsToken: rejects substrings" { + try testing.expect(!headerContainsToken("notwebsocket", "websocket")); + try testing.expect(!headerContainsToken("websocketx", "websocket")); + try testing.expect(!headerContainsToken("noupgrade", "upgrade")); + try testing.expect(!headerContainsToken("upgrade2", "upgrade")); +} + +test "headerContainsToken: no match" { + try testing.expect(!headerContainsToken("keep-alive", "upgrade")); + try testing.expect(!headerContainsToken("", "upgrade")); + try testing.expect(!headerContainsToken("foo, bar, baz", "upgrade")); +} + +// =========================================================================== +// Tests — Client-side (from client_handshake.zig) +// =========================================================================== + +test "encodeKey: produces 24-byte base64 string" { + var buf: [24]u8 = undefined; + const raw = [_]u8{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + }; + const key = encodeKey(&buf, &raw); + try testing.expectEqual(@as(usize, 24), key.len); + + // Verify it's valid base64 by decoding + var decoded: [16]u8 = undefined; + try std.base64.standard.Decoder.decode(&decoded, key); + try testing.expectEqualSlices(u8, &raw, &decoded); +} + +test "encodeKey: different inputs produce different keys" { + var buf1: [24]u8 = undefined; + var buf2: [24]u8 = undefined; + const raw1 = [_]u8{0x01} ** 16; + const raw2 = [_]u8{0x02} ** 16; + const key1 = encodeKey(&buf1, &raw1); + const key2 = encodeKey(&buf2, &raw2); + try testing.expect(!std.mem.eql(u8, key1, key2)); +} + +test "writeRequest: produces valid HTTP request" { + var buf: [512]u8 = undefined; + const addr = std.net.Address.initIp4(.{ 93, 184, 216, 34 }, 443); + const request = try writeRequest(&buf, addr, "/ws", "dGhlIHNhbXBsZSBub25jZQ=="); + + try testing.expect(std.mem.startsWith(u8, request, "GET /ws HTTP/1.1\r\n")); + try testing.expect(std.mem.indexOf(u8, request, "Host: 93.184.216.34:443\r\n") != null); + try testing.expect(std.mem.indexOf(u8, request, "Upgrade: websocket\r\n") != null); + try testing.expect(std.mem.indexOf(u8, request, "Connection: Upgrade\r\n") != null); + const key_header = "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"; + try testing.expect(std.mem.indexOf(u8, request, key_header) != null); + try testing.expect(std.mem.indexOf(u8, request, "Sec-WebSocket-Version: 13\r\n") != null); + try testing.expect(std.mem.endsWith(u8, request, "\r\n\r\n")); +} + +test "writeRequest: buffer too small" { + var buf: [10]u8 = undefined; + const addr = std.net.Address.initIp4(.{ 93, 184, 216, 34 }, 443); + const result = writeRequest(&buf, addr, "/ws", "dGhlIHNhbXBsZSBub25jZQ=="); + try testing.expectError(error.BufferTooSmall, result); +} + +test "validateResponse: valid 101 response" { + const client_key = "dGhlIHNhbXBsZSBub25jZQ=="; + var accept_buf: [28]u8 = undefined; + const accept_key = computeAcceptKey(&accept_buf, client_key); + + var response_buf: [256]u8 = undefined; + var fbs = std.io.fixedBufferStream(&response_buf); + const w = fbs.writer(); + try w.writeAll("HTTP/1.1 101 Switching Protocols\r\n"); + try w.writeAll("Upgrade: websocket\r\n"); + try w.writeAll("Connection: Upgrade\r\n"); + try w.print("Sec-WebSocket-Accept: {s}\r\n", .{accept_key}); + try w.writeAll("\r\n"); + + try validateResponse(response_buf[0..fbs.pos], client_key); +} + +test "validateResponse: wrong status code" { + const response = "HTTP/1.1 400 Bad Request\r\n\r\n"; + const result = validateResponse(response, "somekey"); + try testing.expectError(HandshakeError.InvalidStatusCode, result); +} + +test "validateResponse: missing accept header" { + const response = + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "\r\n"; + const result = validateResponse(response, "somekey"); + try testing.expectError(HandshakeError.MissingAcceptHeader, result); +} + +test "validateResponse: missing upgrade header" { + const client_key = "dGhlIHNhbXBsZSBub25jZQ=="; + var accept_buf: [28]u8 = undefined; + const accept_key = computeAcceptKey(&accept_buf, client_key); + + var response_buf: [256]u8 = undefined; + var fbs = std.io.fixedBufferStream(&response_buf); + const w = fbs.writer(); + try w.writeAll("HTTP/1.1 101 Switching Protocols\r\n"); + try w.writeAll("Connection: Upgrade\r\n"); + try w.print("Sec-WebSocket-Accept: {s}\r\n", .{accept_key}); + try w.writeAll("\r\n"); + + const result = validateResponse(response_buf[0..fbs.pos], client_key); + try testing.expectError(HandshakeError.MissingUpgradeHeader, result); +} + +test "validateResponse: missing connection header" { + const client_key = "dGhlIHNhbXBsZSBub25jZQ=="; + var accept_buf: [28]u8 = undefined; + const accept_key = computeAcceptKey(&accept_buf, client_key); + + var response_buf: [256]u8 = undefined; + var fbs = std.io.fixedBufferStream(&response_buf); + const w = fbs.writer(); + try w.writeAll("HTTP/1.1 101 Switching Protocols\r\n"); + try w.writeAll("Upgrade: websocket\r\n"); + try w.print("Sec-WebSocket-Accept: {s}\r\n", .{accept_key}); + try w.writeAll("\r\n"); + + const result = validateResponse(response_buf[0..fbs.pos], client_key); + try testing.expectError(HandshakeError.MissingConnectionHeader, result); +} + +test "validateResponse: wrong accept key" { + const response = + "HTTP/1.1 101 Switching Protocols\r\n" ++ + "Upgrade: websocket\r\n" ++ + "Connection: Upgrade\r\n" ++ + "Sec-WebSocket-Accept: wrongkey=\r\n" ++ + "\r\n"; + const result = validateResponse(response, "dGhlIHNhbXBsZSBub25jZQ=="); + try testing.expectError(HandshakeError.InvalidAcceptKey, result); +} + +test "validateResponse: malformed status line" { + const response = "INVALID\r\n\r\n"; + const result = validateResponse(response, "somekey"); + try testing.expectError(HandshakeError.MalformedResponse, result); +} + +test "validateResponse: status-line-only response with no headers" { + // Regression: a 101 response with no header lines at all. The status line + // is terminated by the \r\n that starts the \r\n\r\n boundary. + const response = "HTTP/1.1 101 Switching Protocols\r\n\r\n"; + const result = validateResponse(response, "somekey"); + try testing.expectError(HandshakeError.MissingUpgradeHeader, result); +} + +test "validateResponse: non-101 status-line-only response" { + const response = "HTTP/1.1 403 Forbidden\r\n\r\n"; + const result = validateResponse(response, "somekey"); + try testing.expectError(HandshakeError.InvalidStatusCode, result); +} diff --git a/src/rpc/webzockets/src/mask.zig b/src/rpc/webzockets/src/mask.zig new file mode 100644 index 0000000000..7829556bbd --- /dev/null +++ b/src/rpc/webzockets/src/mask.zig @@ -0,0 +1,136 @@ +const std = @import("std"); + +/// Applies XOR masking/unmasking to payload in-place using the 4-byte mask key. +/// The same function works for both masking and unmasking (XOR is its own inverse). +pub fn mask(mask_key: [4]u8, payload: []u8) void { + var data = payload; + + const vector_size = comptime std.simd.suggestVectorLength(u8) orelse @sizeOf(usize); + if (data.len >= vector_size) { + const mask_vector: @Vector(vector_size, u8) = std.simd.repeat(vector_size, mask_key); + while (data.len >= vector_size) { + const slice = data[0..vector_size]; + const masked: @Vector(vector_size, u8) = slice.*; + slice.* = masked ^ mask_vector; + data = data[vector_size..]; + } + } + + scalarMask(mask_key, data); +} + +/// Scalar byte-by-byte XOR mask fallback for remaining bytes or unsupported backends. +fn scalarMask(mask_key: [4]u8, payload: []u8) void { + @setRuntimeSafety(false); + for (payload, 0..) |b, i| { + payload[i] = b ^ mask_key[i & 3]; + } +} + +const testing = std.testing; + +test "mask: empty payload is no-op" { + var buf = [_]u8{}; + mask(.{ 0xAA, 0xBB, 0xCC, 0xDD }, &buf); +} + +test "mask: single byte" { + var buf = [_]u8{0x42}; + mask(.{ 0xFF, 0x00, 0x00, 0x00 }, &buf); + try testing.expectEqual(@as(u8, 0x42 ^ 0xFF), buf[0]); +} + +test "mask: round-trip produces original" { + const original = "Hello, WebSocket!"; + var buf: [original.len]u8 = undefined; + @memcpy(&buf, original); + + const key = [_]u8{ 0x37, 0xFA, 0x21, 0x3D }; + + // Mask + mask(key, &buf); + // Should differ from original + try testing.expect(!std.mem.eql(u8, &buf, original)); + // Unmask (same operation) + mask(key, &buf); + try testing.expectEqualSlices(u8, original, &buf); +} + +test "mask: exact 4-byte alignment" { + var buf = [_]u8{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08 }; + const key = [_]u8{ 0xFF, 0xFF, 0xFF, 0xFF }; + mask(key, &buf); + const expected = [_]u8{ 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8, 0xF7 }; + try testing.expectEqualSlices(u8, &expected, &buf); +} + +test "mask: SIMD boundary sizes" { + const key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + // Test various sizes including around typical SIMD widths (16, 32, 64) + const sizes = [_]usize{ + 1, 2, 3, 4, 5, 7, 8, 15, 16, 17, 31, + 32, 33, 63, 64, 65, 127, 128, 129, 255, 256, 257, + }; + + for (sizes) |size| { + const buf = try testing.allocator.alloc(u8, size); + defer testing.allocator.free(buf); + const expected = try testing.allocator.alloc(u8, size); + defer testing.allocator.free(expected); + + // Fill with known pattern + for (buf, 0..) |*b, i| { + b.* = @truncate(i); + } + @memcpy(expected, buf); + + // Mask then unmask should round-trip + mask(key, buf); + mask(key, buf); + try testing.expectEqualSlices(u8, expected, buf); + } +} + +test "mask: all-zero mask is identity" { + const original = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE }; + var buf = original; + mask(.{ 0x00, 0x00, 0x00, 0x00 }, &buf); + try testing.expectEqualSlices(u8, &original, &buf); +} + +test "mask: all-ones mask inverts all bits" { + var buf = [_]u8{ 0x00, 0xFF, 0xAA, 0x55 }; + mask(.{ 0xFF, 0xFF, 0xFF, 0xFF }, &buf); + try testing.expectEqualSlices(u8, &[_]u8{ 0xFF, 0x00, 0x55, 0xAA }, &buf); +} + +test "mask: RFC 6455 example verification" { + // RFC 6455 Section 5.7 examples use masking; verify XOR properties + // "Hello" = { 0x48, 0x65, 0x6c, 0x6c, 0x6f } + // With mask key { 0x37, 0xfa, 0x21, 0x3d }: + // masked = { 0x48^0x37, 0x65^0xfa, 0x6c^0x21, 0x6c^0x3d, 0x6f^0x37 } + // = { 0x7f, 0x9f, 0x4d, 0x51, 0x58 } + var buf = [_]u8{ 0x48, 0x65, 0x6c, 0x6c, 0x6f }; + mask(.{ 0x37, 0xfa, 0x21, 0x3d }, &buf); + try testing.expectEqualSlices(u8, &[_]u8{ 0x7f, 0x9f, 0x4d, 0x51, 0x58 }, &buf); +} + +test "mask: fuzz round-trip" { + const Context = struct { + fn testOne(_: @This(), input: []const u8) anyerror!void { + if (input.len < 4) return; + const key: [4]u8 = input[0..4].*; + const buf = try testing.allocator.alloc(u8, input.len - 4); + defer testing.allocator.free(buf); + @memcpy(buf, input[4..]); + const original = try testing.allocator.alloc(u8, buf.len); + defer testing.allocator.free(original); + @memcpy(original, buf); + + mask(key, buf); + mask(key, buf); + try testing.expectEqualSlices(u8, original, buf); + } + }; + try testing.fuzz(Context{}, Context.testOne, .{}); +} diff --git a/src/rpc/webzockets/src/reader.zig b/src/rpc/webzockets/src/reader.zig new file mode 100644 index 0000000000..a1a8519db0 --- /dev/null +++ b/src/rpc/webzockets/src/reader.zig @@ -0,0 +1,1171 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const frame = @import("frame.zig"); +const types = @import("types.zig"); + +/// WebSocket frame reader with automatic buffer tier escalation. +/// +/// Manages read buffer state with automatic tier escalation: +/// - Starts with an embedded buffer for small messages +/// - Upgrades to dynamic buffer when needed +/// - Retains larger buffer for subsequent messages (performance optimization) +/// - Restores to embedded only on explicit reset() +/// +/// Buffer layout: +/// +/// ``` +/// buf +/// ├──────────────────────────────────────────────────────┤ +/// │ consumed │ unprocessed data │ free space │ +/// │ (already │ (pending parse/ │ (available │ +/// │ parsed) │ delivery) │ for reads) │ +/// ├─────────────┼──────────────────────┼─────────────────┤ +/// 0 start pos buf.len +/// │ │ │ +/// ├── data() ────────────┤ │ +/// │ (start..pos) │ │ +/// ├─ readSlice() ───┤ +/// │ (pos..buf.len) │ +/// ``` +/// +/// - `buf[0..start]` — already consumed frames, reclaimable by `compact()` +/// - `buf[start..pos]` — received data not yet fully parsed/delivered (`data()`) +/// - `buf[pos..buf.len]` — writable region for the next TCP read (`readSlice()`) +/// - `compact()` shifts unprocessed data to the front: start→0, pos→dataLen +/// - `consume(n)` advances start by n after a frame is processed +/// +/// Parameterized by `Role` to select the correct frame validation: +/// - `.server`: validates that frames are masked (client-to-server) +/// - `.client`: validates that frames are unmasked (server-to-client) +pub fn Reader(comptime role: types.Role) type { + return struct { + /// Current buffer for reading frames, starts as embedded and may upgrade to dynamic. + buf: []u8, + /// Saved reference to embedded buffer. + embedded_buf: []u8, + /// Allocator for dynamic buffer allocation. + allocator: Allocator, + /// Position within buf that we've read into (end of data). + pos: usize, + /// Position in buf where the current message starts. + start: usize, + max_message_size: usize, + /// Opcode of the first fragment in an in-progress fragmented message (null if none). + fragment_opcode: ?types.Opcode, + /// Accumulator for fragmented message payloads. + fragment_buf: std.ArrayListUnmanaged(u8), + + const ReaderSelf = @This(); + + /// Initialize reader with embedded buffer. + pub fn init( + embedded_buf: []u8, + allocator: Allocator, + max_message_size: usize, + ) ReaderSelf { + return .{ + .buf = embedded_buf, + .embedded_buf = embedded_buf, + .allocator = allocator, + .pos = 0, + .start = 0, + .max_message_size = max_message_size, + .fragment_opcode = null, + .fragment_buf = .{}, + }; + } + + /// Clean up any non-embedded buffer and fragment state. + pub fn deinit(self: *ReaderSelf) void { + self.cleanupFragments(); + self.releaseCurrentBuffer(); + } + + /// Reset state for connection reuse. + pub fn reset(self: *ReaderSelf) void { + self.cleanupFragments(); + self.releaseCurrentBuffer(); + self.buf = self.embedded_buf; + self.pos = 0; + self.start = 0; + } + + fn cleanupFragments(self: *ReaderSelf) void { + if (self.fragment_buf.capacity > 0) { + self.fragment_buf.deinit(self.allocator); + self.fragment_buf = .{}; + } + self.fragment_opcode = null; + } + + /// Writable slice for xev TCP read, returns a slice with len > 0. + pub fn readSlice(self: *ReaderSelf) error{OutOfMemory}![]u8 { + const tail_free = self.availableSpace(); + + if (self.start > 0) { + // Hybrid reclaim heuristic: compact early when tail free space is + // small (to avoid a tiny read), but only if shifting is + // amortized-cheap (`start >= live_len`). + const compact_threshold = @max(1, self.buf.len / 16); + if (tail_free > 0 and + tail_free <= compact_threshold and + self.start >= self.dataLen()) + { + self.compact(); + } else if (tail_free == 0) { + // Buffer full but has consumed bytes at front — reclaim them. + self.compact(); + } + } + + // Ensure a writable tail exists (len > 0). + // This is an unusual case because nextMessage will be consuming data and advancing + // start/pos and calling requireCapacity as needed. But if the caller just calls + // readSlice repeatedly without calling nextMessage then we must grow the buffer + // to avoid returning an empty slice. + if (self.availableSpace() == 0) { + self.requireCapacity(self.buf.len * 2) catch return error.OutOfMemory; + } + return self.buf[self.pos..]; + } + + /// Advance pos after successful read. + pub fn advancePos(self: *ReaderSelf, n: usize) void { + self.pos += n; + } + + /// Returns next complete message, or null if more data needed. + /// Call in a loop until null. Control frames are returned immediately + /// even mid-fragment-sequence. Fragments are assembled internally. + /// + /// NOTE: we just re-parse and validate the header each time this is called + /// rather than build a state machine since it is very cheap to do and + /// keeps things simple. For messages large enough to not fit in a single + /// TCP read the repeated work is inconsequential. + pub fn nextMessage( + self: *ReaderSelf, + ) error{ ProtocolError, MessageTooBig, OutOfMemory }!?types.Message { + while (true) { + const current_data = self.data(); + if (current_data.len == 0) return null; + + const header = frame.parseHeader(current_data) catch |err| { + return switch (err) { + error.InsufficientData => null, + else => error.ProtocolError, + }; + }; + + // Validate per RFC 6455 + header.validate() catch return error.ProtocolError; + switch (role) { + .server => header.validateServerBound() catch return error.ProtocolError, + .client => header.validateClientBound() catch return error.ProtocolError, + } + + if (header.payload_len > self.max_message_size) { + return error.MessageTooBig; + } + + const total_frame_size = header.totalLen(); + + // Have we received the full frame yet? + if (current_data.len < total_frame_size) { + // Ensure the buffer can fit the full frame for when the rest arrives. + self.requireCapacity(@intCast(total_frame_size)) catch return error.OutOfMemory; + return null; // Need more data from the wire + } + + const payload_len: usize = @intCast(header.payload_len); + const payload = current_data[header.header_len..][0..payload_len]; + header.unmaskPayload(payload); + + // Consume the frame bytes + self.consume(@intCast(total_frame_size)); + + // Dispatch by opcode + switch (header.opcode) { + .text, .binary => { + if (header.fin) { + // Complete single-frame message + if (self.fragment_opcode != null) { + return error.ProtocolError; // nested fragment + } + const msg_type: types.Message.Type = + if (header.opcode == .text) .text else .binary; + + return .{ .type = msg_type, .data = payload }; + } else { + // First frame of a fragmented message + if (self.fragment_opcode != null) { + return error.ProtocolError; // nested fragment + } + self.fragment_opcode = header.opcode; + self.fragment_buf.clearRetainingCapacity(); + self.fragment_buf.appendSlice( + self.allocator, + payload, + ) catch return error.OutOfMemory; + continue; + } + }, + .continuation => { + const frag_op = self.fragment_opcode orelse { + return error.ProtocolError; // unexpected continuation + }; + const new_len = self.fragment_buf.items.len + payload.len; + if (new_len > self.max_message_size) { + return error.MessageTooBig; + } + self.fragment_buf.appendSlice( + self.allocator, + payload, + ) catch return error.OutOfMemory; + + if (header.fin) { + // Fragmented message complete + const msg_type: types.Message.Type = + if (frag_op == .text) .text else .binary; + + const result: types.Message = .{ + .type = msg_type, + .data = self.fragment_buf.items, + }; + // Clear fragment state — data remains valid until next nextMessage()/read cycle + self.fragment_opcode = null; + return result; + } + continue; + }, + .ping => return .{ .type = .ping, .data = payload }, + .pong => return .{ .type = .pong, .data = payload }, + .close => return .{ .type = .close, .data = payload }, + } + } + } + + /// Get the current unprocessed data (from start to pos). + fn data(self: *const ReaderSelf) []u8 { + return self.buf[self.start..self.pos]; + } + + /// Get length of current unprocessed data. + fn dataLen(self: *const ReaderSelf) usize { + return self.pos - self.start; + } + + /// Number of writable bytes at the tail of the buffer. + pub fn availableSpace(self: *const ReaderSelf) usize { + return self.buf.len - self.pos; + } + + /// Reclaim consumed bytes at the front of the buffer by compacting, + /// but only when the buffer is completely full. Does not allocate. + pub fn compactIfFull(self: *ReaderSelf) void { + if (self.pos >= self.buf.len and self.start > 0) { + self.compact(); + } + } + + /// Ensure buffer has total capacity for at least `required` bytes. + /// Compacts if that alone satisfies the requirement, otherwise upgrades + /// to dynamic buffer. + fn requireCapacity(self: *ReaderSelf, required: usize) !void { + // If current capacity is sufficient, compact to make free space + // contiguous at the end. No heuristic here since we've already + // determined the capacity required. + if (required <= self.buf.len) { + self.compact(); + return; + } + + // Compaction won't be enough — need to upgrade buffer + const current_data_len = self.dataLen(); + + // Dynamic allocation - round up to next power of 2 + // to avoid repeated reallocations for incrementally growing messages + const alloc_size = std.math.ceilPowerOfTwo(usize, required) catch required; + const new_buf = try self.allocator.alloc(u8, alloc_size); + if (current_data_len > 0) { + @memcpy(new_buf[0..current_data_len], self.data()); + } + self.releaseCurrentBuffer(); + self.buf = new_buf; + self.pos = current_data_len; + self.start = 0; + } + + /// Shift unprocessed data to the front of the buffer, reclaiming consumed space. + fn compact(self: *ReaderSelf) void { + if (self.start == 0) return; + const current_data_len = self.dataLen(); + if (current_data_len > 0) { + std.mem.copyForwards(u8, self.buf[0..current_data_len], self.data()); + } + self.pos = current_data_len; + self.start = 0; + } + + /// After processing a message, advance the start position. + fn consume(self: *ReaderSelf, bytes: usize) void { + self.start += bytes; + + // If we've consumed all data, reset positions but keep buffer + if (self.start == self.pos) { + self.pos = 0; + self.start = 0; + } + } + + /// Free the current buffer if it's not embedded. + fn releaseCurrentBuffer(self: *ReaderSelf) void { + if (self.buf.ptr != self.embedded_buf.ptr) { + self.allocator.free(self.buf); + } + } + }; +} + +const testing = std.testing; + +const ServerReader = Reader(.server); + +test "Reader: init with embedded buffer" { + var embedded_buf: [64]u8 = undefined; + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + + try testing.expectEqual(reader.embedded_buf.ptr, reader.buf.ptr); + try testing.expectEqual(@as(usize, 0), reader.pos); + try testing.expectEqual(@as(usize, 64), (try reader.readSlice()).len); +} + +test "Reader: availableSpace after init" { + var embedded_buf: [64]u8 = undefined; + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + + try testing.expectEqual(@as(usize, embedded_buf.len), reader.availableSpace()); +} + +test "Reader: availableSpace after advancePos" { + var embedded_buf: [64]u8 = undefined; + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + + reader.advancePos(10); + try testing.expectEqual(@as(usize, embedded_buf.len - 10), reader.availableSpace()); +} + +test "Reader: availableSpace when full" { + var embedded_buf: [64]u8 = undefined; + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + + reader.pos = embedded_buf.len; + try testing.expectEqual(@as(usize, 0), reader.availableSpace()); +} + +test "Reader: compactIfFull reclaims consumed bytes" { + var embedded_buf: [16]u8 = undefined; + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + + // Simulate full buffer with 6 bytes already consumed. + reader.start = 6; + reader.pos = embedded_buf.len; + + reader.compactIfFull(); + try testing.expectEqual(@as(usize, 6), reader.availableSpace()); + try testing.expectEqual(@as(usize, 0), reader.start); + try testing.expectEqual(@as(usize, 10), reader.pos); +} + +test "Reader: compactIfFull does nothing when not full" { + var embedded_buf: [16]u8 = undefined; + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + + // Buffer has consumed bytes but is not full — should not compact. + reader.start = 4; + reader.pos = 10; + + reader.compactIfFull(); + try testing.expectEqual(@as(usize, 4), reader.start); + try testing.expectEqual(@as(usize, 10), reader.pos); +} + +test "Reader: requireCapacity upgrades to dynamic" { + var embedded_buf: [64]u8 = undefined; + @memcpy(embedded_buf[0..5], "hello"); + + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + reader.pos = 5; + + // Require more than embedded size — escalates to dynamic + try reader.requireCapacity(128); + + try testing.expect(reader.buf.ptr != reader.embedded_buf.ptr); + try testing.expectEqual(@as(usize, 128), reader.buf.len); + try testing.expectEqual(@as(usize, 5), reader.pos); + try testing.expectEqual(@as(usize, 0), reader.start); + try testing.expectEqualStrings("hello", reader.data()); +} + +test "Reader: consume keeps larger buffer" { + var embedded_buf: [64]u8 = undefined; + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + + // Upgrade to dynamic + try reader.requireCapacity(128); + try testing.expect(reader.buf.ptr != reader.embedded_buf.ptr); + + // Simulate reading 100 bytes + reader.pos = 100; + + // Consume all data + reader.consume(100); + + // Should stay at dynamic (not restore to embedded) + try testing.expect(reader.buf.ptr != reader.embedded_buf.ptr); + try testing.expectEqual(@as(usize, 0), reader.pos); + try testing.expectEqual(@as(usize, 0), reader.start); +} + +test "Reader: requireCapacity uses power-of-2 rounding" { + var embedded_buf: [64]u8 = undefined; + var reader = ServerReader.init(&embedded_buf, testing.allocator, 1024); + defer reader.deinit(); + + // Require non-power-of-2 size — should round up + try reader.requireCapacity(300); + + try testing.expect(reader.buf.ptr != reader.embedded_buf.ptr); + try testing.expectEqual(@as(usize, 512), reader.buf.len); +} + +// --- Test helpers for nextMessage --- + +const mask_mod = @import("mask.zig"); + +/// Write the frame header into `out`, returning the header length written. +fn writeFrameHeader( + out: []u8, + opcode: types.Opcode, + fin: bool, + payload_len: usize, + masked: bool, +) usize { + var byte0: u8 = @intFromEnum(opcode); + if (fin) byte0 |= 0x80; + out[0] = byte0; + + const mask_bit: u8 = if (masked) 0x80 else 0; + var header_len: usize = 2; + if (payload_len <= 125) { + out[1] = mask_bit | @as(u8, @truncate(payload_len)); + } else if (payload_len <= 65535) { + out[1] = mask_bit | 126; + std.mem.writeInt(u16, out[2..4], @truncate(payload_len), .big); + header_len = 4; + } else { + out[1] = mask_bit | 127; + std.mem.writeInt(u64, out[2..10], payload_len, .big); + header_len = 10; + } + return header_len; +} + +/// Build a masked WebSocket frame into `out`. Returns the slice of `out` that was written. +fn buildMaskedFrame( + out: []u8, + opcode: types.Opcode, + fin: bool, + payload: []const u8, + mask_key: [4]u8, +) []u8 { + var header_len = writeFrameHeader( + out, + opcode, + fin, + payload.len, + true, + ); + + @memcpy(out[header_len..][0..4], &mask_key); + header_len += 4; + + @memcpy(out[header_len..][0..payload.len], payload); + mask_mod.mask(mask_key, out[header_len..][0..payload.len]); + + return out[0 .. header_len + payload.len]; +} + +/// Create a server reader with data pre-loaded. +fn testReader(embedded_buf: []u8, max_msg: usize) ServerReader { + return ServerReader.init(embedded_buf, testing.allocator, max_msg); +} + +// --- nextMessage tests --- + +test "Reader.nextMessage: single text frame" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + // Build a masked text frame with "Hello" + var frame_buf: [64]u8 = undefined; + const f = buildMaskedFrame(&frame_buf, .text, true, "Hello", .{ 0x37, 0xFA, 0x21, 0x3D }); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.text, msg.type); + try testing.expectEqualStrings("Hello", msg.data); + + // No more messages + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader.nextMessage: single binary frame" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const payload = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }; + var frame_buf: [64]u8 = undefined; + const f = buildMaskedFrame(&frame_buf, .binary, true, &payload, .{ 0x11, 0x22, 0x33, 0x44 }); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.binary, msg.type); + try testing.expectEqualSlices(u8, &payload, msg.data); +} + +test "Reader.nextMessage: partial frame returns null" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + // Build a frame but only give partial data + var frame_buf: [64]u8 = undefined; + const f = buildMaskedFrame(&frame_buf, .text, true, "Hello", .{ 0x37, 0xFA, 0x21, 0x3D }); + // Only copy part of the frame (header but not full payload) + const partial_len = 4; // just the header bytes + @memcpy(embedded_buf[0..partial_len], f[0..partial_len]); + reader.pos = partial_len; + + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader.nextMessage: control frame (ping)" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + var frame_buf: [64]u8 = undefined; + const f = buildMaskedFrame(&frame_buf, .ping, true, "ping", .{ 0xAA, 0xBB, 0xCC, 0xDD }); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.ping, msg.type); + try testing.expectEqualStrings("ping", msg.data); +} + +test "Reader.nextMessage: control frame (close)" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + // Close with status code 1000 + const close_payload = [_]u8{ 0x03, 0xE8 }; + var frame_buf: [64]u8 = undefined; + const f = buildMaskedFrame( + &frame_buf, + .close, + true, + &close_payload, + .{ 0x11, 0x22, 0x33, 0x44 }, + ); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.close, msg.type); + try testing.expectEqualSlices(u8, &close_payload, msg.data); +} + +test "Reader.nextMessage: fragment reassembly (text)" { + var embedded_buf: [256]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // Fragment 1: text FIN=0 "Hel" + const f1 = buildMaskedFrame(&frame_buf, .text, false, "Hel", mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Fragment 2: continuation FIN=1 "lo" + const f2 = buildMaskedFrame(&frame_buf, .continuation, true, "lo", mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + reader.pos = offset; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.text, msg.type); + try testing.expectEqualStrings("Hello", msg.data); +} + +test "Reader.nextMessage: interleaved control during fragmentation" { + var embedded_buf: [256]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // Fragment 1: text FIN=0 "Hel" + const f1 = buildMaskedFrame(&frame_buf, .text, false, "Hel", mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Interleaved ping + const f2 = buildMaskedFrame(&frame_buf, .ping, true, "p", mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + // Fragment 2: continuation FIN=1 "lo" + const f3 = buildMaskedFrame(&frame_buf, .continuation, true, "lo", mask_key); + @memcpy(embedded_buf[offset..][0..f3.len], f3); + offset += f3.len; + + reader.pos = offset; + + // First: should get the ping control frame + const msg1 = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.ping, msg1.type); + try testing.expectEqualStrings("p", msg1.data); + + // Second: should get the reassembled message + const msg2 = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.text, msg2.type); + try testing.expectEqualStrings("Hello", msg2.data); + + // No more + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader.nextMessage: readSlice grows when full" { + + // Tiny inline buffer that will fill up + var embedded_buf: [16]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + // Fill the buffer completely + reader.pos = 16; + + // readSlice should compact/grow instead of failing + const slice = try reader.readSlice(); + try testing.expect(slice.len > 0); +} + +test "Reader.nextMessage: readSlice compacts early when tail is tiny" { + var embedded_buf: [64]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + // Remaining tail = 2 bytes, reclaimable consumed prefix = 40 bytes. + // Heuristic should compact before we hit a full buffer. + reader.start = 40; + reader.pos = 62; + + const slice = try reader.readSlice(); + try testing.expectEqual(@as(usize, 0), reader.start); + try testing.expectEqual(@as(usize, 22), reader.pos); + try testing.expectEqual(@as(usize, 42), slice.len); +} + +test "Reader.nextMessage: readSlice does not compact early when tail is ample" { + var embedded_buf: [64]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + // start >= live_len holds (40 >= 10), but tail (14 bytes) is not tiny, + // so early compaction should not run. + reader.start = 40; + reader.pos = 50; + + const slice = try reader.readSlice(); + try testing.expectEqual(@as(usize, 40), reader.start); + try testing.expectEqual(@as(usize, 50), reader.pos); + try testing.expectEqual(@as(usize, 14), slice.len); +} + +test "Reader.nextMessage: readSlice compacts before growing" { + var embedded_buf: [16]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + // Consume half the buffer so start > 0, then fill to end + reader.start = 8; + reader.pos = 16; + + // readSlice should compact first (moving 8 bytes of data to front), + // giving us 8 free bytes without needing to grow + const slice = try reader.readSlice(); + try testing.expect(slice.len == 8); + try testing.expectEqual(@as(usize, 0), reader.start); + try testing.expectEqual(@as(usize, 8), reader.pos); +} + +test "Reader.nextMessage: protocol error on nested fragments" { + var embedded_buf: [256]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // Fragment 1: text FIN=0 "A" + const f1 = buildMaskedFrame(&frame_buf, .text, false, "A", mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Another text FIN=0 "B" — protocol violation (nested fragment) + const f2 = buildMaskedFrame(&frame_buf, .text, false, "B", mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + reader.pos = offset; + + // First call consumes fragment start, second should error + try testing.expectError(error.ProtocolError, reader.nextMessage()); +} + +test "Reader.nextMessage: protocol error on unexpected continuation" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + + // Continuation frame without a preceding fragment start + const f = buildMaskedFrame(&frame_buf, .continuation, true, "data", mask_key); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + try testing.expectError(error.ProtocolError, reader.nextMessage()); +} + +test "Reader.nextMessage: protocol error on unmasked frame" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + // Build an unmasked text frame manually + embedded_buf[0] = 0x81; // FIN=1, text + embedded_buf[1] = 0x05; // MASK=0, len=5 + @memcpy(embedded_buf[2..7], "Hello"); + reader.pos = 7; + + try testing.expectError(error.ProtocolError, reader.nextMessage()); +} + +test "Reader.nextMessage: message too large" { + var embedded_buf: [128]u8 = undefined; + // Very small max message size + var reader = testReader(&embedded_buf, 10); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + + // Frame with payload larger than max_message_size + const f = buildMaskedFrame(&frame_buf, .text, true, "This is too long!", mask_key); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + try testing.expectError(error.MessageTooBig, reader.nextMessage()); +} + +test "Reader.nextMessage: multiple frames in buffer" { + var embedded_buf: [256]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0xAA, 0xBB, 0xCC, 0xDD }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // Frame 1: text "Hi" + const f1 = buildMaskedFrame(&frame_buf, .text, true, "Hi", mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Frame 2: text "Bye" + const f2 = buildMaskedFrame(&frame_buf, .text, true, "Bye", mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + reader.pos = offset; + + // First message + const msg1 = (try reader.nextMessage()).?; + try testing.expectEqualStrings("Hi", msg1.data); + + // Second message + const msg2 = (try reader.nextMessage()).?; + try testing.expectEqualStrings("Bye", msg2.data); + + // No more + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader.nextMessage: pong control frame" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + var frame_buf: [64]u8 = undefined; + const f = buildMaskedFrame(&frame_buf, .pong, true, "pong", .{ 0x11, 0x22, 0x33, 0x44 }); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.pong, msg.type); + try testing.expectEqualStrings("pong", msg.data); +} + +test "Reader.nextMessage: empty buffer returns null" { + var embedded_buf: [128]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader.nextMessage: fragment reassembly (binary)" { + var embedded_buf: [256]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + const part1 = [_]u8{ 0xDE, 0xAD }; + const part2 = [_]u8{ 0xBE, 0xEF }; + + // Fragment 1: binary FIN=0 + const f1 = buildMaskedFrame(&frame_buf, .binary, false, &part1, mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Fragment 2: continuation FIN=1 + const f2 = buildMaskedFrame(&frame_buf, .continuation, true, &part2, mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + reader.pos = offset; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.binary, msg.type); + try testing.expectEqualSlices(u8, &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF }, msg.data); +} + +test "Reader.nextMessage: three-fragment reassembly" { + var embedded_buf: [256]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // Fragment 1: text FIN=0 "He" + const f1 = buildMaskedFrame(&frame_buf, .text, false, "He", mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Fragment 2: continuation FIN=0 "ll" + const f2 = buildMaskedFrame(&frame_buf, .continuation, false, "ll", mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + // Fragment 3: continuation FIN=1 "o" + const f3 = buildMaskedFrame(&frame_buf, .continuation, true, "o", mask_key); + @memcpy(embedded_buf[offset..][0..f3.len], f3); + offset += f3.len; + + reader.pos = offset; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.text, msg.type); + try testing.expectEqualStrings("Hello", msg.data); + + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader.nextMessage: fragmented message exceeds max size" { + var embedded_buf: [256]u8 = undefined; + // max_message_size = 8, so fragments totalling more than 8 should fail + var reader = testReader(&embedded_buf, 8); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // Fragment 1: 5 bytes + const f1 = buildMaskedFrame(&frame_buf, .text, false, "AAAAA", mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Fragment 2: 5 more bytes — total 10, exceeds max of 8 + const f2 = buildMaskedFrame(&frame_buf, .continuation, true, "BBBBB", mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + reader.pos = offset; + + try testing.expectError(error.MessageTooBig, reader.nextMessage()); +} + +test "Reader.nextMessage: multiple fragmented messages in sequence" { + var embedded_buf: [512]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // First fragmented message: "AB" + const f1 = buildMaskedFrame(&frame_buf, .text, false, "A", mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + const f2 = buildMaskedFrame(&frame_buf, .continuation, true, "B", mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + // Second fragmented message: "CD" + const f3 = buildMaskedFrame(&frame_buf, .binary, false, "C", mask_key); + @memcpy(embedded_buf[offset..][0..f3.len], f3); + offset += f3.len; + + const f4 = buildMaskedFrame(&frame_buf, .continuation, true, "D", mask_key); + @memcpy(embedded_buf[offset..][0..f4.len], f4); + offset += f4.len; + + reader.pos = offset; + + // First fragmented message + const msg1 = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.text, msg1.type); + try testing.expectEqualStrings("AB", msg1.data); + + // Second fragmented message + const msg2 = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.binary, msg2.type); + try testing.expectEqualStrings("CD", msg2.data); + + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader.nextMessage: reset cleans up active fragments" { + var embedded_buf: [256]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + + // Start a fragment but don't finish it + const f1 = buildMaskedFrame(&frame_buf, .text, false, "partial", mask_key); + @memcpy(embedded_buf[0..f1.len], f1); + reader.pos = f1.len; + + // Process the first fragment (sets fragment_opcode, appends to fragment_buf) + try testing.expect((try reader.nextMessage()) == null); + try testing.expect(reader.fragment_opcode != null); + + // Reset should clean up fragment state + reader.reset(); + + try testing.expectEqual(@as(?types.Opcode, null), reader.fragment_opcode); + try testing.expectEqual(@as(usize, 0), reader.fragment_buf.items.len); + try testing.expectEqual(reader.embedded_buf.ptr, reader.buf.ptr); + try testing.expectEqual(@as(usize, 0), reader.pos); + try testing.expectEqual(@as(usize, 0), reader.start); +} + +test "Reader.nextMessage: interleaved close during fragmentation" { + var embedded_buf: [256]u8 = undefined; + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + const mask_key = [_]u8{ 0x12, 0x34, 0x56, 0x78 }; + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // Fragment 1: text FIN=0 "Hel" + const f1 = buildMaskedFrame(&frame_buf, .text, false, "Hel", mask_key); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Interleaved close (status 1000) + const close_payload = [_]u8{ 0x03, 0xE8 }; + const f2 = buildMaskedFrame(&frame_buf, .close, true, &close_payload, mask_key); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + // Fragment 2: continuation FIN=1 "lo" + const f3 = buildMaskedFrame(&frame_buf, .continuation, true, "lo", mask_key); + @memcpy(embedded_buf[offset..][0..f3.len], f3); + offset += f3.len; + + reader.pos = offset; + + // First: should get the close control frame + const msg1 = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.close, msg1.type); + try testing.expectEqualSlices(u8, &close_payload, msg1.data); + + // Second: should get the reassembled message + const msg2 = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.text, msg2.type); + try testing.expectEqualStrings("Hello", msg2.data); + + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader: set pos for pre-loaded data" { + var embedded_buf: [128]u8 = undefined; + @memcpy(embedded_buf[0..5], "hello"); + var reader = testReader(&embedded_buf, 1024); + defer reader.deinit(); + + reader.pos = 5; + try testing.expectEqual(@as(usize, 5), reader.pos); + try testing.expectEqualStrings("hello", reader.data()); +} + +// --- Reader(.client) tests --- + +const ClientReader = Reader(.client); + +/// Build an unmasked WebSocket frame into `out`. Returns the slice of `out` that was written. +fn buildUnmaskedFrame(out: []u8, opcode: types.Opcode, fin: bool, payload: []const u8) []u8 { + const header_len = writeFrameHeader(out, opcode, fin, payload.len, false); + @memcpy(out[header_len..][0..payload.len], payload); + return out[0 .. header_len + payload.len]; +} + +fn testClientReader(embedded_buf: []u8, max_msg: usize) ClientReader { + return ClientReader.init(embedded_buf, testing.allocator, max_msg); +} + +test "Reader(.client): single unmasked text frame" { + var embedded_buf: [128]u8 = undefined; + var reader = testClientReader(&embedded_buf, 1024); + defer reader.deinit(); + + var frame_buf: [64]u8 = undefined; + const f = buildUnmaskedFrame(&frame_buf, .text, true, "Hello"); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.text, msg.type); + try testing.expectEqualStrings("Hello", msg.data); + try testing.expect((try reader.nextMessage()) == null); +} + +test "Reader(.client): rejects masked frame from server" { + var embedded_buf: [128]u8 = undefined; + var reader = testClientReader(&embedded_buf, 1024); + defer reader.deinit(); + + // Build a masked frame — invalid for server-to-client + var frame_buf: [64]u8 = undefined; + const f = buildMaskedFrame(&frame_buf, .text, true, "Hello", .{ 0x37, 0xFA, 0x21, 0x3D }); + @memcpy(embedded_buf[0..f.len], f); + reader.pos = f.len; + + try testing.expectError(error.ProtocolError, reader.nextMessage()); +} + +test "Reader(.client): fragment reassembly with unmasked frames" { + var embedded_buf: [256]u8 = undefined; + var reader = testClientReader(&embedded_buf, 1024); + defer reader.deinit(); + + var frame_buf: [64]u8 = undefined; + var offset: usize = 0; + + // Fragment 1: text FIN=0 "Hel" + const f1 = buildUnmaskedFrame(&frame_buf, .text, false, "Hel"); + @memcpy(embedded_buf[offset..][0..f1.len], f1); + offset += f1.len; + + // Fragment 2: continuation FIN=1 "lo" + const f2 = buildUnmaskedFrame(&frame_buf, .continuation, true, "lo"); + @memcpy(embedded_buf[offset..][0..f2.len], f2); + offset += f2.len; + + reader.pos = offset; + + const msg = (try reader.nextMessage()).?; + try testing.expectEqual(types.Message.Type.text, msg.type); + try testing.expectEqualStrings("Hello", msg.data); +} + +test "Reader: payload exactly at max_message_size via read loop" { + const max_msg_size = 100; + const payload = "A" ** max_msg_size; + const mask_key = [_]u8{ 0x37, 0xFA, 0x21, 0x3D }; + + // Build the full wire frame (header + mask + payload) + var wire_buf: [256]u8 = undefined; + const wire_frame = buildMaskedFrame(&wire_buf, .text, true, payload, mask_key); + // Sanity: wire frame is larger than max_message_size due to header overhead + try testing.expect(wire_frame.len > max_msg_size); + + // Small embedded buffer forces growth during the read loop + var embedded_buf: [32]u8 = undefined; + var reader = testReader(&embedded_buf, max_msg_size); + defer reader.deinit(); + + // Simulate the real read loop: readSlice → copy chunk → advancePos → nextMessage + var wire_offset: usize = 0; + while (wire_offset < wire_frame.len) { + const dest = try reader.readSlice(); + const remaining = wire_frame.len - wire_offset; + const chunk = @min(dest.len, remaining); + @memcpy(dest[0..chunk], wire_frame[wire_offset..][0..chunk]); + reader.advancePos(chunk); + wire_offset += chunk; + + if (try reader.nextMessage()) |msg| { + try testing.expectEqual(types.Message.Type.text, msg.type); + try testing.expectEqualStrings(payload, msg.data); + return; // success + } + } + // If we get here, all wire data was fed but no message was produced + return error.TestUnexpectedResult; +} diff --git a/src/rpc/webzockets/src/root.zig b/src/rpc/webzockets/src/root.zig new file mode 100644 index 0000000000..c4a3b0b157 --- /dev/null +++ b/src/rpc/webzockets/src/root.zig @@ -0,0 +1,35 @@ +const std = @import("std"); + +pub const types = @import("types.zig"); +pub const mask = @import("mask.zig"); +pub const http = @import("http.zig"); +pub const frame = @import("frame.zig"); +pub const reader = @import("reader.zig"); +pub const server = @import("server/server.zig"); +// Client modules +pub const client = @import("client/client.zig"); +pub const client_handshake = @import("client/handshake.zig"); +pub const client_connection = @import("client/connection.zig"); + +pub const Opcode = types.Opcode; +pub const Message = types.Message; +pub const ConnectionState = types.ConnectionState; +pub const CloseCode = types.CloseCode; +pub const FrameHeader = frame.FrameHeader; +pub const FrameError = frame.FrameError; +pub const HandshakeError = types.HandshakeError; +pub const HandshakeState = types.HandshakeState; +pub const ConnectionError = types.ConnectionError; +pub const Role = types.Role; +pub const Server = server.Server; +pub const ClientHandshake = client_handshake.ClientHandshake; +pub const ClientConnection = client_connection.ClientConnection; +pub const Client = client.Client; +pub const ClientMaskPRNG = types.ClientMaskPRNG; + +test { + // Uncomment to see logs during tests + // std.testing.log_level = .debug; + std.testing.refAllDecls(@This()); + std.testing.refAllDecls(@import("control_queue.zig")); +} diff --git a/src/rpc/webzockets/src/server/connection.zig b/src/rpc/webzockets/src/server/connection.zig new file mode 100644 index 0000000000..bb260cd496 --- /dev/null +++ b/src/rpc/webzockets/src/server/connection.zig @@ -0,0 +1,1037 @@ +const std = @import("std"); +const xev = @import("xev"); + +const types = @import("../types.zig"); +const frame = @import("../frame.zig"); +const Reader = @import("../reader.zig").Reader(.server); +const ControlQueue = @import("../control_queue.zig").ControlQueue; + +/// Poolable server connection that handles the WebSocket protocol phase +/// (after HTTP upgrade). Owns its read buffer, manages write scheduling, +/// and handles lifecycle (pool release, list removal) internally. +/// +/// Comptime parameters: +/// - `ServerType`: The server struct (provides pools, loop, active list). +/// - `Handler`: User handler for protocol events (see below). +/// - `read_buf_size`: Size of the embedded per-connection read buffer. +/// +/// **Required handler methods** (comptime-enforced): +/// - `onMessage(*Handler, *Conn, types.Message)` — complete message received. +/// - `onWriteComplete(*Handler, *Conn)` — send finished; caller buffer may be +/// freed. Also called on disconnect if a data write was in-flight/pending. +/// - `onClose(*Handler, *Conn)` — connection torn down (exactly once). The +/// connection is released back to the pool after this returns. +/// +/// **Optional handler methods** (detected via `@hasDecl`): +/// - `onOpen(*Handler, *Conn)` — protocol phase started, ready to send. +/// - `onPing(*Handler, *Conn, []const u8)` — ping received. When declared, +/// the handler is responsible for sending pong via `sendPong()`, which +/// enqueues each pong individually into the control queue. When absent, +/// the library auto-pongs using "latest wins" semantics: if multiple +/// pings arrive before a pong can be sent, only the most recent ping's +/// payload is used (permitted by RFC 6455 §5.5.3). Implement `onPing` +/// if you need to respond to every ping. +/// - `onPong(*Handler, *Conn, []const u8)` — pong received. +/// - `onBytesRead(*Handler, *Conn, usize)` — raw TCP data received. +/// Fires on every read completion regardless of whether reads are paused. +/// `usize` parameter is the number of bytes received. Combine with +/// `peekBufferedBytes()` to inspect raw data as it arrives. +/// +/// All data slices passed to read callbacks point into internal buffers and are +/// only valid for the duration of that callback. +pub fn Connection( + comptime ServerType: type, + comptime Handler: type, + comptime read_buf_size: usize, +) type { + comptime { + if (!@hasDecl(Handler, "onMessage")) + @compileError("Handler must declare an onMessage method"); + if (!@hasDecl(Handler, "onWriteComplete")) + @compileError("Handler must declare an onWriteComplete method (sends are zero-copy)"); + if (!@hasDecl(Handler, "onClose")) + @compileError("Handler must declare an onClose method"); + } + + return struct { + // -- Owned read buffer -- + read_buf: [read_buf_size]u8, + + // -- Core state -- + state: types.ConnectionState, + socket: xev.TCP, + allocator: std.mem.Allocator, + config: Config, + /// When true all message bytes received are buffered (up to current reader capacity) + /// but onMessage/onPing/onPong/close-frame handling is deferred until reads + /// are resumed. + read_paused: bool, + /// Re-entrancy guard: true while processMessages is on the call stack. + in_process_messages: bool, + + // -- Tiered read buffer management -- + reader: Reader, + + // -- Completions -- + read_completion: xev.Completion, + close_completion: xev.Completion, + write_completion: xev.Completion, + cancel_completion: xev.Completion, + + // -- Write state -- + write: WriteState, + + // -- Timer state -- + timer: xev.Timer, + idle_timer_completion: xev.Completion, + idle_timer_cancel_completion: xev.Completion, + close_timer_completion: xev.Completion, + close_timer_cancel_completion: xev.Completion, + received_data: bool, + + // -- Server back-pointer -- + server: *ServerType, + + // -- User handler, called directly -- + user_handler: Handler, + + // -- Intrusive list node for server's active_connections list -- + node: List.Node, + + const ConnectionSelf = @This(); + + const log = std.log.scoped(.server_connection); + + /// Comptime flags for optional handler capabilities. + const has = struct { + const on_open = @hasDecl(Handler, "onOpen"); + const on_ping = @hasDecl(Handler, "onPing"); + const on_pong = @hasDecl(Handler, "onPong"); + const on_bytes_read = @hasDecl(Handler, "onBytesRead"); + }; + + pub const List = std.DoublyLinkedList(*ConnectionSelf); + + pub const Config = struct { + /// Maximum total size of a reassembled fragmented message. + max_message_size: usize = 16 * 1024 * 1024, + /// Idle timeout in ms. Server sends close (going_away) if no data + /// received for this long. null = disabled (default). + idle_timeout_ms: ?u32 = null, + /// Close handshake timeout in ms. Force disconnect if peer doesn't + /// respond to our close frame within this duration. Default: 5000. + close_timeout_ms: u32 = 5_000, + }; + + /// Holds pending auto-pong payload when the handler does not declare + /// `onPing` (the library auto-responds). When `onPing` is declared this + /// is a zero-sized struct, adding no bytes to WriteState. + const AutoPongState = if (has.on_ping) + // Handler manages pong responses; no auto-pong state needed. + struct {} + else + struct { + /// Whether an auto-pong is pending (latest ping payload wins). + pending: bool = false, + /// Payload for the pending auto-pong. + data: [125]u8 = undefined, + /// Length of the pending auto-pong payload. + len: u8 = 0, + }; + + /// All write-path state, grouped for clarity. + const WriteState = struct { + /// What is currently being written to the socket. + const InFlight = union(enum) { + idle, + /// Two-phase data write: header bytes first, then caller's payload. + data: struct { + phase: enum { header, payload }, + offset: usize, + }, + /// Single-phase control frame write from control_buf. + control: struct { + offset: usize, + }, + }; + + /// Current in-flight write operation state. + in_flight: InFlight = .idle, + + // Data write buffers — populated in startDataWrite, used by both + // in-flight (.data) and deferred data writes. + /// Header buffer for current or pending data write. + header_buf: [10]u8 = undefined, + /// Length of actual header in header_buf. + header_len: usize = 0, + /// Payload for current or pending data write. + payload: []const u8 = &.{}, + + // Control frame buffers — control_buf holds the in-flight frame, + // control_queue holds deferred control frames. + /// Control frame buffer for in-flight control writes. + control_buf: [127]u8 = undefined, + /// Length of actual control frame in control_buf. + control_len: usize = 0, + /// Queue for pending control frames (close, ping, pong). + control_queue: ControlQueue = ControlQueue.init(), + /// Auto-pong state (only present when Handler lacks onPing). + auto_pong: AutoPongState = .{}, + /// True when we should disconnect after sending the current close frame. + /// Set when the peer initiates a close (we echo and disconnect) or when + /// we detect a protocol error (we send 1002 and disconnect). + peer_caused_close: bool = false, + }; + + // ==================================================================== + // Lifecycle + // ==================================================================== + + /// Initialize the connection in-place. Sets all runtime fields except + /// `read_buf` (the pool provides the struct and the buffer is embedded). + /// Reader is initialized with `&self.read_buf`. + pub fn init( + self: *ConnectionSelf, + socket: xev.TCP, + server: *ServerType, + user_handler: Handler, + allocator: std.mem.Allocator, + config: Config, + ) void { + self.socket = socket; + self.server = server; + self.user_handler = user_handler; + self.allocator = allocator; + self.config = config; + self.reader = Reader.init( + &self.read_buf, + allocator, + config.max_message_size, + ); + self.timer = xev.Timer.init() catch unreachable; + self.resetOpenRuntimeState(); + } + + /// Release any resources allocated by the connection (reader buffers). + /// Note: not used in pool-managed server connections — `SlotPool.release` + /// calls `reset()` which performs the same cleanup. Provided for manual + /// lifecycle management outside of a pool. + pub fn deinit(self: *ConnectionSelf) void { + self.reader.deinit(); + self.timer.deinit(); + } + + /// Reset the connection state for reuse from a pool. + /// Cleans up any fragment buffers and resets all state. + pub fn reset(self: *ConnectionSelf) void { + self.reader.reset(); + self.resetOpenRuntimeState(); + } + + fn resetOpenRuntimeState(self: *ConnectionSelf) void { + self.state = .open; + self.read_completion = .{}; + self.close_completion = .{}; + self.write_completion = .{}; + self.cancel_completion = .{}; + self.write = .{}; + self.idle_timer_completion = .{}; + self.idle_timer_cancel_completion = .{}; + self.close_timer_completion = .{}; + self.close_timer_cancel_completion = .{}; + self.received_data = false; + self.read_paused = false; + self.in_process_messages = false; + } + + /// Begin the WebSocket protocol phase. Sets the reader position to + /// account for leftover handshake bytes already in the read buffer, + /// calls `onOpen` (if defined), then processes any leftover data and + /// starts reading new data. + pub fn start(self: *ConnectionSelf, data_start: usize, data_end: usize) void { + std.debug.assert(data_start <= data_end); + std.debug.assert(data_end <= self.reader.buf.len); + self.reader.start = data_start; + self.reader.pos = data_end; + + if (comptime has.on_open) { + self.user_handler.onOpen(self); + } + + // NOTE: we check state is still open here to avoid arming the idle + // timer if the handler.onOpen() called close(). + if (self.state == .open) { + // Start idle timer if configured + if (self.config.idle_timeout_ms) |timeout_ms| { + self.timer.run( + self.server.loop, + &self.idle_timer_completion, + timeout_ms, + ConnectionSelf, + self, + onIdleTimerCallback, + ); + } + } + + // Process any leftover handshake data, and start reading + self.processMessages(); + } + + // ==================================================================== + // Public send API + // ==================================================================== + + /// Send a text message. The payload is written zero-copy; keep the buffer + /// alive and unmodified until `onWriteComplete` fires. Only one data + /// write can be in flight; returns `error.WriteBusy` if another send is + /// pending. Queue additional sends and retry from `onWriteComplete`. + pub fn sendText(self: *ConnectionSelf, data: []const u8) !void { + if (self.state != .open) return error.InvalidState; + try self.startDataWrite(.text, data); + } + + /// Send a binary message. The payload is written zero-copy; keep the buffer + /// alive and unmodified until `onWriteComplete` fires. Only one data + /// write can be in flight; returns `error.WriteBusy` if another send is + /// pending. Queue additional sends and retry from `onWriteComplete`. + pub fn sendBinary(self: *ConnectionSelf, data: []const u8) !void { + if (self.state != .open) return error.InvalidState; + try self.startDataWrite(.binary, data); + } + + /// Send a ping frame. The payload is copied into an internal control queue, + /// so the caller can free the buffer immediately after this returns. + /// No `onWriteComplete` callback fires for ping sends. + /// Returns `error.ControlFrameTooBig` if payload exceeds 125 bytes, + /// or `error.QueueFull` if the control queue has insufficient space. + pub fn sendPing(self: *ConnectionSelf, data: []const u8) !void { + if (self.state != .open) return error.InvalidState; + if (data.len > 125) return error.ControlFrameTooBig; + try self.enqueueAndFlush(.ping, data); + } + + /// Send a pong frame. The payload is copied into an internal control queue, + /// so the caller can free the buffer immediately after this returns. + /// No `onWriteComplete` callback fires for pong sends. + /// Returns `error.ControlFrameTooBig` if payload exceeds 125 bytes, + /// or `error.QueueFull` if the control queue has insufficient space. + pub fn sendPong(self: *ConnectionSelf, data: []const u8) !void { + if (self.state != .open) return error.InvalidState; + if (data.len > 125) return error.ControlFrameTooBig; + try self.enqueueAndFlush(.pong, data); + } + + /// Initiate a close handshake with the given status code and optional reason. + /// The connection transitions to `.closing` and waits for the peer's close response. + /// The reason is silently truncated to 123 bytes (the maximum allowed by RFC 6455 + /// after the 2-byte close code in a 125-byte control frame payload). + pub fn close(self: *ConnectionSelf, code: types.CloseCode, reason: []const u8) void { + if (self.state != .open) return; + self.initiateClose(code, reason); + + // Cancel idle timer to avoid repeated wakeups during the close handshake. + self.cancelTimer(&self.idle_timer_completion, &self.idle_timer_cancel_completion); + } + + /// Pause frame dispatch. While paused, onMessage/onPing/onPong/close-frame + /// handling stops until `resumeReads()` is called. TCP reads continue until + /// read buffer is full, but will not grow the buffer while paused. + pub fn pauseReads(self: *ConnectionSelf) void { + self.read_paused = true; + } + + /// Resume frame dispatch and drain any already-buffered frames, this will + /// cause onMessage/onPing/onPong/close-frame handling to resume. + pub fn resumeReads(self: *ConnectionSelf) void { + if (self.state == .closed or !self.read_paused) return; + self.read_paused = false; + self.processMessages(); + } + + /// Peek at the raw bytes currently buffered in the reader (received + /// from TCP but not yet consumed as websocket frames). The returned + /// slice points into an internal buffer and may be invalidated as soon + /// as the xev loop ticks again. + pub fn peekBufferedBytes(self: *ConnectionSelf) []const u8 { + return self.reader.buf[self.reader.start..self.reader.pos]; + } + + /// Perform the state transition and close frame enqueue. + /// Ensures the close-handshake timer is running. + fn initiateClose(self: *ConnectionSelf, code: types.CloseCode, reason: []const u8) void { + self.state = .closing; + + // Build close payload: 2-byte big-endian status code + reason text + var payload: [125]u8 = undefined; + payload[0..2].* = code.payloadBytes(); + const reason_len = @min(reason.len, 123); // control frame payload max 125 + @memcpy(payload[2..][0..reason_len], reason[0..reason_len]); + const total_len: u8 = @intCast(2 + reason_len); + + self.write.peer_caused_close = false; + self.enqueueClose(payload[0..total_len]); + + // Start close-handshake deadline timer. + // Uses a dedicated completion so it can't be clobbered by the idle + // timer's guaranteed `error.Canceled` callback. + if (self.close_timer_completion.state() != .active) { + self.timer.run( + self.server.loop, + &self.close_timer_completion, + self.config.close_timeout_ms, + ConnectionSelf, + self, + onCloseTimerCallback, + ); + } + } + + // ==================================================================== + // Read path + // ==================================================================== + + /// Arm another socket read when legal; while paused this continues + /// filling the read buffer until it has no free space. + fn maybeReadMore(self: *ConnectionSelf) void { + if (self.state == .closed or self.write.peer_caused_close) return; + if (self.read_completion.state() == .active) return; + if (self.read_paused) { + // Reclaim consumed bytes if buffer is full so we can buffer as much as + // possible without growing the read buffer. + self.reader.compactIfFull(); + // If still full then just return to avoid growing the buffer + if (self.reader.availableSpace() == 0) return; + } + self.startRead(); + } + + fn startRead(self: *ConnectionSelf) void { + const slice = self.reader.readSlice() catch |err| { + log.debug("readSlice failed: {}", .{err}); + self.handleDisconnect(); + return; + }; + self.socket.read( + self.server.loop, + &self.read_completion, + .{ .slice = slice }, + ConnectionSelf, + self, + onReadCallback, + ); + } + + fn onReadCallback( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + _: xev.ReadBuffer, + result: xev.ReadError!usize, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + if (self.state == .closed) { + self.checkAllDone(); + return .disarm; + } + + const bytes_read = result catch |err| { + log.debug("read failed: {}", .{err}); + self.handleDisconnect(); + return .disarm; + }; + + if (bytes_read == 0) { + log.debug("peer closed TCP connection", .{}); + self.handleDisconnect(); + return .disarm; + } + + self.received_data = true; + self.reader.advancePos(bytes_read); + + if (comptime has.on_bytes_read) { + self.user_handler.onBytesRead(self, bytes_read); + } + + if (self.read_paused) { + self.maybeReadMore(); + return .disarm; + } + + switch (self.state) { + .open, .closing => self.processMessages(), + .closed => {}, + } + + return .disarm; + } + + // ==================================================================== + // Frame processing + // ==================================================================== + + fn processMessages(self: *ConnectionSelf) void { + // Re-entrancy guard: prevents recursion into processMessages() when a handler + // calls resumeReads() + if (self.in_process_messages) return; + + if (self.read_paused) { + self.maybeReadMore(); + return; + } + + self.in_process_messages = true; + defer self.in_process_messages = false; + + while (true) { + const maybe_msg = self.reader.nextMessage() catch |err| { + log.debug("nextMessage failed: {}", .{err}); + switch (err) { + error.ProtocolError => self.failWithClose(.protocol_error), + error.MessageTooBig => self.failWithClose(.message_too_big), + error.OutOfMemory => self.handleDisconnect(), + } + return; + }; + const msg = maybe_msg orelse break; + + switch (msg.type) { + .text, .binary => self.user_handler.onMessage(self, msg), + .ping => { + if (comptime has.on_ping) { + self.user_handler.onPing(self, msg.data); + } else { + // Auto-pong: store in dedicated field (latest wins) + const len: u8 = @intCast(msg.data.len); + @memcpy(self.write.auto_pong.data[0..len], msg.data[0..len]); + self.write.auto_pong.len = len; + self.write.auto_pong.pending = true; + self.trySubmitNextControl(); + } + }, + .pong => { + if (comptime has.on_pong) { + self.user_handler.onPong(self, msg.data); + } + }, + .close => { + self.handleCloseFrame(msg.data); + if (self.state == .closed) return; + }, + } + if (self.state == .closed or self.read_paused) break; + } + // Only start a new read if we're still active. + // When peer_caused_close is set, we're echoing the peer's close frame + // and will disconnect as soon as the write completes — no need to read more. + self.maybeReadMore(); + } + + /// Send a close frame with the given code and disconnect. + /// Used for server-initiated error closes (protocol error, message too big). + fn failWithClose(self: *ConnectionSelf, code: types.CloseCode) void { + if (self.state == .closed) return; + self.state = .closing; + self.write.peer_caused_close = true; + const close_payload = code.payloadBytes(); + self.enqueueClose(&close_payload); + } + + fn handleCloseFrame(self: *ConnectionSelf, payload: []const u8) void { + if (self.state == .closing) { + // Already in closing state. If we're waiting to send our close response + // (peer_caused_close), ignore this duplicate. Otherwise, we initiated + // the close and this is the peer's response — complete the handshake. + if (!self.write.peer_caused_close) { + self.handleDisconnect(); + } + // else: peer sent multiple close frames — ignore per RFC 6455 + } else { + // Peer initiated — validate and echo the close frame, then disconnect + self.state = .closing; + self.write.peer_caused_close = true; + + const validation = types.validateClosePayload(payload); + switch (validation) { + .valid_payload => |vp| self.enqueueClose(vp), + .close_code => |code| { + const close_payload = code.payloadBytes(); + self.enqueueClose(&close_payload); + }, + } + } + } + + // ==================================================================== + // Write path + // ==================================================================== + + /// True when a user data write is in-flight or deferred. + fn outstandingUserWrite(self: *ConnectionSelf) bool { + return self.write.header_len != 0; + } + + /// Start a two-phase data write: header first, then payload. + /// If a control frame is in flight, the data write is deferred and will + /// start automatically when the control frame completes. + /// Returns error.WriteBusy if another data write is already in flight or pending. + fn startDataWrite(self: *ConnectionSelf, opcode: types.Opcode, payload: []const u8) !void { + // Only one data write can be pending/in-flight at a time + if (self.outstandingUserWrite()) { + return error.WriteBusy; + } + + // Build header into inline buffer (safe — header_buf is not used by control writes) + // Compression integration point: pass compressed=true when permessage-deflate is active + const header = frame.writeFrameHeader( + &self.write.header_buf, + opcode, + payload.len, + false, + ); + self.write.header_len = header.len; + self.write.payload = payload; + + if (self.write.in_flight == .control) { + // Control frame in flight — defer data write until it completes + return; + } + + // Idle — start immediately + self.write.in_flight = .{ .data = .{ .phase = .header, .offset = 0 } }; + self.submitWrite(self.write.header_buf[0..self.write.header_len]); + } + + /// Build a control frame into control_buf and issue the socket write. + fn writeControlFrame(self: *ConnectionSelf, opcode: types.Opcode, payload: []const u8) void { + const header = frame.writeFrameHeader( + self.write.control_buf[0..10], + opcode, + payload.len, + false, + ); + const header_len = header.len; + if (payload.len > 0) { + @memcpy(self.write.control_buf[header_len..][0..payload.len], payload); + } + self.write.control_len = header_len + payload.len; + self.write.in_flight = .{ .control = .{ .offset = 0 } }; + + self.submitWrite(self.write.control_buf[0..self.write.control_len]); + } + + /// Issue a socket write for the given slice. + fn submitWrite(self: *ConnectionSelf, slice: []const u8) void { + self.socket.write( + self.server.loop, + &self.write_completion, + .{ .slice = slice }, + ConnectionSelf, + self, + onWriteCallback, + ); + } + + /// Advance a write offset by bytes_written. If the full slice has been + /// written, returns true. Otherwise reissues a write for the remainder. + fn advanceWrite( + self: *ConnectionSelf, + offset: *usize, + bytes_written: usize, + full_slice: []const u8, + ) bool { + offset.* += bytes_written; + if (offset.* < full_slice.len) { + self.submitWrite(full_slice[offset.*..]); + return false; + } + return true; + } + + fn onWriteCallback( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + _: xev.WriteBuffer, + result: xev.WriteError!usize, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + + if (self.state == .closed) { + self.checkAllDone(); + return .disarm; + } + + const bytes_written = result catch |err| { + log.debug("write failed: {}", .{err}); + self.handleDisconnect(); + return .disarm; + }; + + switch (self.write.in_flight) { + .data => |*d| switch (d.phase) { + .header => { + const header_data = self.write.header_buf[0..self.write.header_len]; + if (self.advanceWrite(&d.offset, bytes_written, header_data)) { + if (self.write.payload.len == 0) { + self.finishWrite(); + } else { + d.phase = .payload; + d.offset = 0; + self.submitWrite(self.write.payload); + } + } + }, + .payload => { + if (self.advanceWrite(&d.offset, bytes_written, self.write.payload)) { + self.finishWrite(); + } + }, + }, + .control => |*ctrl| { + const control_data = self.write.control_buf[0..self.write.control_len]; + if (self.advanceWrite(&ctrl.offset, bytes_written, control_data)) { + self.finishControlWrite(); + } + }, + .idle => {}, + } + + return .disarm; + } + + fn finishWrite(self: *ConnectionSelf) void { + self.write.in_flight = .idle; + self.write.payload = &.{}; + self.write.header_len = 0; + // Flush pending controls first (priority over data) before invoking user callback + self.trySubmitNextControl(); + self.user_handler.onWriteComplete(self); + } + + fn finishControlWrite(self: *ConnectionSelf) void { + self.write.in_flight = .idle; + + if (self.write.peer_caused_close) { + // Close frame sent (peer-initiated echo or protocol error) — tear down + self.write.peer_caused_close = false; + self.handleDisconnect(); + return; + } + + // Pending controls have priority over pending data writes + self.trySubmitNextControl(); + if (self.write.in_flight == .idle and self.outstandingUserWrite()) { + // Start deferred data write (header already built in header_buf) + self.write.in_flight = .{ .data = .{ .phase = .header, .offset = 0 } }; + self.submitWrite(self.write.header_buf[0..self.write.header_len]); + } + } + + /// Enqueue a control frame and flush if the write path is idle. + fn enqueueAndFlush(self: *ConnectionSelf, opcode: types.Opcode, data: []const u8) !void { + try self.write.control_queue.enqueue(opcode, data); + self.trySubmitNextControl(); + } + + /// Enqueue a close frame, clearing any pending controls (close takes priority). + fn enqueueClose(self: *ConnectionSelf, payload: []const u8) void { + self.write.control_queue.clear(); + if (comptime !has.on_ping) { + self.write.auto_pong.pending = false; + } + self.write.control_queue.enqueue(.close, payload) catch unreachable; + self.trySubmitNextControl(); + } + + /// Flush the next pending control frame, if any. No-op when a write is + /// already in flight. Implements the priority chain: + /// 1. Close frame in the queue (always wins) + /// 2. Pending auto-pong (latest ping payload) + /// 3. Other queued control frames (FIFO) + fn trySubmitNextControl(self: *ConnectionSelf) void { + if (self.write.in_flight != .idle) return; + + var payload_buf: [125]u8 = undefined; + // Priority 1: Close frame in queue + if (self.write.control_queue.isNextClose()) { + const entry = self.write.control_queue.dequeue(&payload_buf).?; + self.writeControlFrame(entry.opcode, payload_buf[0..entry.len]); + return; + } + // Priority 2: Pending auto-pong + if (comptime !has.on_ping) { + if (self.write.auto_pong.pending) { + self.write.auto_pong.pending = false; + const pong_data = self.write.auto_pong.data[0..self.write.auto_pong.len]; + self.writeControlFrame(.pong, pong_data); + return; + } + } + // Priority 3: Other queued control frames + if (self.write.control_queue.dequeue(&payload_buf)) |entry| { + self.writeControlFrame(entry.opcode, payload_buf[0..entry.len]); + } + } + + // ==================================================================== + // Timer callbacks + // ==================================================================== + + fn onIdleTimerCallback( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + result: xev.Timer.RunError!void, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + + result catch |err| switch (err) { + error.Canceled => { + if (self.state == .closed) self.checkAllDone(); + return .disarm; + }, + error.Unexpected => |e| { + // should never happen + log.err("idle timer error: {}", .{e}); + // just disconnect to avoid leaking resources + self.handleDisconnect(); + return .disarm; + }, + }; + + switch (self.state) { + .open => { + if (self.received_data) { + // Activity detected — clear flag, re-run idle timer + self.received_data = false; + self.timer.run( + self.server.loop, + &self.idle_timer_completion, + self.config.idle_timeout_ms.?, + ConnectionSelf, + self, + onIdleTimerCallback, + ); + } else { + // Idle timeout + self.initiateClose(.going_away, ""); + } + }, + .closing => { + // We may be in the middle of a server-initiated close while the + // idle timer is being canceled; never treat an idle timer fire + // as a close timeout. + }, + .closed => self.checkAllDone(), + } + + return .disarm; + } + + fn onCloseTimerCallback( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + result: xev.Timer.RunError!void, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + + result catch |err| switch (err) { + error.Canceled => { + if (self.state == .closed) self.checkAllDone(); + return .disarm; + }, + error.Unexpected => |e| { + // should never happen + log.err("close timer error: {}", .{e}); + // fallthrough to enforce close/cleanup + }, + }; + + switch (self.state) { + .closing => { + if (!self.write.peer_caused_close) { + // Close timeout expired (server-initiated) — force disconnect + self.handleDisconnect(); + } + // else: peer-initiated close — write path handles disconnect + }, + .open => { + // Shouldn't happen: close timer is only armed for server-initiated closes. + }, + .closed => self.checkAllDone(), + } + + return .disarm; + } + + fn onTimerCancelled( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + result: xev.Timer.CancelError!void, + ) xev.CallbackAction { + // Cancel completion exists so we can wait for a backend-specific + // timer removal operation to complete (io_uring uses timer_remove). + result catch |err| log.debug("timer cancel error: {}", .{err}); + + // Only relevant for teardown if the connection closed while the + // cancellation was in-flight. + if (self_opt) |self| { + if (self.state == .closed) self.checkAllDone(); + } + return .disarm; + } + + // ==================================================================== + // Cleanup + // ==================================================================== + + fn handleDisconnect(self: *ConnectionSelf) void { + if (self.state == .closed) return; + self.state = .closed; + + // Notify handler if a data write was in flight or pending so buffer + // can be cleaned up. Not fired for internal control frame writes. + if (self.outstandingUserWrite()) { + self.user_handler.onWriteComplete(self); + } + + self.write = .{}; + self.user_handler.onClose(self); + + self.cancelTimer(&self.idle_timer_completion, &self.idle_timer_cancel_completion); + self.cancelTimer(&self.close_timer_completion, &self.close_timer_cancel_completion); + self.cancelActive(&self.read_completion, &self.cancel_completion); + self.cancelActive(&self.write_completion, &self.close_completion); + + self.checkAllDone(); + } + + /// Raw xev callback for cancel completions. + fn cancelCallback( + ud: ?*anyopaque, + _: *xev.Loop, + _: *xev.Completion, + result: xev.Result, + ) xev.CallbackAction { + result.cancel catch |err| log.debug("cancel failed: {}", .{err}); + if (ud) |ptr| { + const self: *ConnectionSelf = @ptrCast(@alignCast(ptr)); + self.checkAllDone(); + } + return .disarm; + } + + /// Cancel an active timer using its cancel completion slot. + fn cancelTimer( + self: *ConnectionSelf, + timer_completion: *xev.Completion, + cancel_slot: *xev.Completion, + ) void { + const is_active = timer_completion.state() == .active; + if (is_active and cancel_slot.state() != .active) { + // Cancel via xev.Timer.cancel() (not a raw `.cancel` op) since some + // backends (e.g. io_uring) require a timer-specific remove. + self.timer.cancel( + self.server.loop, + timer_completion, + cancel_slot, + ConnectionSelf, + self, + onTimerCancelled, + ); + } + } + + /// Cancel an active completion using a cancel slot. + fn cancelActive( + self: *ConnectionSelf, + target: *xev.Completion, + cancel_slot: *xev.Completion, + ) void { + if (target.state() == .active) { + cancel_slot.* = .{ + .op = .{ .cancel = .{ .c = target } }, + .userdata = @ptrCast(self), + .callback = cancelCallback, + }; + self.server.loop.add(cancel_slot); + } + } + + /// Check if all completions (read, write, cancel, close) are inactive. + /// If so, initiate async shutdown then close. + fn checkAllDone(self: *ConnectionSelf) void { + const completions = [_]*xev.Completion{ + &self.read_completion, + &self.write_completion, + &self.cancel_completion, + &self.close_completion, + &self.idle_timer_completion, + &self.idle_timer_cancel_completion, + &self.close_timer_completion, + &self.close_timer_cancel_completion, + }; + for (completions) |c| { + if (c.state() == .active) return; + } + + // Shutdown write side to send FIN, this is required by RFC 6455. + // Reuses close_completion since all cancellations are done. The + // shutdown callback then reuses it again for the final socket close. + self.socket.shutdown( + self.server.loop, + &self.close_completion, + ConnectionSelf, + self, + onShutdownComplete, + ); + } + + fn onShutdownComplete( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + result: xev.ShutdownError!void, + ) xev.CallbackAction { + // Shutdown errors are expected — socket may already be closed by client (ENOTCONN). + result catch |err| log.debug("shutdown failed: {}", .{err}); + if (self_opt) |self| { + self.closeSocket(); + } + return .disarm; + } + + fn closeSocket(self: *ConnectionSelf) void { + self.socket.close( + self.server.loop, + &self.close_completion, + ConnectionSelf, + self, + onCloseComplete, + ); + } + + fn onCloseComplete( + self_opt: ?*ConnectionSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + result: xev.CloseError!void, + ) xev.CallbackAction { + result catch |err| log.debug("close failed: {}", .{err}); + if (self_opt) |self| { + // Remove from server's active connections list and release to pool. + self.server.active_connections.remove(&self.node); + self.server.connection_pool.release(self); + } + return .disarm; + } + }; +} diff --git a/src/rpc/webzockets/src/server/handshake.zig b/src/rpc/webzockets/src/server/handshake.zig new file mode 100644 index 0000000000..87477f1335 --- /dev/null +++ b/src/rpc/webzockets/src/server/handshake.zig @@ -0,0 +1,271 @@ +const std = @import("std"); +const xev = @import("xev"); +const types = @import("../types.zig"); +const http = @import("../http.zig"); + +/// Poolable server handshake — reads the HTTP upgrade request, validates it, +/// writes the 101 response, creates the user handler, and transitions to +/// the connection phase. Owns its read buffer and holds a server back-pointer. +/// +/// Comptime parameters: +/// - `ServerType`: Server struct providing pools, loop, allocator, config, +/// active_connections, and shutting_down flag. +/// - `Handler`: User handler type. See `Server` doc comment for the full +/// handler contract including `init`, `onHandshakeFailed`, and Context. +/// - `read_buf_size`: Size of the embedded HTTP read buffer. +pub fn Handshake( + comptime ServerType: type, + comptime Handler: type, + comptime read_buf_size: usize, +) type { + return struct { + read_buf: [read_buf_size]u8, + response_buf: [response_buf_len]u8, + state: types.HandshakeState, + socket: xev.TCP, + read_pos: usize, + response_len: usize, + write_pos: usize, + head_parser: http.HeadParser, + read_completion: xev.Completion, + write_completion: xev.Completion, + close_completion: xev.Completion, + header_len: usize, + server: *ServerType, + user_handler: ?Handler, + + const HandshakeSelf = @This(); + + const log = std.log.scoped(.server_handshake); + const has_on_handshake_failed = @hasDecl(Handler, "onHandshakeFailed"); + + /// Fixed-size buffer length for the HTTP 101 Switching Protocols response. + /// 129 bytes: status line + headers + base64-encoded accept key + \r\n\r\n. + const response_buf_len = 129; + + /// Initialize the handshake in-place. Sets all runtime fields except + /// `read_buf` (the pool provides the struct and the buffer is embedded). + pub fn init(self: *HandshakeSelf, socket: xev.TCP, server: *ServerType) void { + self.reset(); + self.socket = socket; + self.server = server; + self.response_buf = undefined; + } + + /// Begin reading the HTTP upgrade request from the socket. + pub fn start(self: *HandshakeSelf) void { + self.startRead(); + } + + /// Reset all state for pool reuse. Does not touch `read_buf`. + pub fn reset(self: *HandshakeSelf) void { + self.state = .reading; + self.read_pos = 0; + self.response_len = 0; + self.write_pos = 0; + self.head_parser = .{}; + self.read_completion = .{}; + self.write_completion = .{}; + self.close_completion = .{}; + self.header_len = 0; + self.user_handler = null; + } + + fn startRead(self: *HandshakeSelf) void { + if (self.read_pos >= self.read_buf.len) { + log.debug("handshake failed: read buffer full", .{}); + self.fail(); + return; + } + + self.socket.read( + self.server.loop, + &self.read_completion, + .{ .slice = self.read_buf[self.read_pos..] }, + HandshakeSelf, + self, + onReadCallback, + ); + } + + fn onReadCallback( + self_opt: ?*HandshakeSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + _: xev.ReadBuffer, + result: xev.ReadError!usize, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + if (self.state != .reading) return .disarm; + + const bytes_read = result catch |err| { + log.debug("handshake read failed: {}", .{err}); + self.fail(); + return .disarm; + }; + + if (bytes_read == 0) { + log.debug("handshake failed: peer closed connection", .{}); + self.fail(); + return .disarm; + } + + const old_pos = self.read_pos; + self.read_pos += bytes_read; + + // Feed only the new bytes to the incremental head parser. + const consumed = self.head_parser.feed(self.read_buf[old_pos..self.read_pos]); + if (self.head_parser.state != .finished) { + self.startRead(); + return .disarm; + } + + self.header_len = old_pos + consumed; + self.processHandshake(); + return .disarm; + } + + fn processHandshake(self: *HandshakeSelf) void { + const req = http.parseRequest(self.read_buf[0..self.header_len]) catch |err| { + log.debug("handshake failed: invalid HTTP request: {}", .{err}); + self.fail(); + return; + }; + + self.user_handler = Handler.init(req, self.server.config.handler_context) catch |err| { + log.debug("handshake failed: handler rejected connection: {}", .{err}); + self.fail(); + return; + }; + + const response = http.writeResponse(&self.response_buf, req.websocket_key) catch |err| { + log.debug("handshake failed: response write error: {}", .{err}); + self.fail(); + return; + }; + + self.response_len = response.len; + self.write_pos = 0; + self.state = .writing; + self.startWrite(); + } + + fn startWrite(self: *HandshakeSelf) void { + self.socket.write( + self.server.loop, + &self.write_completion, + .{ .slice = self.response_buf[self.write_pos..self.response_len] }, + HandshakeSelf, + self, + onWriteComplete, + ); + } + + fn onWriteComplete( + self_opt: ?*HandshakeSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + _: xev.WriteBuffer, + result: xev.WriteError!usize, + ) xev.CallbackAction { + const self = self_opt orelse return .disarm; + if (self.state != .writing) return .disarm; + + const bytes_written = result catch |err| { + log.debug("handshake write failed: {}", .{err}); + self.fail(); + return .disarm; + }; + + self.write_pos += bytes_written; + + if (self.write_pos < self.response_len) { + self.startWrite(); + return .disarm; + } + + self.state = .completed; + self.transitionToConnection(); + return .disarm; + } + + fn transitionToConnection(self: *HandshakeSelf) void { + if (self.server.shutting_down) { + log.debug("handshake rejected: shutting down", .{}); + self.fail(); + return; + } + + const conn = self.server.connection_pool.create() catch { + log.debug("handshake rejected: connection pool exhausted", .{}); + self.fail(); + return; + }; + + defer self.server.handshake_pool.release(self); + + const leftover = self.read_buf[self.header_len..self.read_pos]; + + if (leftover.len > 0) { + @memcpy(conn.read_buf[0..leftover.len], leftover); + } + + conn.init( + self.socket, + self.server, + self.user_handler.?, + self.server.allocator, + .{ + .max_message_size = self.server.config.max_message_size, + .idle_timeout_ms = self.server.config.idle_timeout_ms, + .close_timeout_ms = self.server.config.close_timeout_ms, + }, + ); + + conn.node = .{ .data = conn }; + self.server.active_connections.append(&conn.node); + conn.start(0, leftover.len); + } + + /// Mark state as failed and close the socket. + /// RFC 6455 §4.2.2 says the server should return an appropriate HTTP error + /// code for failed WebSocket validation, but we close without a response + /// for simplicity. + fn fail(self: *HandshakeSelf) void { + self.state = .failed; + if (self.user_handler) |*handler| { + if (comptime has_on_handshake_failed) { + handler.onHandshakeFailed(); + } + self.user_handler = null; + } + self.closeAndRelease(); + } + + fn closeAndRelease(self: *HandshakeSelf) void { + self.socket.close( + self.server.loop, + &self.close_completion, + HandshakeSelf, + self, + onCloseComplete, + ); + } + + fn onCloseComplete( + self_opt: ?*HandshakeSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + result: xev.CloseError!void, + ) xev.CallbackAction { + result catch |err| log.debug("handshake close failed: {}", .{err}); + if (self_opt) |self| { + self.server.handshake_pool.release(self); + } + return .disarm; + } + }; +} diff --git a/src/rpc/webzockets/src/server/server.zig b/src/rpc/webzockets/src/server/server.zig new file mode 100644 index 0000000000..95a4d55385 --- /dev/null +++ b/src/rpc/webzockets/src/server/server.zig @@ -0,0 +1,394 @@ +const std = @import("std"); +const xev = @import("xev"); + +const slot_pool = @import("slot_pool.zig"); +const server_hs = @import("handshake.zig"); +const server_conn = @import("connection.zig"); + +/// TCP listener wrapping a libxev accept loop with memory-pooled connections. +/// +/// Manages the full lifecycle of WebSocket connections: accepting TCP +/// connections, performing the HTTP upgrade handshake, and running the +/// WebSocket protocol. Pre-allocates memory pools for handshakes, +/// connections, and large-message buffers. +/// +/// Comptime parameters: +/// - `Handler`: User handler type for protocol events (see below). +/// - `read_buf_size`: Size of per-connection embedded read buffer +/// (fast path for small messages). +/// +/// **Required handler declarations** (comptime-enforced): +/// - `pub const Context = T` — the handler's context *pointee* type (`void` for none). +/// When `Context != void`, the caller provides `Config.handler_context: *T`, which is +/// passed to `init` as the second parameter. The pointer must outlive any handshake/ +/// connection that might call `init` or `onHandshakeFailed`. +/// +/// **Required handler methods** (comptime-enforced): +/// - `init(http.Request, if (Context == void) void else *Context) !Handler` — +/// factory method called during the HTTP upgrade, before the 101 response is sent. +/// Return an error to reject the connection. The handler pointer is stable for the +/// lifetime of the connection. +/// - `onMessage(*Handler, *Conn, types.Message)` — complete message +/// received (text or binary, after reassembly of fragments). +/// - `onWriteComplete(*Handler, *Conn)` — data send finished; the +/// caller's buffer may now be freed or reused. Also called on +/// disconnect if a data write was in-flight or pending. Required +/// because sends are zero-copy — the caller must not free or reuse +/// the buffer until this fires. Not called for `sendPing`/`sendPong` +/// (those copy into an internal queue; the caller can free immediately). +/// - `onClose(*Handler, *Conn)` — connection torn down (called exactly +/// once). The connection is released back to the pool after this +/// returns. +/// +/// **Optional handler methods** (detected via `@hasDecl`): +/// - `onHandshakeFailed(*Handler)` — called if the handshake fails after +/// `init` succeeds (e.g., write error, connection pool exhausted, +/// server shutdown). Use to clean up resources allocated in `init`. +/// Neither `onOpen` nor `onClose` will fire when this is called. +/// There is no connection parameter — only `*Handler` is available. +/// - `onOpen(*Handler, *Conn)` — protocol phase started, connection is +/// ready to send. Called after successful `init`, before reading +/// messages. +/// - `onPing(*Handler, *Conn, []const u8)` — ping received. When +/// declared, the handler is responsible for sending pong via +/// `conn.sendPong()`, which enqueues each pong individually into the +/// control queue. When absent, the library auto-pongs using "latest +/// wins" semantics: if multiple pings arrive before a pong can be +/// sent, only the most recent ping's payload is used (permitted by +/// RFC 6455 §5.5.3). Implement `onPing` if you need to respond to +/// every ping. +/// - `onPong(*Handler, *Conn, []const u8)` — unsolicited or solicited +/// pong received. Useful for latency measurement with application- +/// level ping/pong. +/// +/// Connection tracking is delegated to the user via handler callbacks. +pub fn Server( + comptime Handler: type, + comptime read_buf_size: usize, +) type { + comptime { + if (!@hasDecl(Handler, "Context")) + @compileError("Handler must declare `pub const Context = T`"); + if (!@hasDecl(Handler, "init")) + @compileError("Handler must declare init(request, context) !Handler"); + if (!@hasDecl(Handler, "onMessage")) + @compileError("Handler must declare an onMessage method"); + if (!@hasDecl(Handler, "onWriteComplete")) + @compileError("Handler must declare an onWriteComplete method (sends are zero-copy)"); + if (!@hasDecl(Handler, "onClose")) + @compileError("Handler must declare an onClose method"); + } + + return struct { + allocator: std.mem.Allocator, + loop: *xev.Loop, + config: Config, + listen_socket: xev.TCP, + accept_completion: xev.Completion, + handshake_pool: HandshakePool, + connection_pool: ConnectionPool, + shutting_down: bool, + listen_socket_closed: bool, + /// Connections in WebSocket phase; walked during shutdown to issue close frames. + active_connections: ConnectionList, + /// Timer for periodic shutdown drain check. + shutdown_timer: xev.Timer, + shutdown_timer_completion: xev.Completion, + /// For closing listen socket. + listen_close_completion: xev.Completion, + /// Absolute nanoTimestamp after which shutdown times out. + shutdown_deadline: i128, + shutdown_userdata: ?*anyopaque, + + const ServerSelf = @This(); + + const log = std.log.scoped(.server); + + const HandlerContext = if (Handler.Context != void) *Handler.Context else void; + + const HsImpl = server_hs.Handshake(ServerSelf, Handler, read_buf_size); + const ConnImpl = server_conn.Connection(ServerSelf, Handler, read_buf_size); + const ConnectionList = ConnImpl.List; + const HandshakePool = slot_pool.SlotPool(HsImpl); + const ConnectionPool = slot_pool.SlotPool(ConnImpl); + + pub const ShutdownResult = enum { clean, timed_out }; + + /// The per-connection WebSocket type exposed to handler callbacks. + /// Provides `sendText`, `sendBinary`, `sendPing`, `sendPong`, and `close` methods. + pub const Conn = ConnImpl; + + pub const Config = struct { + /// Address to bind and listen on. + address: std.net.Address, + /// Typed context passed to `Handler.init`. `*T` when + /// `Handler.Context` is non-void; `void` (`{}`) otherwise. + handler_context: HandlerContext, + /// TCP listen backlog. + tcp_accept_backlog: u31 = 128, + /// Maximum total size of a reassembled fragmented message. + max_message_size: usize = 16 * 1024 * 1024, + /// Number of handshake slots to pre-allocate in the pool. + initial_handshake_pool_size: usize = 16, + /// Number of connection slots to pre-allocate in the pool. + initial_connection_pool_size: usize = 64, + /// Maximum number of concurrent handshakes. Null means unlimited. + max_handshakes: ?usize = null, + /// Maximum number of concurrent connections. Null means unlimited. + max_connections: ?usize = null, + /// Idle timeout in ms. Server sends close (going_away) if no data + /// received for this long. null = disabled (default). + idle_timeout_ms: ?u32 = null, + /// Close handshake timeout in ms. Force disconnect if peer doesn't + /// respond to our close frame within this duration. Default: 5000. + close_timeout_ms: u32 = 5_000, + }; + + /// Create a server: opens, binds, and listens on the configured address. + /// Pre-heats the memory pools with the configured initial sizes. + /// The event loop is not started — call `accept()` then run the loop. + pub fn init(allocator: std.mem.Allocator, loop: *xev.Loop, config: Config) !ServerSelf { + // The kqueue and epoll backends require a thread pool to be set on + // the loop, otherwise they cannot perform socket close operations. + if (comptime @hasField(xev.Loop, "thread_pool")) { + std.debug.assert(loop.thread_pool != null); + } + + const listen_socket = try xev.TCP.init(config.address); + errdefer std.posix.close(listen_socket.fd); + + try listen_socket.bind(config.address); + try listen_socket.listen(config.tcp_accept_backlog); + + // Create pools with preheating (handshake pool has no limit, connection pool may be limited) + var hs_pool = HandshakePool.init(allocator, config.max_handshakes); + errdefer hs_pool.deinit(); + try hs_pool.preheat(config.initial_handshake_pool_size); + + var conn_pool = ConnectionPool.init(allocator, config.max_connections); + errdefer conn_pool.deinit(); + try conn_pool.preheat(config.initial_connection_pool_size); + + return .{ + .allocator = allocator, + .loop = loop, + .config = config, + .listen_socket = listen_socket, + .accept_completion = .{}, + .handshake_pool = hs_pool, + .connection_pool = conn_pool, + .shutting_down = false, + .listen_socket_closed = false, + .active_connections = .{}, + .shutdown_timer = undefined, + .shutdown_timer_completion = .{}, + .listen_close_completion = .{}, + .shutdown_deadline = 0, + .shutdown_userdata = null, + }; + } + + /// Close the listen socket and clean up memory pools. + /// Does not affect active connections (they continue until closed). + pub fn deinit(self: *ServerSelf) void { + if (!self.listen_socket_closed) { + std.posix.close(self.listen_socket.fd); + } + self.handshake_pool.deinit(); + self.connection_pool.deinit(); + } + + /// Start the asynchronous accept loop. Each accepted connection goes through + /// the HTTP upgrade handshake, then transitions to the WebSocket protocol phase. + /// The accept loop re-arms automatically after each connection. + pub fn accept(self: *ServerSelf) void { + self.listen_socket.accept( + self.loop, + &self.accept_completion, + ServerSelf, + self, + acceptCallback, + ); + } + + fn acceptCallback( + self_opt: ?*ServerSelf, + _: *xev.Loop, + completion: *xev.Completion, + result: xev.AcceptError!xev.TCP, + ) xev.CallbackAction { + // NOTE: in this callback we always return .disarm, even if just calling self.accept() again, + // because of bug with libxev kqueue backend when returning .rearm + const self = self_opt orelse return .disarm; + + if (self.shutting_down) { + const client_socket = result catch return .disarm; + log.debug("accept rejected: shutting down", .{}); + client_socket.close(self.loop, completion, ServerSelf, self, onRejectCloseComplete); + return .disarm; + } + + const client_socket = result catch |err| { + log.debug("accept failed: {}", .{err}); + // Accept failed — re-register to keep listening. + self.accept(); + return .disarm; + }; + + if (!self.setupConnection(client_socket)) { + // Pool exhausted, close socket asynchronously, then resume accepting + client_socket.close( + self.loop, + completion, + ServerSelf, + self, + onRejectCloseComplete, + ); + return .disarm; + } + + // Re-register to accept the next connection. + self.accept(); + return .disarm; + } + + fn setupConnection(self: *ServerSelf, client_socket: xev.TCP) bool { + // Acquire handshake slot from pool + const hs = self.handshake_pool.create() catch { + log.debug("setupConnection: handshake pool exhausted", .{}); + return false; + }; + + // Initialize and start the handshake + hs.init(client_socket, self); + hs.start(); + return true; + } + + fn onRejectCloseComplete( + self_opt: ?*ServerSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + result: xev.CloseError!void, + ) xev.CallbackAction { + result catch |err| log.debug("rejected connection close failed: {}", .{err}); + if (self_opt) |self| { + if (!self.shutting_down) self.accept(); + } + return .disarm; + } + + fn onListenSocketCloseComplete( + self_opt: ?*ServerSelf, + _: *xev.Loop, + _: *xev.Completion, + _: xev.TCP, + result: xev.CloseError!void, + ) xev.CallbackAction { + result catch |err| log.debug("listen socket close failed: {}", .{err}); + if (self_opt) |self| { + log.debug("listen socket closed", .{}); + self.listen_socket_closed = true; + } + return .disarm; + } + + fn isFullyDrained(self: *const ServerSelf) bool { + return self.connection_pool.active_count == 0 and + self.handshake_pool.active_count == 0 and + self.listen_socket_closed; + } + + /// Initiate graceful shutdown: stop accepting connections, close all active + /// WebSocket connections with a going_away close frame, and invoke the callback + /// when fully drained or after the timeout expires. + pub fn shutdown( + self: *ServerSelf, + max_wait_ms: u64, + comptime Context: type, + context: ?*Context, + comptime onComplete: fn (?*Context, ShutdownResult) void, + ) void { + log.debug("shutdown: max_wait_ms={d}, active_connections={d}, active_handshakes={d}", .{ + max_wait_ms, + self.connection_pool.active_count, + self.handshake_pool.active_count, + }); + self.shutting_down = true; + self.shutdown_userdata = @ptrCast(context); + self.shutdown_deadline = std.time.nanoTimestamp() + + @as(i128, max_wait_ms) * std.time.ns_per_ms; + + // Close the listen socket to stop accepting new connections + self.listen_socket.close( + self.loop, + &self.listen_close_completion, + ServerSelf, + self, + onListenSocketCloseComplete, + ); + + // Close all active WebSocket connections + var it = self.active_connections.first; + while (it) |node| { + it = node.next; + node.data.close(.going_away, ""); + } + + // Check if already drained (no connections, listen close may fire synchronously) + if (self.isFullyDrained()) { + onComplete(@ptrCast(@alignCast(self.shutdown_userdata)), .clean); + return; + } + + // Start periodic timer to check for drain completion + const S = struct { + fn onTimer( + s_opt: ?*ServerSelf, + _: *xev.Loop, + _: *xev.Completion, + result: xev.Timer.RunError!void, + ) xev.CallbackAction { + result catch |err| log.err("shutdown timer failed: {}", .{err}); + const s = s_opt orelse return .disarm; + if (s.isFullyDrained()) { + log.debug("shutdown: fully drained", .{}); + onComplete(@ptrCast(@alignCast(s.shutdown_userdata)), .clean); + return .disarm; + } + if (std.time.nanoTimestamp() >= s.shutdown_deadline) { + log.debug("shutdown: timed out, connections={d}, handshakes={d}", .{ + s.connection_pool.active_count, + s.handshake_pool.active_count, + }); + onComplete(@ptrCast(@alignCast(s.shutdown_userdata)), .timed_out); + return .disarm; + } + s.shutdown_timer.run( + s.loop, + &s.shutdown_timer_completion, + 10, + ServerSelf, + s, + onTimer, + ); + return .disarm; + } + }; + + self.shutdown_timer = xev.Timer.init() catch |err| { + log.err("shutdown: failed to create drain timer (should never happen): {}", .{err}); + return; + }; + self.shutdown_timer.run( + self.loop, + &self.shutdown_timer_completion, + 10, + ServerSelf, + self, + S.onTimer, + ); + } + }; +} diff --git a/src/rpc/webzockets/src/server/slot_pool.zig b/src/rpc/webzockets/src/server/slot_pool.zig new file mode 100644 index 0000000000..6ec5465ce9 --- /dev/null +++ b/src/rpc/webzockets/src/server/slot_pool.zig @@ -0,0 +1,49 @@ +const std = @import("std"); + +/// Wrapper around MemoryPool that ensures reset() is called before returning to pool. +/// Optionally enforces a maximum number of active slots (independent of preheat count). +pub fn SlotPool(comptime Slot: type) type { + return struct { + inner: std.heap.MemoryPool(Slot), + /// Number of slots currently checked out (not in pool). + active_count: usize, + /// Upper bound on active_count; null means unlimited. + max_size: ?usize, + + const SlotPoolSelf = @This(); + + pub fn init(allocator: std.mem.Allocator, max_size: ?usize) SlotPoolSelf { + return .{ + .inner = std.heap.MemoryPool(Slot).init(allocator), + .active_count = 0, + .max_size = max_size, + }; + } + + pub fn deinit(self: *SlotPoolSelf) void { + self.inner.deinit(); + } + + pub fn preheat(self: *SlotPoolSelf, count: usize) !void { + try self.inner.preheat(count); + } + + pub fn create(self: *SlotPoolSelf) !*Slot { + if (self.max_size) |max| { + if (self.active_count >= max) { + return error.PoolExhausted; + } + } + const slot = try self.inner.create(); + self.active_count += 1; + return slot; + } + + /// Release a slot back to the pool, calling reset() first to clean up state. + pub fn release(self: *SlotPoolSelf, slot: *Slot) void { + slot.reset(); + self.inner.destroy(slot); + self.active_count -= 1; + } + }; +} diff --git a/src/rpc/webzockets/src/types.zig b/src/rpc/webzockets/src/types.zig new file mode 100644 index 0000000000..b29366f516 --- /dev/null +++ b/src/rpc/webzockets/src/types.zig @@ -0,0 +1,271 @@ +const std = @import("std"); + +/// Whether this endpoint acts as a WebSocket client or server. +/// Affects frame validation (mask requirements) and header construction. +pub const Role = enum { client, server }; + +/// WebSocket frame opcodes per RFC 6455 Section 5.2. +pub const Opcode = enum(u4) { + continuation = 0x0, + text = 0x1, + binary = 0x2, + // 0x3-0x7 reserved for non-control frames + close = 0x8, + ping = 0x9, + pong = 0xA, + // 0xB-0xF reserved for control frames + + /// Returns true for control opcodes (close, ping, pong). + /// Control frames have opcodes where the high bit of the 4-bit field is set (>= 0x8). + pub fn isControl(self: Opcode) bool { + return @intFromEnum(self) >= 0x8; + } +}; + +/// High-level message type delivered to user callbacks. +pub const Message = struct { + type: Type, + data: []const u8, + + pub const Type = enum { + text, + binary, + close, + ping, + pong, + }; +}; + +/// Per-connection state machine states (WebSocket protocol phase only). +pub const ConnectionState = enum { + open, + closing, + closed, +}; + +/// Handshake state machine states. +pub const HandshakeState = enum { + reading, + writing, + completed, + failed, +}; + +/// WebSocket close status codes per RFC 6455 Section 7.4.1. +pub const CloseCode = enum(u16) { + normal = 1000, + going_away = 1001, + protocol_error = 1002, + unsupported_data = 1003, + // 1004 reserved + no_status = 1005, + abnormal = 1006, + invalid_payload = 1007, + policy_violation = 1008, + message_too_big = 1009, + mandatory_extension = 1010, + internal_error = 1011, + // 1015 reserved (TLS handshake failure) + + /// Returns true if the code is valid for use in a close frame sent on the wire. + /// Codes 1005, 1006, and 1015 are designated for use in APIs/logs only, never sent. + pub fn isValidForWire(self: CloseCode) bool { + const code = @intFromEnum(self); + return code != 1005 and code != 1006; + } + + /// Validates a raw u16 close code per RFC 6455 Section 7.4. + /// Valid ranges: 1000-1003, 1007-1011, 3000-4999. + pub fn isValidCode(code: u16) bool { + if (code < 1000) return false; + if (code == 1004 or code == 1005 or code == 1006) return false; + if (code >= 1012 and code <= 2999) return false; + if (code >= 5000) return false; + return true; + } + + /// Returns the close code as a 2-byte big-endian value, ready to use as a + /// close frame payload on the wire. + pub fn payloadBytes(self: CloseCode) [2]u8 { + var payload: [2]u8 = undefined; + std.mem.writeInt(u16, &payload, @intFromEnum(self), .big); + return payload; + } +}; + +/// Errors that can occur during WebSocket handshake. +pub const HandshakeError = error{ + InvalidMethod, + InvalidHttpVersion, + MissingUpgradeHeader, + MissingConnectionHeader, + MissingWebSocketKey, + MissingWebSocketVersion, + UnsupportedWebSocketVersion, + MalformedRequest, + MalformedResponse, + InvalidStatusCode, + MissingAcceptHeader, + InvalidAcceptKey, +}; + +/// Errors that can occur on a WebSocket connection. +pub const ConnectionError = error{ + ConnectionClosed, + InvalidState, + InvalidCloseCode, + MessageTooBig, + UnexpectedContinuation, + NestedFragment, + WriteError, + ReadError, + WriteBusy, + QueueFull, + ControlFrameTooBig, +}; + +/// Result of validating a close frame payload. +pub const ClosePayloadValidation = union(enum) { + /// The original payload is well-formed and can be echoed back as-is. + valid_payload: []const u8, + /// The payload was invalid; send this code in the close response instead. + close_code: CloseCode, +}; + +/// Validate a close frame payload per RFC 6455 Section 7.4. +/// Returns the original payload (if valid) or a CloseCode to send in response. +/// +/// Invalid cases: +/// - Payload of 1 byte (must be 0 or >=2) +/// - Close code outside valid ranges (see `CloseCode.isValidCode`) +/// - Reason text (bytes 2+) contains invalid UTF-8 +pub fn validateClosePayload(payload: []const u8) ClosePayloadValidation { + // Empty payload is valid (no status code) + if (payload.len == 0) return .{ .valid_payload = payload }; + + // 1 byte is invalid — code requires 2 bytes + if (payload.len == 1) { + return .{ .close_code = .protocol_error }; + } + + // Extract and validate the close code + const code = std.mem.readInt(u16, payload[0..2], .big); + if (!CloseCode.isValidCode(code)) { + return .{ .close_code = .protocol_error }; + } + + // Validate UTF-8 in reason text (if present) + if (payload.len > 2) { + const reason = payload[2..]; + if (!std.unicode.utf8ValidateSlice(reason)) { + return .{ .close_code = .invalid_payload }; + } + } + + // Valid — echo the original payload + return .{ .valid_payload = payload }; +} + +/// Wrapper around `std.Random.DefaultCsprng` used by WebSocket clients for +/// mask-key and handshake-key generation. Provides only the `fill()` method +/// needed by the library, decoupling the public API from the concrete CSPRNG type. +/// +/// Not thread-safe — use only from the `loop.run()` thread and do not share +/// across loops/threads. The pointer must remain stable for the lifetime of +/// any `ClientConnection` using it. +pub const ClientMaskPRNG = struct { + inner: std.Random.DefaultCsprng, + + /// Seed length required by the underlying CSPRNG. Callers should fill a + /// `[secret_seed_length]u8` buffer with cryptographically-random bytes + /// (e.g. `std.posix.getrandom` or `std.crypto.random.bytes`) before + /// passing it to `init`. + pub const secret_seed_length = std.Random.DefaultCsprng.secret_seed_length; + + /// Initialize from a seed. The seed must be cryptographically random + /// (e.g., from `std.crypto.random.bytes`). + pub fn init(seed: [secret_seed_length]u8) ClientMaskPRNG { + return .{ .inner = std.Random.DefaultCsprng.init(seed) }; + } + + /// Fill a buffer with cryptographically-strong random bytes. + pub fn fill(self: *ClientMaskPRNG, buf: []u8) void { + self.inner.fill(buf); + } +}; + +const testing = std.testing; + +test "Opcode: isControl returns true for control frames" { + try testing.expect(Opcode.close.isControl()); + try testing.expect(Opcode.ping.isControl()); + try testing.expect(Opcode.pong.isControl()); +} + +test "Opcode: isControl returns false for data frames" { + try testing.expect(!Opcode.continuation.isControl()); + try testing.expect(!Opcode.text.isControl()); + try testing.expect(!Opcode.binary.isControl()); +} + +test "Opcode: enum values match RFC 6455" { + try testing.expectEqual(@as(u4, 0x0), @intFromEnum(Opcode.continuation)); + try testing.expectEqual(@as(u4, 0x1), @intFromEnum(Opcode.text)); + try testing.expectEqual(@as(u4, 0x2), @intFromEnum(Opcode.binary)); + try testing.expectEqual(@as(u4, 0x8), @intFromEnum(Opcode.close)); + try testing.expectEqual(@as(u4, 0x9), @intFromEnum(Opcode.ping)); + try testing.expectEqual(@as(u4, 0xA), @intFromEnum(Opcode.pong)); +} + +test "CloseCode: standard codes have correct values" { + try testing.expectEqual(@as(u16, 1000), @intFromEnum(CloseCode.normal)); + try testing.expectEqual(@as(u16, 1001), @intFromEnum(CloseCode.going_away)); + try testing.expectEqual(@as(u16, 1002), @intFromEnum(CloseCode.protocol_error)); + try testing.expectEqual(@as(u16, 1003), @intFromEnum(CloseCode.unsupported_data)); + try testing.expectEqual(@as(u16, 1005), @intFromEnum(CloseCode.no_status)); + try testing.expectEqual(@as(u16, 1006), @intFromEnum(CloseCode.abnormal)); + try testing.expectEqual(@as(u16, 1007), @intFromEnum(CloseCode.invalid_payload)); + try testing.expectEqual(@as(u16, 1008), @intFromEnum(CloseCode.policy_violation)); + try testing.expectEqual(@as(u16, 1009), @intFromEnum(CloseCode.message_too_big)); + try testing.expectEqual(@as(u16, 1010), @intFromEnum(CloseCode.mandatory_extension)); + try testing.expectEqual(@as(u16, 1011), @intFromEnum(CloseCode.internal_error)); +} + +test "CloseCode: isValidForWire rejects API-only codes" { + try testing.expect(!CloseCode.no_status.isValidForWire()); + try testing.expect(!CloseCode.abnormal.isValidForWire()); + try testing.expect(CloseCode.normal.isValidForWire()); + try testing.expect(CloseCode.protocol_error.isValidForWire()); +} + +test "CloseCode: isValidCode validates raw codes" { + // Valid standard codes + try testing.expect(CloseCode.isValidCode(1000)); + try testing.expect(CloseCode.isValidCode(1001)); + try testing.expect(CloseCode.isValidCode(1002)); + try testing.expect(CloseCode.isValidCode(1003)); + try testing.expect(CloseCode.isValidCode(1007)); + try testing.expect(CloseCode.isValidCode(1011)); + + // Invalid: reserved or out of range + try testing.expect(!CloseCode.isValidCode(0)); + try testing.expect(!CloseCode.isValidCode(999)); + try testing.expect(!CloseCode.isValidCode(1004)); + try testing.expect(!CloseCode.isValidCode(1005)); + try testing.expect(!CloseCode.isValidCode(1006)); + try testing.expect(!CloseCode.isValidCode(1012)); + try testing.expect(!CloseCode.isValidCode(2999)); + try testing.expect(!CloseCode.isValidCode(5000)); + + // Valid registered and private ranges + try testing.expect(CloseCode.isValidCode(3000)); + try testing.expect(CloseCode.isValidCode(3999)); + try testing.expect(CloseCode.isValidCode(4000)); + try testing.expect(CloseCode.isValidCode(4999)); +} + +test "CloseCode: payloadBytes encodes as big-endian" { + try testing.expectEqual([2]u8{ 0x03, 0xE8 }, CloseCode.normal.payloadBytes()); // 1000 + try testing.expectEqual([2]u8{ 0x03, 0xEA }, CloseCode.protocol_error.payloadBytes()); // 1002 + try testing.expectEqual([2]u8{ 0x03, 0xF1 }, CloseCode.message_too_big.payloadBytes()); // 1009 +}