Skip to content

Commit 756d662

Browse files
committed
Ensure completions are executed on the currently connected client
For the time being, given that we only allow 1 client at a time, I took a shortcut to implement this. The server has an incrementing "current_client_id" which is part of every completion. On completion callback, we just check if its client_id is still equal to the server's current_client_id.
1 parent 09505db commit 756d662

File tree

1 file changed

+87
-42
lines changed

1 file changed

+87
-42
lines changed

src/server.zig

Lines changed: 87 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub const Client = ClientT(*Server, CDP);
5252
const Server = struct {
5353
allocator: Allocator,
5454
loop: *jsruntime.Loop,
55+
current_client_id: usize = 0,
5556

5657
// internal fields
5758
listener: posix.socket_t,
@@ -64,10 +65,9 @@ const Server = struct {
6465
// a memory poor for our Clietns
6566
client_pool: std.heap.MemoryPool(Client),
6667

67-
timeout_completion_pool: std.heap.MemoryPool(Completion),
68+
completion_state_pool: std.heap.MemoryPool(CompletionState),
6869

6970
// I/O fields
70-
conn_completion: Completion,
7171
close_completion: Completion,
7272
accept_completion: Completion,
7373

@@ -77,7 +77,7 @@ const Server = struct {
7777
fn deinit(self: *Server) void {
7878
self.send_pool.deinit();
7979
self.client_pool.deinit();
80-
self.timeout_completion_pool.deinit();
80+
self.completion_state_pool.deinit();
8181
self.allocator.free(self.json_version_response);
8282
}
8383

@@ -99,40 +99,37 @@ const Server = struct {
9999
) void {
100100
std.debug.assert(self.client == null);
101101
std.debug.assert(completion == &self.accept_completion);
102-
103-
const socket = result catch |err| {
102+
self.doCallbackAccept(result) catch |err| {
104103
log.err("accept error: {any}", .{err});
105104
self.queueAccept();
106-
return;
107105
};
106+
}
108107

109-
const client = self.client_pool.create() catch |err| {
110-
log.err("failed to create client: {any}", .{err});
111-
posix.close(socket);
112-
return;
113-
};
108+
fn doCallbackAccept(
109+
self: *Server,
110+
result: AcceptError!posix.socket_t,
111+
) !void {
112+
const socket = try result;
113+
const client = try self.client_pool.create();
114114
errdefer self.client_pool.destroy(client);
115115

116+
self.current_client_id += 1;
116117
client.* = Client.init(socket, self);
117118

118119
self.client = client;
119120

120121
log.info("client connected", .{});
121-
self.queueRead();
122-
self.queueTimeout();
122+
try self.queueRead();
123+
try self.queueTimeout();
123124
}
124125

125-
fn queueTimeout(self: *Server) void {
126-
const completion = self.timeout_completion_pool.create() catch |err| {
127-
log.err("failed to create timeout completion: {any}", .{err});
128-
return;
129-
};
130-
126+
fn queueTimeout(self: *Server) !void {
127+
const cs = try self.createCompletionState();
131128
self.loop.io.timeout(
132129
*Server,
133130
self,
134131
callbackTimeout,
135-
completion,
132+
&cs.completion,
136133
TimeoutCheck,
137134
);
138135
}
@@ -142,7 +139,16 @@ const Server = struct {
142139
completion: *Completion,
143140
result: TimeoutError!void,
144141
) void {
145-
self.timeout_completion_pool.destroy(completion);
142+
const cs: *CompletionState = @alignCast(
143+
@fieldParentPtr("completion", completion),
144+
);
145+
defer self.completion_state_pool.destroy(cs);
146+
147+
if (cs.client_id != self.current_client_id) {
148+
// completion for a previously-connected client
149+
return;
150+
}
151+
146152
const client = self.client orelse return;
147153

148154
if (result) |_| {
@@ -160,28 +166,39 @@ const Server = struct {
160166
// very unlikely IO timeout error.
161167
// AKA: we don't requeue this if the connection timed out and we
162168
// closed the connection.s
163-
self.queueTimeout();
169+
self.queueTimeout() catch |err| {
170+
log.err("queueTimeout error: {any}", .{err});
171+
};
164172
}
165173

166-
fn queueRead(self: *Server) void {
167-
if (self.client) |client| {
168-
self.loop.io.recv(
169-
*Server,
170-
self,
171-
callbackRead,
172-
&self.conn_completion,
173-
client.socket,
174-
client.readBuf(),
175-
);
176-
}
174+
fn queueRead(self: *Server) !void {
175+
var client = self.client orelse return;
176+
177+
const cs = try self.createCompletionState();
178+
self.loop.io.recv(
179+
*Server,
180+
self,
181+
callbackRead,
182+
&cs.completion,
183+
client.socket,
184+
client.readBuf(),
185+
);
177186
}
178187

179188
fn callbackRead(
180189
self: *Server,
181190
completion: *Completion,
182191
result: RecvError!usize,
183192
) void {
184-
std.debug.assert(completion == &self.conn_completion);
193+
const cs: *CompletionState = @alignCast(
194+
@fieldParentPtr("completion", completion),
195+
);
196+
defer self.completion_state_pool.destroy(cs);
197+
198+
if (cs.client_id != self.current_client_id) {
199+
// completion for a previously-connected client
200+
return;
201+
}
185202

186203
var client = self.client orelse return;
187204

@@ -205,7 +222,10 @@ const Server = struct {
205222

206223
// if more == false, the client is disconnecting
207224
if (more) {
208-
self.queueRead();
225+
self.queueRead() catch |err| {
226+
log.err("queueRead error: {any}", .{err});
227+
client.close(null);
228+
};
209229
}
210230
}
211231

@@ -218,12 +238,15 @@ const Server = struct {
218238
const sd = try self.send_pool.create();
219239
errdefer self.send_pool.destroy(sd);
220240

241+
const cs = try self.createCompletionState();
242+
errdefer self.completion_state_pool.destroy(cs);
243+
221244
sd.* = .{
222245
.unsent = data,
223246
.server = self,
224247
.socket = socket,
225-
.completion = undefined,
226248
.arena = arena,
249+
.completion_state = cs,
227250
};
228251
sd.queueSend();
229252
}
@@ -246,6 +269,18 @@ const Server = struct {
246269
std.debug.assert(completion == &self.close_completion);
247270
self.queueAccept();
248271
}
272+
273+
fn createCompletionState(self: *Server) !*CompletionState {
274+
var cs = try self.completion_state_pool.create();
275+
cs.client_id = self.current_client_id;
276+
cs.completion = undefined;
277+
return cs;
278+
}
279+
};
280+
281+
const CompletionState = struct {
282+
client_id: usize,
283+
completion: Completion,
249284
};
250285

251286
// I/O Send
@@ -259,17 +294,19 @@ const Send = struct { // Any unsent data we have.
259294
unsent: []const u8,
260295

261296
server: *Server,
262-
completion: Completion,
263297
socket: posix.socket_t,
298+
completion_state: *CompletionState,
264299

265300
// If we need to free anything when we're done
266301
arena: ?ArenaAllocator,
267302

268303
fn deinit(self: *Send) void {
269-
var server = self.server;
270304
if (self.arena) |arena| {
271305
arena.deinit();
272306
}
307+
308+
var server = self.server;
309+
server.completion_state_pool.destroy(self.completion_state);
273310
server.send_pool.destroy(self);
274311
}
275312

@@ -278,16 +315,25 @@ const Send = struct { // Any unsent data we have.
278315
*Send,
279316
self,
280317
sendCallback,
281-
&self.completion,
318+
&self.completion_state.completion,
282319
self.socket,
283320
self.unsent,
284321
);
285322
}
286323

287324
fn sendCallback(self: *Send, _: *Completion, result: SendError!usize) void {
325+
const server = self.server;
326+
const cs = self.completion_state;
327+
328+
if (cs.client_id != server.current_client_id) {
329+
// completion for a previously-connected client
330+
self.deinit();
331+
return;
332+
}
333+
288334
const sent = result catch |err| {
289335
log.info("send error: {any}", .{err});
290-
if (self.server.client) |client| {
336+
if (server.client) |client| {
291337
client.close(null);
292338
}
293339
self.deinit();
@@ -1011,13 +1057,12 @@ pub fn run(
10111057
.timeout = timeout,
10121058
.listener = listener,
10131059
.allocator = allocator,
1014-
.conn_completion = undefined,
10151060
.close_completion = undefined,
10161061
.accept_completion = undefined,
10171062
.json_version_response = json_version_response,
10181063
.send_pool = std.heap.MemoryPool(Send).init(allocator),
10191064
.client_pool = std.heap.MemoryPool(Client).init(allocator),
1020-
.timeout_completion_pool = std.heap.MemoryPool(Completion).init(allocator),
1065+
.completion_state_pool = std.heap.MemoryPool(CompletionState).init(allocator),
10211066
};
10221067
defer server.deinit();
10231068

0 commit comments

Comments
 (0)