Skip to content

Commit 527579a

Browse files
authored
Merge pull request #720 from lightpanda-io/clean_xhr_shutdown
Clean Http Request Shutdown
2 parents 1869ef0 + e339ee3 commit 527579a

File tree

3 files changed

+98
-16
lines changed

3 files changed

+98
-16
lines changed

src/browser/page.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ pub const Page = struct {
491491
return arr.items;
492492
}
493493

494-
fn newHTTPRequest(self: *const Page, method: http.Request.Method, url: *const URL, opts: storage.cookie.LookupOpts) !http.Request {
494+
fn newHTTPRequest(self: *const Page, method: http.Request.Method, url: *const URL, opts: storage.cookie.LookupOpts) !*http.Request {
495495
// Don't use the state's request_factory here, since requests made by the
496496
// page (i.e. to load <scripts>) should not generate notifications.
497497
var request = try self.session.browser.http_client.request(method, &url.uri);

src/browser/xhr/xhr.zig

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ const XMLHttpRequestBodyInit = union(enum) {
7979
pub const XMLHttpRequest = struct {
8080
proto: XMLHttpRequestEventTarget = XMLHttpRequestEventTarget{},
8181
arena: Allocator,
82-
request: ?http.Request = null,
82+
request: ?*http.Request = null,
8383

8484
method: http.Request.Method,
8585
state: State,
@@ -252,6 +252,13 @@ pub const XMLHttpRequest = struct {
252252
};
253253
}
254254

255+
pub fn destructor(self: *XMLHttpRequest) void {
256+
if (self.request) |req| {
257+
req.abort();
258+
self.request = null;
259+
}
260+
}
261+
255262
pub fn reset(self: *XMLHttpRequest) void {
256263
self.url = null;
257264

@@ -417,7 +424,7 @@ pub const XMLHttpRequest = struct {
417424
self.send_flag = true;
418425

419426
self.request = try page.request_factory.create(self.method, &self.url.?.uri);
420-
var request = &self.request.?;
427+
var request = self.request.?;
421428
errdefer request.deinit();
422429

423430
for (self.headers.list.items) |hdr| {
@@ -452,6 +459,9 @@ pub const XMLHttpRequest = struct {
452459

453460
pub fn onHttpResponse(self: *XMLHttpRequest, progress_: anyerror!http.Progress) !void {
454461
const progress = progress_ catch |err| {
462+
// The request has been closed internally by the client, it isn't safe
463+
// for us to keep it around.
464+
self.request = null;
455465
self.onErr(err);
456466
return err;
457467
};
@@ -510,6 +520,10 @@ pub const XMLHttpRequest = struct {
510520
.status = progress.header.status,
511521
});
512522

523+
// Not that the request is done, the http/client will free the request
524+
// object. It isn't safe to keep it around.
525+
self.request = null;
526+
513527
self.state = .done;
514528
self.send_flag = false;
515529
self.dispatchEvt("readystatechange");
@@ -532,6 +546,7 @@ pub const XMLHttpRequest = struct {
532546

533547
pub fn _abort(self: *XMLHttpRequest) void {
534548
self.onErr(DOMError.Abort);
549+
self.destructor();
535550
}
536551

537552
pub fn get_responseType(self: *XMLHttpRequest) []const u8 {

src/http/client.zig

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub const Client = struct {
5151
root_ca: tls.config.CertBundle,
5252
tls_verify_host: bool = true,
5353
connection_manager: ConnectionManager,
54+
request_pool: std.heap.MemoryPool(Request),
5455

5556
const Opts = struct {
5657
tls_verify_host: bool = true,
@@ -76,6 +77,7 @@ pub const Client = struct {
7677
.http_proxy = opts.http_proxy,
7778
.tls_verify_host = opts.tls_verify_host,
7879
.connection_manager = connection_manager,
80+
.request_pool = std.heap.MemoryPool(Request).init(allocator),
7981
};
8082
}
8183

@@ -86,17 +88,29 @@ pub const Client = struct {
8688
}
8789
self.state_pool.deinit(allocator);
8890
self.connection_manager.deinit();
91+
self.request_pool.deinit();
8992
}
9093

91-
pub fn request(self: *Client, method: Request.Method, uri: *const Uri) !Request {
94+
pub fn request(self: *Client, method: Request.Method, uri: *const Uri) !*Request {
9295
const state = self.state_pool.acquire();
9396

9497
errdefer {
9598
state.reset();
9699
self.state_pool.release(state);
97100
}
98101

99-
return Request.init(self, state, method, uri);
102+
// We need the request on the heap, because it can have a longer lifetime
103+
// than the code making the request. That sounds odd, but consider the
104+
// case of an XHR request: it can still be inflight (e.g. waiting for
105+
// the response) when the page gets unloaded. Once the page is unloaded
106+
// the page arena is reset and the XHR instance becomes invalid. If the
107+
// XHR instance owns the `Request`, we'd crash once an async callback
108+
// executes.
109+
const req = try self.request_pool.create();
110+
errdefer self.request_pool.destroy(req);
111+
112+
req.* = try Request.init(self, state, method, uri);
113+
return req;
100114
}
101115

102116
pub fn requestFactory(self: *Client, notification: ?*Notification) RequestFactory {
@@ -112,7 +126,7 @@ pub const RequestFactory = struct {
112126
client: *Client,
113127
notification: ?*Notification,
114128

115-
pub fn create(self: RequestFactory, method: Request.Method, uri: *const Uri) !Request {
129+
pub fn create(self: RequestFactory, method: Request.Method, uri: *const Uri) !*Request {
116130
var req = try self.client.request(method, uri);
117131
req.notification = self.notification;
118132
return req;
@@ -244,6 +258,17 @@ pub const Request = struct {
244258
// The notifier that we emit request notifications to, if any.
245259
notification: ?*Notification,
246260

261+
// Aborting an async request is complicated, as we need to wait until all
262+
// in-flight IO events are completed. Our AsyncHandler is a generic type
263+
// that we don't have the necessary type information for in the Request,
264+
// so we need to rely on anyopaque.
265+
_aborter: ?Aborter,
266+
267+
const Aborter = struct {
268+
ctx: *anyopaque,
269+
func: *const fn (*anyopaque) void,
270+
};
271+
247272
pub const Method = enum {
248273
GET,
249274
PUT,
@@ -282,6 +307,7 @@ pub const Request = struct {
282307
._request_host = decomposed.request_host,
283308
._state = state,
284309
._client = client,
310+
._aborter = null,
285311
._connection = null,
286312
._keepalive = false,
287313
._redirect_count = 0,
@@ -297,6 +323,15 @@ pub const Request = struct {
297323
self.releaseConnection();
298324
_ = self._state.reset();
299325
self._client.state_pool.release(self._state);
326+
self._client.request_pool.destroy(self);
327+
}
328+
329+
pub fn abort(self: *Request) void {
330+
const aborter = self._aborter orelse {
331+
self.deinit();
332+
return;
333+
};
334+
aborter.func(aborter.ctx);
300335
}
301336

302337
const DecomposedURL = struct {
@@ -544,6 +579,11 @@ pub const Request = struct {
544579
return async_handler.conn.connected();
545580
}
546581

582+
self._aborter = .{
583+
.ctx = async_handler,
584+
.func = AsyncHandlerT.abort,
585+
};
586+
547587
return loop.connect(
548588
AsyncHandlerT,
549589
async_handler,
@@ -732,13 +772,6 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
732772
// that we have valid, but unprocessed, data up to.
733773
read_pos: usize = 0,
734774

735-
// Depending on which version of TLS, there are different places during
736-
// the handshake that we want to start receiving from. We can't have
737-
// overlapping receives (works fine on MacOS (kqueue) but not Linux (
738-
// io_uring)). Using this boolean as a guard, to make sure we only have
739-
// 1 in-flight receive is easier than trying to understand TLS.
740-
is_receiving: bool = false,
741-
742775
// need a separate read and write buf because, with TLS, messages are
743776
// not strictly req->resp.
744777
write_buf: []u8,
@@ -775,6 +808,13 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
775808
// gzipped responses *cough*)
776809
full_body: ?std.ArrayListUnmanaged(u8) = null,
777810

811+
// Shutting down an async request requires that we wait for all inflight
812+
// IO to be completed. So we need to track what inflight requests we
813+
// have and whether or not we're shutting down
814+
shutdown: bool = false,
815+
pending_write: bool = false,
816+
pending_receive: bool = false,
817+
778818
const Self = @This();
779819
const SendQueue = std.DoublyLinkedList([]const u8);
780820

@@ -794,6 +834,12 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
794834
self.request.deinit();
795835
}
796836

837+
fn abort(ctx: *anyopaque) void {
838+
var self: *Self = @alignCast(@ptrCast(ctx));
839+
self.shutdown = true;
840+
self.maybeShutdown();
841+
}
842+
797843
fn connected(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void {
798844
result catch |err| return self.handleError("Connection failed", err);
799845
self.conn.connected() catch |err| {
@@ -815,6 +861,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
815861
return;
816862
}
817863

864+
self.pending_write = true;
818865
self.loop.send(
819866
Self,
820867
self,
@@ -828,6 +875,10 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
828875
}
829876

830877
fn sent(self: *Self, _: *IO.Completion, n_: IO.SendError!usize) void {
878+
self.pending_write = false;
879+
if (self.shutdown) {
880+
return self.maybeShutdown();
881+
}
831882
const n = n_ catch |err| {
832883
return self.handleError("Write error", err);
833884
};
@@ -845,6 +896,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
845896
}
846897

847898
if (next) |next_| {
899+
self.pending_write = true;
848900
// we still have data to send
849901
self.loop.send(
850902
Self,
@@ -869,11 +921,11 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
869921
// while handshaking and potentially while sending data. So we're always
870922
// receiving.
871923
fn receive(self: *Self) void {
872-
if (self.is_receiving) {
924+
if (self.pending_receive) {
873925
return;
874926
}
875927

876-
self.is_receiving = true;
928+
self.pending_receive = true;
877929
self.loop.recv(
878930
Self,
879931
self,
@@ -887,7 +939,11 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
887939
}
888940

889941
fn received(self: *Self, _: *IO.Completion, n_: IO.RecvError!usize) void {
890-
self.is_receiving = false;
942+
self.pending_receive = false;
943+
if (self.shutdown) {
944+
return self.maybeShutdown();
945+
}
946+
891947
const n = n_ catch |err| {
892948
return self.handleError("Read error", err);
893949
};
@@ -926,6 +982,17 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
926982
}
927983
}
928984

985+
fn maybeShutdown(self: *Self) void {
986+
std.debug.assert(self.shutdown);
987+
if (self.pending_write or self.pending_receive) {
988+
return;
989+
}
990+
991+
// Who knows what state we're in, safer to not try to re-use the connection
992+
self.request._keepalive = false;
993+
self.request.deinit();
994+
}
995+
929996
// If our socket came from the connection pool, it's possible that we're
930997
// failing because it's since timed out. If
931998
fn maybeRetryRequest(self: *Self) bool {

0 commit comments

Comments
 (0)