Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build.zig.zon
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
.fingerprint = 0xda130f3af836cea0,
.dependencies = .{
.tls = .{
.url = "https://github.com/ianic/tls.zig/archive/8250aa9184fbad99983b32411bbe1a5d2fd6f4b7.tar.gz",
.hash = "tls-0.1.0-ER2e0pU3BQB-UD2_s90uvppceH_h4KZxtHCrCct8L054",
.url = "https://github.com/ianic/tls.zig/archive/55845f755d9e2e821458ea55693f85c737cd0c7a.tar.gz",
.hash = "tls-0.1.0-ER2e0m43BQAshi8ixj1qf3w2u2lqKtXtkrxUJ4AGZDcl",
},
.tigerbeetle_io = .{
.url = "https://github.com/lightpanda-io/tigerbeetle-io/archive/61d9652f1a957b7f4db723ea6aa0ce9635e840ce.tar.gz",
Expand Down
158 changes: 112 additions & 46 deletions src/http/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,19 @@ const Connection = struct {

const TLSClient = union(enum) {
blocking: tls.Connection(std.net.Stream),
blocking_tlsproxy: struct {
proxy: tls.Connection(std.net.Stream), // Note, self-referential field. Proxy should be pinned in memory.
destination: tls.Connection(*tls.Connection(std.net.Stream)),
},
nonblocking: tls.nonblock.Connection,

fn close(self: *TLSClient) void {
switch (self.*) {
.blocking => |*tls_client| tls_client.close() catch {},
.blocking_tlsproxy => |*tls_in_tls| {
tls_in_tls.destination.close() catch {};
tls_in_tls.proxy.close() catch {};
},
.nonblocking => {},
}
}
Expand Down Expand Up @@ -375,16 +383,17 @@ pub const Request = struct {
// List of request headers
headers: std.ArrayListUnmanaged(std.http.Header),

// whether or not we expect this connection to be secure
_secure: bool,

// whether or not we should keep the underlying socket open and and usable
// for other requests
_keepalive: bool,

// extracted from request_uri
_request_port: u16,
_request_host: []const u8,
// Whether or not we expect this connection to be secure, connection may still be secure due to proxy
_request_secure: bool,
// Whether or not we expect the SIMPLE/CONNECT proxy connection to be secure
_proxy_secure: bool,

// extracted from connect_uri
_connect_port: u16,
Expand Down Expand Up @@ -470,11 +479,12 @@ pub const Request = struct {
.method = method,
.notification = null,
.arena = state.arena.allocator(),
._secure = decomposed.secure,
._connect_host = decomposed.connect_host,
._connect_port = decomposed.connect_port,
._proxy_secure = decomposed.proxy_secure,
._request_host = decomposed.request_host,
._request_port = decomposed.request_port,
._request_secure = decomposed.request_secure,
._state = state,
._client = client,
._aborter = null,
Expand Down Expand Up @@ -506,12 +516,13 @@ pub const Request = struct {
}

const DecomposedURL = struct {
secure: bool,
connect_port: u16,
connect_host: []const u8,
connect_uri: *const std.Uri,
proxy_secure: bool,
request_port: u16,
request_host: []const u8,
request_secure: bool,
};
fn decomposeURL(client: *const Client, uri: *const Uri) !DecomposedURL {
if (uri.host == null) {
Expand All @@ -526,27 +537,31 @@ pub const Request = struct {
connect_host = proxy.host.?.percent_encoded;
}

const is_connect_proxy = client.isConnectProxy();

var secure: bool = undefined;
const scheme = if (is_connect_proxy) uri.scheme else connect_uri.scheme;
if (std.ascii.eqlIgnoreCase(scheme, "https")) {
secure = true;
} else if (std.ascii.eqlIgnoreCase(scheme, "http")) {
secure = false;
var request_secure: bool = undefined;
if (std.ascii.eqlIgnoreCase(uri.scheme, "https")) {
request_secure = true;
} else if (std.ascii.eqlIgnoreCase(uri.scheme, "http")) {
request_secure = false;
} else {
return error.UnsupportedUriScheme;
}
const request_port: u16 = uri.port orelse if (secure) 443 else 80;
const connect_port: u16 = connect_uri.port orelse (if (is_connect_proxy) 80 else request_port);
const proxy_secure = client.http_proxy != null and std.ascii.eqlIgnoreCase(client.http_proxy.?.scheme, "https");

const request_port: u16 = uri.port orelse if (request_secure) 443 else 80;
const connect_port: u16 = connect_uri.port orelse blk: {
if (client.isConnectProxy()) {
if (proxy_secure) break :blk 443 else break :blk 80;
} else break :blk request_port;
};

return .{
.secure = secure,
.connect_port = connect_port,
.connect_host = connect_host,
.connect_uri = connect_uri,
.proxy_secure = proxy_secure,
.request_port = request_port,
.request_host = request_host,
.request_secure = request_secure,
};
}

Expand Down Expand Up @@ -655,19 +670,50 @@ pub const Request = struct {
};
self._connection = connection;

const tls_config = tls.config.Client{
.host = self._request_host,
.root_ca = self._client.root_ca,
.insecure_skip_verify = self._tls_verify_host == false,
// .key_log_callback = tls.config.key_log.callback,
};

// proxy
const is_connect_proxy = self._client.isConnectProxy();

if (is_connect_proxy) {
try SyncHandler.connect(self);
}
var proxy_conn: SyncHandler.Conn = .{ .plain = self._connection.?.socket };

if (self._proxy_secure) {
// Create an underlying TLS stream with the proxy
var proxy_tls_config = tls_config;
proxy_tls_config.host = self._connect_host;
var proxy_conn_tls = try tls.client(std.net.Stream{ .handle = socket }, proxy_tls_config);
proxy_conn = .{ .tls = &proxy_conn_tls };
}

if (self._secure) {
// Connect to the proxy
try SyncHandler.connect(self, &proxy_conn);

if (self._proxy_secure) {
if (self._request_secure) {
// If secure endpoint, create the main TLS stream encapsulated into the TLS stream proxy
self._connection.?.tls = .{
.blocking_tlsproxy = .{
.proxy = proxy_conn.tls.*,
.destination = undefined,
},
};
const proxy = &self._connection.?.tls.?.blocking_tlsproxy.proxy;
self._connection.?.tls.?.blocking_tlsproxy.destination = try tls.client(proxy, tls_config);
} else {
// Otherwise, just use the TLS stream proxy
self._connection.?.tls = .{ .blocking = proxy_conn.tls.* };
}
}
}
if (self._request_secure and !self._proxy_secure) {
self._connection.?.tls = .{
.blocking = try tls.client(std.net.Stream{ .handle = socket }, .{
.host = if (is_connect_proxy) self._request_host else self._connect_host,
.root_ca = self._client.root_ca,
.insecure_skip_verify = self._tls_verify_host == false,
// .key_log_callback = tls.config.key_log.callback,
}),
.blocking = try tls.client(std.net.Stream{ .handle = socket }, tls_config),
};
}

Expand Down Expand Up @@ -744,7 +790,8 @@ pub const Request = struct {
.conn = .{ .handler = async_handler, .protocol = .{ .plain = {} } },
};

if (self._secure) {
if (self._client.isConnectProxy() and self._proxy_secure) log.warn(.http, "ASYNC TLS CONNECT no impl.", .{});
if (self._request_secure) {
if (self._connection_from_keepalive) {
// If the connection came from the keepalive pool, than we already
// have a TLS Connection.
Expand All @@ -753,7 +800,7 @@ pub const Request = struct {
std.debug.assert(connection.tls == null);
async_handler.conn.protocol = .{
.handshake = tls.nonblock.Client.init(.{
.host = if (self._client.isConnectProxy()) self._request_host else self._connect_host,
.host = if (self._client.isConnectProxy()) self._request_host else self._connect_host, // looks wrong
.root_ca = self._client.root_ca,
.insecure_skip_verify = self._tls_verify_host == false,
.key_log_callback = tls.config.key_log.callback,
Expand Down Expand Up @@ -833,9 +880,10 @@ pub const Request = struct {
const decomposed = try decomposeURL(self._client, self.request_uri);
self.connect_uri = decomposed.connect_uri;
self._request_host = decomposed.request_host;
self._request_secure = decomposed.request_secure;
self._connect_host = decomposed.connect_host;
self._connect_port = decomposed.connect_port;
self._secure = decomposed.secure;
self._proxy_secure = decomposed.proxy_secure;
self._keepalive = false;
self._redirect_count = redirect_count + 1;

Expand Down Expand Up @@ -883,7 +931,9 @@ pub const Request = struct {
return null;
}

return self._client.connection_manager.get(self._secure, self._connect_host, self._connect_port, blocking);
// A simple http proxy to an https destination is made into tls by the proxy, we see it as a plain connection
const expect_tls = self._proxy_secure or (self._request_secure and !self._client.isSimpleProxy());
return self._client.connection_manager.get(expect_tls, self._connect_host, self._connect_port, blocking);
}

fn createSocket(self: *Request, blocking: bool) !struct { posix.socket_t, std.net.Address } {
Expand Down Expand Up @@ -1721,7 +1771,15 @@ const SyncHandler = struct {
var conn: Conn = blk: {
const c = request._connection.?;
if (c.tls) |*tls_client| {
break :blk .{ .tls = &tls_client.blocking };
switch (tls_client.*) {
.nonblocking => unreachable,
.blocking => |*blocking| {
break :blk .{ .tls = blocking };
},
.blocking_tlsproxy => |*blocking_tlsproxy| {
break :blk .{ .tls_in_tls = &blocking_tlsproxy.destination };
},
}
}
break :blk .{ .plain = c.socket };
};
Expand Down Expand Up @@ -1804,11 +1862,9 @@ const SyncHandler = struct {

// Unfortunately, this is called from the Request doSendSync since we need
// to do this before setting up our TLS connection.
fn connect(request: *Request) !void {
const socket = request._connection.?.socket;

fn connect(request: *Request, conn: *Conn) !void {
const header = try request.buildConnectHeader();
try Conn.writeAll(socket, header);
try conn.writeAll(header);

var pos: usize = 0;
var reader = request.newReader();
Expand All @@ -1819,7 +1875,7 @@ const SyncHandler = struct {
// we only send CONNECT requests on newly established connections
// and maybeRetryOrErr is only for connections that might have been
// closed while being kept-alive
const n = try posix.read(socket, read_buf[pos..]);
const n = try conn.read(read_buf[pos..]);
if (n == 0) {
return error.ConnectionResetByPeer;
}
Expand All @@ -1831,6 +1887,7 @@ const SyncHandler = struct {

// we don't have enough data yet.
}
return;
}

fn maybeRetryOrErr(self: *SyncHandler, err: anyerror) !Response {
Expand Down Expand Up @@ -1880,12 +1937,13 @@ const SyncHandler = struct {
}

const Conn = union(enum) {
tls_in_tls: *tls.Connection(*tls.Connection(std.net.Stream)),
tls: *tls.Connection(std.net.Stream),
plain: posix.socket_t,

fn sendRequest(self: *Conn, header: []const u8, body: ?[]const u8) !void {
switch (self.*) {
.tls => |tls_client| {
inline .tls, .tls_in_tls => |tls_client| {
try tls_client.writeAll(header);
if (body) |b| {
try tls_client.writeAll(b);
Expand All @@ -1899,14 +1957,15 @@ const SyncHandler = struct {
};
return writeAllIOVec(socket, &vec);
}
return writeAll(socket, header);
return self.writeAll(header);
},
}
}

fn read(self: *Conn, buf: []u8) !usize {
const n = switch (self.*) {
.tls => |tls_client| try tls_client.read(buf),
.tls_in_tls => |tls_client| try tls_client.read(buf),
.plain => |socket| try posix.read(socket, buf),
};
if (n == 0) {
Expand All @@ -1915,6 +1974,19 @@ const SyncHandler = struct {
return n;
}

fn writeAll(self: *Conn, data: []const u8) !void {
switch (self.*) {
.tls => |tls_client| try tls_client.writeAll(data),
.tls_in_tls => |tls_client| try tls_client.writeAll(data),
.plain => |socket| {
var i: usize = 0;
while (i < data.len) {
i += try posix.write(socket, data[i..]);
}
},
}
}

fn writeAllIOVec(socket: posix.socket_t, vec: []posix.iovec_const) !void {
var i: usize = 0;
while (true) {
Expand All @@ -1930,13 +2002,6 @@ const SyncHandler = struct {
vec[i].len -= n;
}
}

fn writeAll(socket: posix.socket_t, data: []const u8) !void {
var i: usize = 0;
while (i < data.len) {
i += try posix.write(socket, data[i..]);
}
}
};

// We don't ask for encoding, but some providers (CloudFront!!)
Expand Down Expand Up @@ -2081,6 +2146,7 @@ const Reader = struct {
if (result.done == false) {
// CONNECT responses should not have a body. If the header is
// done, then the entire response should be done.
log.info(.http_client, "InvalidConnectResponse", .{ .status = self.response.status, .unprocessed = result.unprocessed });
return error.InvalidConnectResponse;
}

Expand Down Expand Up @@ -2907,14 +2973,14 @@ const ConnectionManager = struct {
self.connection_pool.deinit();
}

fn get(self: *ConnectionManager, secure: bool, host: []const u8, port: u16, blocking: bool) ?*Connection {
fn get(self: *ConnectionManager, expect_tls: bool, host: []const u8, port: u16, blocking: bool) ?*Connection {
self.mutex.lock();
defer self.mutex.unlock();

var node = self.idle.first;
while (node) |n| {
const connection = n.data;
if (std.ascii.eqlIgnoreCase(connection.host, host) and connection.port == port and connection.blocking == blocking and ((connection.tls == null) == !secure)) {
if (std.ascii.eqlIgnoreCase(connection.host, host) and connection.port == port and connection.blocking == blocking and ((connection.tls == null) == !expect_tls)) {
self.count -= 1;
self.idle.remove(n);
self.node_pool.destroy(n);
Expand Down