Skip to content

Commit 9ed49cf

Browse files
francisbouviersjorsdonkers
authored andcommitted
Handle TLS proxy, both for HTTP and HTTPS (tls in tls) endpoints
1 parent 1254b55 commit 9ed49cf

File tree

1 file changed

+87
-63
lines changed

1 file changed

+87
-63
lines changed

src/http/client.zig

Lines changed: 87 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ pub const Client = struct {
239239
const proxy_type = self.proxy_type orelse return false;
240240
return proxy_type == .forward;
241241
}
242+
243+
fn isProxyTLS(self: *const Client) bool {
244+
const proxy = self.http_proxy orelse return false;
245+
return std.mem.eql(u8, proxy.scheme, "https");
246+
}
242247
};
243248

244249
const RequestOpts = struct {
@@ -331,10 +336,10 @@ const Connection = struct {
331336
fn close(self: *TLSClient) void {
332337
switch (self.*) {
333338
.blocking => |*tls_client| tls_client.close() catch {},
334-
.blocking_tls_in_tls => {}, // |*tls_in_tls| {
335-
// tls_in_tls.destination.close() catch {}; // Crashes
336-
// tls_in_tls.proxy.close() catch {};
337-
// },
339+
.blocking_tls_in_tls => |*tls_in_tls| {
340+
tls_in_tls.destination.close() catch {};
341+
tls_in_tls.proxy.close() catch {};
342+
},
338343
.nonblocking => {},
339344
}
340345
}
@@ -535,6 +540,7 @@ pub const Request = struct {
535540
}
536541

537542
const is_connect_proxy = client.isConnectProxy();
543+
const is_proxy_tls = client.isProxyTLS();
538544

539545
var secure: bool = undefined;
540546
const scheme = if (is_connect_proxy) uri.scheme else connect_uri.scheme;
@@ -546,7 +552,11 @@ pub const Request = struct {
546552
return error.UnsupportedUriScheme;
547553
}
548554
const request_port: u16 = uri.port orelse if (secure) 443 else 80;
549-
const connect_port: u16 = connect_uri.port orelse (if (is_connect_proxy) 80 else request_port);
555+
const connect_port: u16 = connect_uri.port orelse blk: {
556+
if (is_connect_proxy) {
557+
if (is_proxy_tls) break :blk 443 else break :blk 80;
558+
} else break :blk request_port;
559+
};
550560

551561
return .{
552562
.secure = secure,
@@ -663,36 +673,58 @@ pub const Request = struct {
663673
};
664674
self._connection = connection;
665675

676+
const tls_config = tls.config.Client{
677+
.host = self._request_host,
678+
.root_ca = self._client.root_ca,
679+
.insecure_skip_verify = self._tls_verify_host == false,
680+
// .key_log_callback = tls.config.key_log.callback,
681+
};
682+
683+
// proxy
666684
const is_connect_proxy = self._client.isConnectProxy();
685+
const is_proxy_tls = self._client.isProxyTLS();
686+
667687
if (is_connect_proxy) {
668-
var connect_connection = try SyncHandler.connect(self);
669-
if (self._secure) { // TODO separate _secure for proxy and desination
670-
const tls_in_tls = try tls.client(&connect_connection, .{
671-
.host = self._request_host,
672-
.root_ca = self._client.root_ca,
673-
.insecure_skip_verify = self._tls_verify_host == false,
674-
// .key_log_callback = tls.config.key_log.callback,
675-
});
676-
self._connection.?.tls = .{
677-
.blocking_tls_in_tls = .{
678-
.proxy = connect_connection,
679-
.destination = tls_in_tls,
680-
},
681-
};
688+
var proxy_conn: SyncHandler.Conn = .{ .plain = self._connection.?.socket };
689+
690+
if (is_proxy_tls) {
691+
692+
// create an underlying TLS stream with the proxy
693+
var proxy_tls_config = tls_config;
694+
proxy_tls_config.host = self._connect_host;
695+
var proxy_conn_tls = try tls.client(std.net.Stream{ .handle = socket }, proxy_tls_config);
696+
proxy_conn = .{ .tls = &proxy_conn_tls };
682697
}
683-
} else {
684-
if (self._secure) {
685-
self._connection.?.tls = .{
686-
.blocking = try tls.client(std.net.Stream{ .handle = socket }, .{
687-
.host = if (is_connect_proxy) self._request_host else self._connect_host,
688-
.root_ca = self._client.root_ca,
689-
.insecure_skip_verify = self._tls_verify_host == false,
690-
// .key_log_callback = tls.config.key_log.callback,
691-
}),
692-
};
698+
699+
// connect to the proxy
700+
try SyncHandler.connect(self, &proxy_conn);
701+
702+
if (is_proxy_tls) {
703+
if (self._secure) {
704+
705+
// if secure endpoint, create the main TLS stream
706+
// encapsulated into the TLS stream proxy
707+
const tls_in_tls = try tls.client(proxy_conn.tls, tls_config);
708+
self._connection.?.tls = .{
709+
.blocking_tls_in_tls = .{
710+
.proxy = proxy_conn.tls.*,
711+
.destination = tls_in_tls,
712+
},
713+
};
714+
} else {
715+
716+
// otherwise, just use the TLS stream proxy
717+
self._connection.?.tls = .{ .blocking = proxy_conn.tls.* };
718+
}
693719
}
694720
}
695721

722+
if (self._secure and !is_proxy_tls) {
723+
self._connection.?.tls = .{
724+
.blocking = try tls.client(std.net.Stream{ .handle = socket }, tls_config),
725+
};
726+
}
727+
696728
self._connection_from_keepalive = false;
697729
}
698730

@@ -1834,18 +1866,9 @@ const SyncHandler = struct {
18341866

18351867
// Unfortunately, this is called from the Request doSendSync since we need
18361868
// to do this before setting up our TLS connection.
1837-
fn connect(request: *Request) !tls.Connection(std.net.Stream) {
1838-
const socket = request._connection.?.socket;
1839-
1869+
fn connect(request: *Request, conn: *Conn) !void {
18401870
const header = try request.buildConnectHeader();
1841-
// try Conn.writeAll(socket, header);
1842-
var tls_client = try tls.client(std.net.Stream{ .handle = socket }, .{
1843-
.host = request._connect_host,
1844-
.root_ca = request._client.root_ca,
1845-
.insecure_skip_verify = request._tls_verify_host == false,
1846-
.key_log_callback = tls.config.key_log.callback,
1847-
});
1848-
try tls_client.writeAll(header);
1871+
try conn.writeAll(header);
18491872

18501873
var pos: usize = 0;
18511874
var reader = request.newReader();
@@ -1856,24 +1879,19 @@ const SyncHandler = struct {
18561879
// we only send CONNECT requests on newly established connections
18571880
// and maybeRetryOrErr is only for connections that might have been
18581881
// closed while being kept-alive
1859-
// const n = try posix.read(socket, read_buf[pos..]);
1860-
// const n = switch (self.*) {
1861-
// .tls => |tls_client| try tls_client.read(buf),
1862-
// .plain => |socket| try posix.read(socket, buf),
1863-
// };
1864-
const n = try tls_client.read(read_buf[pos..]);
1882+
const n = try conn.read(read_buf[pos..]);
18651883
if (n == 0) {
18661884
return error.ConnectionResetByPeer;
18671885
}
18681886
pos += n;
18691887
if (try reader.connectResponse(read_buf[0..pos])) {
18701888
// returns true if we have a successful connect response
1871-
return tls_client;
1889+
return;
18721890
}
18731891

18741892
// we don't have enough data yet.
18751893
}
1876-
return tls_client;
1894+
return;
18771895
}
18781896

18791897
fn maybeRetryOrErr(self: *SyncHandler, err: anyerror) !Response {
@@ -1929,16 +1947,16 @@ const SyncHandler = struct {
19291947

19301948
fn sendRequest(self: *Conn, header: []const u8, body: ?[]const u8) !void {
19311949
switch (self.*) {
1932-
.tls_in_tls => |tls_client| {
1933-
try tls_client.writeAll(header);
1950+
.tls => |_| {
1951+
try self.writeAll(header);
19341952
if (body) |b| {
1935-
try tls_client.writeAll(b);
1953+
try self.writeAll(b);
19361954
}
19371955
},
1938-
.tls => |tls_client| {
1939-
try tls_client.writeAll(header);
1956+
.tls_in_tls => |_| {
1957+
try self.writeAll(header);
19401958
if (body) |b| {
1941-
try tls_client.writeAll(b);
1959+
try self.writeAll(b);
19421960
}
19431961
},
19441962
.plain => |socket| {
@@ -1949,15 +1967,15 @@ const SyncHandler = struct {
19491967
};
19501968
return writeAllIOVec(socket, &vec);
19511969
}
1952-
return writeAll(socket, header);
1970+
return self.writeAll(header);
19531971
},
19541972
}
19551973
}
19561974

19571975
fn read(self: *Conn, buf: []u8) !usize {
19581976
const n = switch (self.*) {
1959-
.tls_in_tls => |tls_client| try tls_client.read(buf),
19601977
.tls => |tls_client| try tls_client.read(buf),
1978+
.tls_in_tls => |tls_client| try tls_client.read(buf),
19611979
.plain => |socket| try posix.read(socket, buf),
19621980
};
19631981
if (n == 0) {
@@ -1966,6 +1984,19 @@ const SyncHandler = struct {
19661984
return n;
19671985
}
19681986

1987+
fn writeAll(self: *Conn, data: []const u8) !void {
1988+
switch (self.*) {
1989+
.tls => |tls_client| try tls_client.writeAll(data),
1990+
.tls_in_tls => |tls_client| try tls_client.writeAll(data),
1991+
.plain => |socket| {
1992+
var i: usize = 0;
1993+
while (i < data.len) {
1994+
i += try posix.write(socket, data[i..]);
1995+
}
1996+
},
1997+
}
1998+
}
1999+
19692000
fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void {
19702001
var i: usize = 0;
19712002
while (true) {
@@ -1981,13 +2012,6 @@ const SyncHandler = struct {
19812012
vec[i].len -= n;
19822013
}
19832014
}
1984-
1985-
fn writeAll(socket: posix.socket_t, data: []const u8) !void {
1986-
var i: usize = 0;
1987-
while (i < data.len) {
1988-
i += try posix.write(socket, data[i..]);
1989-
}
1990-
}
19912015
};
19922016

19932017
// We don't ask for encoding, but some providers (CloudFront!!)

0 commit comments

Comments
 (0)