Skip to content

Commit 83cd274

Browse files
committed
refactor(webzockets): pass mask keys by value in mask api
1 parent 744fc9a commit 83cd274

File tree

4 files changed

+21
-30
lines changed

4 files changed

+21
-30
lines changed

src/rpc/webzockets/src/client/connection.zig

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ pub fn ClientConnection(
488488

489489
// Mask payload in-place in the caller's buffer
490490
if (payload.len > 0) {
491-
mask_mod.mask(&mask_key, payload);
491+
mask_mod.mask(mask_key, payload);
492492
}
493493
self.write.payload = payload;
494494

@@ -546,7 +546,7 @@ pub fn ClientConnection(
546546
if (payload.len > 0) {
547547
@memcpy(self.write.control_buf[header_len..][0..payload.len], payload);
548548
// Mask control payload in-place
549-
mask_mod.mask(&mask_key, self.write.control_buf[header_len..][0..payload.len]);
549+
mask_mod.mask(mask_key, self.write.control_buf[header_len..][0..payload.len]);
550550
}
551551
self.write.control_len = header_len + payload.len;
552552
self.write.in_flight = .{ .control = .{ .offset = 0 } };

src/rpc/webzockets/src/frame.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub const FrameHeader = struct {
4040
/// No-op if the frame is not masked.
4141
pub fn unmaskPayload(self: FrameHeader, payload: []u8) void {
4242
if (self.masked) {
43-
mask_mod.mask(&self.mask_key, payload);
43+
mask_mod.mask(self.mask_key, payload);
4444
}
4545
}
4646

src/rpc/webzockets/src/mask.zig

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,42 @@
11
const std = @import("std");
22

3-
/// Whether the current backend supports SIMD vector operations.
4-
const backend_supports_vectors = switch (@import("builtin").zig_backend) {
5-
.stage2_llvm, .stage2_c => true,
6-
else => false,
7-
};
8-
93
/// Applies XOR masking/unmasking to payload in-place using the 4-byte mask key.
104
/// The same function works for both masking and unmasking (XOR is its own inverse).
11-
/// Uses SIMD acceleration when available, with scalar fallback.
12-
pub fn mask(m: []const u8, payload: []u8) void {
13-
std.debug.assert(m.len == 4);
5+
pub fn mask(mask_key: [4]u8, payload: []u8) void {
146
var data = payload;
157

16-
if (!comptime backend_supports_vectors) return scalarMask(m, data);
17-
188
const vector_size = comptime std.simd.suggestVectorLength(u8) orelse @sizeOf(usize);
199
if (data.len >= vector_size) {
20-
const mask_vector: @Vector(vector_size, u8) = std.simd.repeat(vector_size, m[0..4].*);
10+
const mask_vector: @Vector(vector_size, u8) = std.simd.repeat(vector_size, mask_key);
2111
while (data.len >= vector_size) {
2212
const slice = data[0..vector_size];
2313
const masked: @Vector(vector_size, u8) = slice.*;
2414
slice.* = masked ^ mask_vector;
2515
data = data[vector_size..];
2616
}
2717
}
28-
scalarMask(m, data);
18+
19+
scalarMask(mask_key, data);
2920
}
3021

3122
/// Scalar byte-by-byte XOR mask fallback for remaining bytes or unsupported backends.
32-
fn scalarMask(m: []const u8, payload: []u8) void {
23+
fn scalarMask(mask_key: [4]u8, payload: []u8) void {
3324
@setRuntimeSafety(false);
3425
for (payload, 0..) |b, i| {
35-
payload[i] = b ^ m[i & 3];
26+
payload[i] = b ^ mask_key[i & 3];
3627
}
3728
}
3829

3930
const testing = std.testing;
4031

4132
test "mask: empty payload is no-op" {
4233
var buf = [_]u8{};
43-
mask(&.{ 0xAA, 0xBB, 0xCC, 0xDD }, &buf);
34+
mask(.{ 0xAA, 0xBB, 0xCC, 0xDD }, &buf);
4435
}
4536

4637
test "mask: single byte" {
4738
var buf = [_]u8{0x42};
48-
mask(&.{ 0xFF, 0x00, 0x00, 0x00 }, &buf);
39+
mask(.{ 0xFF, 0x00, 0x00, 0x00 }, &buf);
4940
try testing.expectEqual(@as(u8, 0x42 ^ 0xFF), buf[0]);
5041
}
5142

@@ -57,18 +48,18 @@ test "mask: round-trip produces original" {
5748
const key = [_]u8{ 0x37, 0xFA, 0x21, 0x3D };
5849

5950
// Mask
60-
mask(&key, &buf);
51+
mask(key, &buf);
6152
// Should differ from original
6253
try testing.expect(!std.mem.eql(u8, &buf, original));
6354
// Unmask (same operation)
64-
mask(&key, &buf);
55+
mask(key, &buf);
6556
try testing.expectEqualSlices(u8, original, &buf);
6657
}
6758

6859
test "mask: exact 4-byte alignment" {
6960
var buf = [_]u8{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08 };
7061
const key = [_]u8{ 0xFF, 0xFF, 0xFF, 0xFF };
71-
mask(&key, &buf);
62+
mask(key, &buf);
7263
const expected = [_]u8{ 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8, 0xF7 };
7364
try testing.expectEqualSlices(u8, &expected, &buf);
7465
}
@@ -94,22 +85,22 @@ test "mask: SIMD boundary sizes" {
9485
@memcpy(expected, buf);
9586

9687
// Mask then unmask should round-trip
97-
mask(&key, buf);
98-
mask(&key, buf);
88+
mask(key, buf);
89+
mask(key, buf);
9990
try testing.expectEqualSlices(u8, expected, buf);
10091
}
10192
}
10293

10394
test "mask: all-zero mask is identity" {
10495
const original = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE };
10596
var buf = original;
106-
mask(&.{ 0x00, 0x00, 0x00, 0x00 }, &buf);
97+
mask(.{ 0x00, 0x00, 0x00, 0x00 }, &buf);
10798
try testing.expectEqualSlices(u8, &original, &buf);
10899
}
109100

110101
test "mask: all-ones mask inverts all bits" {
111102
var buf = [_]u8{ 0x00, 0xFF, 0xAA, 0x55 };
112-
mask(&.{ 0xFF, 0xFF, 0xFF, 0xFF }, &buf);
103+
mask(.{ 0xFF, 0xFF, 0xFF, 0xFF }, &buf);
113104
try testing.expectEqualSlices(u8, &[_]u8{ 0xFF, 0x00, 0x55, 0xAA }, &buf);
114105
}
115106

@@ -120,15 +111,15 @@ test "mask: RFC 6455 example verification" {
120111
// masked = { 0x48^0x37, 0x65^0xfa, 0x6c^0x21, 0x6c^0x3d, 0x6f^0x37 }
121112
// = { 0x7f, 0x9f, 0x4d, 0x51, 0x58 }
122113
var buf = [_]u8{ 0x48, 0x65, 0x6c, 0x6c, 0x6f };
123-
mask(&.{ 0x37, 0xfa, 0x21, 0x3d }, &buf);
114+
mask(.{ 0x37, 0xfa, 0x21, 0x3d }, &buf);
124115
try testing.expectEqualSlices(u8, &[_]u8{ 0x7f, 0x9f, 0x4d, 0x51, 0x58 }, &buf);
125116
}
126117

127118
test "mask: fuzz round-trip" {
128119
const Context = struct {
129120
fn testOne(_: @This(), input: []const u8) anyerror!void {
130121
if (input.len < 4) return;
131-
const key = input[0..4];
122+
const key: [4]u8 = input[0..4].*;
132123
const buf = try testing.allocator.alloc(u8, input.len - 4);
133124
defer testing.allocator.free(buf);
134125
@memcpy(buf, input[4..]);

src/rpc/webzockets/src/reader.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ fn buildMaskedFrame(
512512
header_len += 4;
513513

514514
@memcpy(out[header_len..][0..payload.len], payload);
515-
mask_mod.mask(&mask_key, out[header_len..][0..payload.len]);
515+
mask_mod.mask(mask_key, out[header_len..][0..payload.len]);
516516

517517
return out[0 .. header_len + payload.len];
518518
}

0 commit comments

Comments
 (0)