@@ -51,6 +51,7 @@ pub const Client = struct {
5151 root_ca : tls.config.CertBundle ,
5252 tls_verify_host : bool = true ,
5353 connection_manager : ConnectionManager ,
54+ request_pool : std .heap .MemoryPool (Request ),
5455
5556 const Opts = struct {
5657 tls_verify_host : bool = true ,
@@ -76,6 +77,7 @@ pub const Client = struct {
7677 .http_proxy = opts .http_proxy ,
7778 .tls_verify_host = opts .tls_verify_host ,
7879 .connection_manager = connection_manager ,
80+ .request_pool = std .heap .MemoryPool (Request ).init (allocator ),
7981 };
8082 }
8183
@@ -86,17 +88,29 @@ pub const Client = struct {
8688 }
8789 self .state_pool .deinit (allocator );
8890 self .connection_manager .deinit ();
91+ self .request_pool .deinit ();
8992 }
9093
91- pub fn request (self : * Client , method : Request.Method , uri : * const Uri ) ! Request {
94+ pub fn request (self : * Client , method : Request.Method , uri : * const Uri ) ! * Request {
9295 const state = self .state_pool .acquire ();
9396
9497 errdefer {
9598 state .reset ();
9699 self .state_pool .release (state );
97100 }
98101
99- return Request .init (self , state , method , uri );
102+ // We need the request on the heap, because it can have a longer lifetime
103+ // than the code making the request. That sounds odd, but consider the
104+ // case of an XHR request: it can still be inflight (e.g. waiting for
105+ // the response) when the page gets unloaded. Once the page is unloaded
106+ // the page arena is reset and the XHR instance becomes invalid. If the
107+ // XHR instance owns the `Request`, we'd crash once an async callback
108+ // executes.
109+ const req = try self .request_pool .create ();
110+ errdefer self .request_pool .destroy (req );
111+
112+ req .* = try Request .init (self , state , method , uri );
113+ return req ;
100114 }
101115
102116 pub fn requestFactory (self : * Client , notification : ? * Notification ) RequestFactory {
@@ -112,7 +126,7 @@ pub const RequestFactory = struct {
112126 client : * Client ,
113127 notification : ? * Notification ,
114128
115- pub fn create (self : RequestFactory , method : Request.Method , uri : * const Uri ) ! Request {
129+ pub fn create (self : RequestFactory , method : Request.Method , uri : * const Uri ) ! * Request {
116130 var req = try self .client .request (method , uri );
117131 req .notification = self .notification ;
118132 return req ;
@@ -244,6 +258,17 @@ pub const Request = struct {
244258 // The notifier that we emit request notifications to, if any.
245259 notification : ? * Notification ,
246260
261+ // Aborting an async request is complicated, as we need to wait until all
262+ // in-flight IO events are completed. Our AsyncHandler is a generic type
263+ // that we don't have the necessary type information for in the Request,
264+ // so we need to rely on anyopaque.
265+ _aborter : ? Aborter ,
266+
267+ const Aborter = struct {
268+ ctx : * anyopaque ,
269+ func : * const fn (* anyopaque ) void ,
270+ };
271+
247272 pub const Method = enum {
248273 GET ,
249274 PUT ,
@@ -282,6 +307,7 @@ pub const Request = struct {
282307 ._request_host = decomposed .request_host ,
283308 ._state = state ,
284309 ._client = client ,
310+ ._aborter = null ,
285311 ._connection = null ,
286312 ._keepalive = false ,
287313 ._redirect_count = 0 ,
@@ -297,6 +323,15 @@ pub const Request = struct {
297323 self .releaseConnection ();
298324 _ = self ._state .reset ();
299325 self ._client .state_pool .release (self ._state );
326+ self ._client .request_pool .destroy (self );
327+ }
328+
329+ pub fn abort (self : * Request ) void {
330+ const aborter = self ._aborter orelse {
331+ self .deinit ();
332+ return ;
333+ };
334+ aborter .func (aborter .ctx );
300335 }
301336
302337 const DecomposedURL = struct {
@@ -544,6 +579,11 @@ pub const Request = struct {
544579 return async_handler .conn .connected ();
545580 }
546581
582+ self ._aborter = .{
583+ .ctx = async_handler ,
584+ .func = AsyncHandlerT .abort ,
585+ };
586+
547587 return loop .connect (
548588 AsyncHandlerT ,
549589 async_handler ,
@@ -732,13 +772,6 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
732772 // that we have valid, but unprocessed, data up to.
733773 read_pos : usize = 0 ,
734774
735- // Depending on which version of TLS, there are different places during
736- // the handshake that we want to start receiving from. We can't have
737- // overlapping receives (works fine on MacOS (kqueue) but not Linux (
738- // io_uring)). Using this boolean as a guard, to make sure we only have
739- // 1 in-flight receive is easier than trying to understand TLS.
740- is_receiving : bool = false ,
741-
742775 // need a separate read and write buf because, with TLS, messages are
743776 // not strictly req->resp.
744777 write_buf : []u8 ,
@@ -775,6 +808,13 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
775808 // gzipped responses *cough*)
776809 full_body : ? std .ArrayListUnmanaged (u8 ) = null ,
777810
811+ // Shutting down an async request requires that we wait for all inflight
812+ // IO to be completed. So we need to track what inflight requests we
813+ // have and whether or not we're shutting down
814+ shutdown : bool = false ,
815+ pending_write : bool = false ,
816+ pending_receive : bool = false ,
817+
778818 const Self = @This ();
779819 const SendQueue = std .DoublyLinkedList ([]const u8 );
780820
@@ -794,6 +834,12 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
794834 self .request .deinit ();
795835 }
796836
837+ fn abort (ctx : * anyopaque ) void {
838+ var self : * Self = @alignCast (@ptrCast (ctx ));
839+ self .shutdown = true ;
840+ self .maybeShutdown ();
841+ }
842+
797843 fn connected (self : * Self , _ : * IO.Completion , result : IO .ConnectError ! void ) void {
798844 result catch | err | return self .handleError ("Connection failed" , err );
799845 self .conn .connected () catch | err | {
@@ -815,6 +861,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
815861 return ;
816862 }
817863
864+ self .pending_write = true ;
818865 self .loop .send (
819866 Self ,
820867 self ,
@@ -828,6 +875,10 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
828875 }
829876
830877 fn sent (self : * Self , _ : * IO.Completion , n_ : IO .SendError ! usize ) void {
878+ self .pending_write = false ;
879+ if (self .shutdown ) {
880+ return self .maybeShutdown ();
881+ }
831882 const n = n_ catch | err | {
832883 return self .handleError ("Write error" , err );
833884 };
@@ -845,6 +896,7 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
845896 }
846897
847898 if (next ) | next_ | {
899+ self .pending_write = true ;
848900 // we still have data to send
849901 self .loop .send (
850902 Self ,
@@ -869,11 +921,11 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
869921 // while handshaking and potentially while sending data. So we're always
870922 // receiving.
871923 fn receive (self : * Self ) void {
872- if (self .is_receiving ) {
924+ if (self .pending_receive ) {
873925 return ;
874926 }
875927
876- self .is_receiving = true ;
928+ self .pending_receive = true ;
877929 self .loop .recv (
878930 Self ,
879931 self ,
@@ -887,7 +939,11 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
887939 }
888940
889941 fn received (self : * Self , _ : * IO.Completion , n_ : IO .RecvError ! usize ) void {
890- self .is_receiving = false ;
942+ self .pending_receive = false ;
943+ if (self .shutdown ) {
944+ return self .maybeShutdown ();
945+ }
946+
891947 const n = n_ catch | err | {
892948 return self .handleError ("Read error" , err );
893949 };
@@ -926,6 +982,17 @@ fn AsyncHandler(comptime H: type, comptime L: type) type {
926982 }
927983 }
928984
985+ fn maybeShutdown (self : * Self ) void {
986+ std .debug .assert (self .shutdown );
987+ if (self .pending_write or self .pending_receive ) {
988+ return ;
989+ }
990+
991+ // Who knows what state we're in, safer to not try to re-use the connection
992+ self .request ._keepalive = false ;
993+ self .request .deinit ();
994+ }
995+
929996 // If our socket came from the connection pool, it's possible that we're
930997 // failing because it's since timed out. If
931998 fn maybeRetryRequest (self : * Self ) bool {
0 commit comments