Skip to content

Commit 5edd613

Browse files
committed
async tls proxy wip
1 parent 878dbd8 commit 5edd613

File tree

1 file changed

+182
-38
lines changed

1 file changed

+182
-38
lines changed

src/http/client.zig

Lines changed: 182 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,16 @@ const Connection = struct {
331331
socket: posix.socket_t,
332332

333333
const TLSClient = union(enum) {
334-
blocking: tls.Connection(std.net.Stream),
334+
blocking: tls.Connection(std.net.Stream), // Note can also be a tlsproxy if the destination is not secure
335335
blocking_tlsproxy: struct {
336336
proxy: tls.Connection(std.net.Stream), // Note, self-referential field. Proxy should be pinned in memory.
337337
destination: tls.Connection(*tls.Connection(std.net.Stream)),
338338
},
339339
nonblocking: tls.nonblock.Connection,
340+
nonblocking_tlsproxy: struct {
341+
proxy: tls.nonblock.Connection,
342+
destination: tls.nonblock.Connection,
343+
},
340344

341345
fn close(self: *TLSClient) void {
342346
switch (self.*) {
@@ -345,7 +349,7 @@ const Connection = struct {
345349
tls_in_tls.destination.close() catch {};
346350
tls_in_tls.proxy.close() catch {};
347351
},
348-
.nonblocking => {},
352+
.nonblocking, .nonblocking_tlsproxy => {},
349353
}
350354
}
351355
};
@@ -797,36 +801,18 @@ pub const Request = struct {
797801
.handler = handler,
798802
.read_buf = state.read_buf,
799803
.write_buf = state.write_buf,
804+
.write_connect_buf = state.write_connect_buf,
800805
.reader = self.newReader(),
801806
.socket = connection.socket,
802807
.conn = .{ .handler = async_handler, .protocol = .{ .plain = {} } },
803808
};
804809

805-
if (self._client.isConnectProxy() and self._proxy_secure) {
806-
log.warn(.http, "not implemented", .{ .feature = "async tls connect" });
807-
}
808-
809-
const is_proxy = self._client.isProxy();
810-
if ((is_proxy and self._proxy_secure) or (!self._client.isForwardProxy() and self._request_secure)) {
811-
if (self._connection_from_keepalive) {
812-
// If the connection came from the keepalive pool, than we already
813-
// have a TLS Connection.
814-
async_handler.conn.protocol = .{ .encrypted = .{ .conn = &connection.tls.?.nonblocking } };
815-
} else {
816-
std.debug.assert(connection.tls == null);
817-
async_handler.conn.protocol = .{
818-
.handshake = tls.nonblock.Client.init(.{
819-
.host = if (self._client.isConnectProxy()) self._request_host else self._connect_host, // looks wrong
820-
.root_ca = self._client.root_ca,
821-
.insecure_skip_verify = self._tls_verify_host == false,
822-
.key_log_callback = tls.config.key_log.callback,
823-
}),
824-
};
825-
}
826-
}
827-
828-
if (self._connection_from_keepalive) {
829-
// we're already connected
810+
if (self._connection_from_keepalive and
811+
((self._client.isProxy() and self._proxy_secure) or (!self._client.isForwardProxy() and self._request_secure)))
812+
{
813+
// If the connection came from the keepalive pool, than we already have a TLS Connection.
814+
async_handler.conn.protocol = .{ .encrypted = .{ .conn = &connection.tls.?.nonblocking } };
815+
// and we're already connected
830816
async_handler.pending_connect = false;
831817
return async_handler.conn.connected();
832818
}
@@ -1079,6 +1065,7 @@ fn AsyncHandler(comptime H: type) type {
10791065
// need a separate read and write buf because, with TLS, messages are
10801066
// not strictly req->resp.
10811067
write_buf: []u8,
1068+
write_connect_buf: []u8,
10821069

10831070
socket: posix.socket_t,
10841071
read_completion: IO.Completion = undefined,
@@ -1095,7 +1082,7 @@ fn AsyncHandler(comptime H: type) type {
10951082
send_queue: SendQueue = .{},
10961083

10971084
// Used to help us know if we're writing the header or the body;
1098-
state: SendState = .handshake,
1085+
state: SendState = .init,
10991086

11001087
// Abstraction over TLS and plain text socket, this is a version of
11011088
// the request._connection (which is a *Connection) that is async-specific.
@@ -1124,7 +1111,9 @@ fn AsyncHandler(comptime H: type) type {
11241111
const SendQueue = std.DoublyLinkedList([]const u8);
11251112

11261113
const SendState = enum {
1127-
connect,
1114+
init,
1115+
connect_handshake,
1116+
connect_header,
11281117
handshake,
11291118
header,
11301119
body,
@@ -1148,6 +1137,16 @@ fn AsyncHandler(comptime H: type) type {
11481137
self.maybeShutdown();
11491138
}
11501139

1140+
/// Shift unused part of the buffer to the beginning.
1141+
/// Returns write position for the next write into buffer.
1142+
/// Unused part is at the end of the buffer.
1143+
fn shiftUnused(buf: []u8, unused: []const u8) usize {
1144+
if (unused.len == 0) return 0;
1145+
if (unused.ptr == buf.ptr) return unused.len;
1146+
std.mem.copyForwards(u8, buf, unused);
1147+
return unused.len;
1148+
}
1149+
11511150
fn connected(self: *Self, _: *IO.Completion, result: IO.ConnectError!void) void {
11521151
self.pending_connect = false;
11531152
if (self.shutdown) {
@@ -1156,16 +1155,81 @@ fn AsyncHandler(comptime H: type) type {
11561155

11571156
result catch |err| return self.handleError("Connection failed", err);
11581157

1159-
if (self.request.shouldProxyConnect()) {
1160-
self.state = .connect;
1161-
const header = self.request.buildConnectHeader() catch |err| {
1158+
const request = self.request;
1159+
if (request._request_secure) {
1160+
std.debug.assert(request._connection.?.tls == null);
1161+
self.conn.protocol = .{
1162+
.handshake = tls.nonblock.Client.init(.{
1163+
.host = request._request_host,
1164+
.root_ca = request._client.root_ca,
1165+
.insecure_skip_verify = request._tls_verify_host == false,
1166+
.key_log_callback = tls.config.key_log.callback,
1167+
}),
1168+
};
1169+
}
1170+
1171+
if (request.shouldProxyConnect()) {
1172+
if (request._proxy_secure) {
1173+
// If we're using a CONNECT proxy, we need to do a TLS handshake, before sending the CONNECT request
1174+
self.state = .connect_handshake;
1175+
const tls_config = tls.config.Client{
1176+
.host = request._connect_host,
1177+
.root_ca = request._client.root_ca,
1178+
.insecure_skip_verify = request._tls_verify_host == false,
1179+
.key_log_callback = tls.config.key_log.callback,
1180+
};
1181+
self.conn.connect_protocol = .{ .handshake = tls.nonblock.Client.init(tls_config) };
1182+
const handshake = &self.conn.connect_protocol.?.handshake;
1183+
1184+
var recv_buf: [tls.max_ciphertext_record_len]u8 = undefined;
1185+
var send_buf: [tls.max_ciphertext_record_len]u8 = undefined;
1186+
var recv_pos: usize = 0;
1187+
while (true) { // run handshake until done
1188+
const res = handshake.run(recv_buf[0..recv_pos], &send_buf) catch |err| {
1189+
return self.handleError("TLS handshake error", err);
1190+
};
1191+
if (res.send.len > 0) {
1192+
var i: usize = 0;
1193+
while (i < res.send.len) {
1194+
i += posix.write(self.socket, res.send[i..]) catch |err| {
1195+
return self.handleError("TLS handshake write error", err);
1196+
};
1197+
}
1198+
}
1199+
recv_pos = shiftUnused(&recv_buf, res.unused_recv);
1200+
if (handshake.done()) break;
1201+
while (true) {
1202+
recv_pos += posix.read(self.socket, recv_buf[recv_pos..]) catch |err| {
1203+
if (err == error.WouldBlock) continue; // Retry on WouldBlock
1204+
return self.handleError("TLS handshake read error", err);
1205+
};
1206+
break;
1207+
}
1208+
}
1209+
1210+
const encrypted = tls.nonblock.Connection.init(handshake.cipher().?); // steal it from blocking connection if it exists?
1211+
std.debug.assert(request._connection.?.tls == null);
1212+
request._connection.?.tls = .{ .nonblocking = encrypted }; // TODO considering storing it in nonblocking_tlsproxy directly if request_secure
1213+
self.conn.connect_protocol = .{
1214+
.encrypted = .{
1215+
.conn = &request._connection.?.tls.?.nonblocking,
1216+
},
1217+
};
1218+
}
1219+
// else {
1220+
// If we're using a plain proxy, we just send the CONNECT request
1221+
self.state = .connect_header;
1222+
const header = request.buildConnectHeader() catch |err| {
11621223
return self.handleError("Failed to build CONNECT header", err);
11631224
};
11641225
self.send(header);
11651226
self.receive();
11661227
return;
1228+
// }
11671229
}
11681230

1231+
// There is no CONNECT proxy so we can proceed with the TLS handshake or header if plain
1232+
self.state = .handshake;
11691233
self.conn.connected() catch |err| {
11701234
self.handleError("connected handler error", err);
11711235
};
@@ -1178,6 +1242,18 @@ fn AsyncHandler(comptime H: type) type {
11781242
};
11791243

11801244
node.data = data;
1245+
1246+
if (self.conn.connect_protocol) |*connect_protocol| {
1247+
// Encrypt data with the cypher
1248+
const res = connect_protocol.encrypted.conn.encrypt(data, self.conn.handler.write_connect_buf) catch |err| {
1249+
self.handleError("TLS proxy encrypt error", err);
1250+
return;
1251+
};
1252+
connect_protocol.encrypted.unsent = res.unused_cleartext;
1253+
node.data = res.ciphertext;
1254+
if (res.unused_cleartext.len > 0) log.warn(.http_client, "TLS encrypt unused data", .{});
1255+
}
1256+
11811257
self.send_queue.append(node);
11821258
if (self.send_queue.len > 1) {
11831259
// if we already had a message in the queue, then our send loop
@@ -1191,7 +1267,7 @@ fn AsyncHandler(comptime H: type) type {
11911267
self,
11921268
&self.send_completion,
11931269
sent,
1194-
self.socket,
1270+
self.socket, // For TLS CONNECT should this be a TLS CLient?, no data should already be encrypted if that is needed
11951271
node.data,
11961272
) catch |err| {
11971273
self.handleError("loop send error", err);
@@ -1235,7 +1311,7 @@ fn AsyncHandler(comptime H: type) type {
12351311
return;
12361312
}
12371313

1238-
if (self.state == .connect) {
1314+
if (self.state == .connect_handshake or self.state == .connect_header) {
12391315
// We're in a proxy CONNECT flow. There's nothing for us to
12401316
// do except for wait for the response.
12411317
return;
@@ -1286,7 +1362,60 @@ fn AsyncHandler(comptime H: type) type {
12861362

12871363
const data = self.read_buf[0 .. self.read_pos + n];
12881364

1289-
if (self.state == .connect) {
1365+
if (self.state == .connect_handshake) {
1366+
// TODO send/rec
1367+
}
1368+
1369+
if (self.state == .connect_header) {
1370+
blk: {
1371+
if (self.conn.connect_protocol) |*connect_protocol| {
1372+
const res = connect_protocol.encrypted.conn.decrypt(data, data) catch |err| {
1373+
return self.handleError("TLS proxy decrypt error", err);
1374+
};
1375+
1376+
if (res.ciphertext_pos == 0) {
1377+
// no part of the encrypted data was consumed
1378+
// no cleartext data should have been generated
1379+
std.debug.assert(res.cleartext.len == 0);
1380+
1381+
// our next read needs to append more data to
1382+
// the existing data
1383+
self.read_pos = data.len;
1384+
return if (res.closed) break :blk else self.receive();
1385+
}
1386+
1387+
if (res.cleartext.len > 0) {
1388+
// status = self.processData(res.cleartext);
1389+
break :blk; // we assume we can read the header in one go
1390+
}
1391+
1392+
if (res.closed) break :blk;
1393+
1394+
const unused = res.unused_ciphertext;
1395+
if (unused.len == 0) {
1396+
// all of data was used up, our next read can use
1397+
// the whole read buffer.
1398+
self.read_pos = 0;
1399+
return self.receive();
1400+
}
1401+
1402+
// We used some of the data, but have some leftover
1403+
// (i.e. there was 1+ full records AND an incomplete
1404+
// record). We need to maintain the "leftover" data
1405+
// for subsequent reads.
1406+
1407+
// Remember that our read_buf is the MAX possible TLS
1408+
// record size. So as long as we make sure that the start
1409+
// of a record is at read_buf[0], we know that we'll
1410+
// always have enough space for 1 record.
1411+
std.mem.copyForwards(u8, self.read_buf, unused);
1412+
self.read_pos = unused.len;
1413+
1414+
// an incomplete record means there must be more data
1415+
return self.receive();
1416+
}
1417+
}
1418+
12901419
const success = self.reader.connectResponse(data) catch |err| {
12911420
return self.handleError("Invalid CONNECT response", err);
12921421
};
@@ -1538,6 +1667,7 @@ fn AsyncHandler(comptime H: type) type {
15381667
const Conn = struct {
15391668
handler: *Self,
15401669
protocol: Protocol,
1670+
connect_protocol: ?Protocol = null,
15411671

15421672
const Encrypted = struct {
15431673
conn: *tls.nonblock.Connection,
@@ -1594,6 +1724,12 @@ fn AsyncHandler(comptime H: type) type {
15941724

15951725
fn received(self: *Conn, data: []u8) !ProcessStatus {
15961726
const handler = self.handler;
1727+
1728+
if (self.connect_protocol) |*connect_protocol| {
1729+
_ = try connect_protocol.encrypted.conn.decrypt(data, data);
1730+
// TODO needsmore
1731+
}
1732+
15971733
switch (self.protocol) {
15981734
.plain => {
15991735
std.debug.assert(handler.state == .body);
@@ -1686,7 +1822,7 @@ fn AsyncHandler(comptime H: type) type {
16861822
const handler = self.handler;
16871823
switch (self.protocol) {
16881824
.plain => switch (handler.state) {
1689-
.handshake, .connect => unreachable,
1825+
.init, .handshake, .connect_header, .connect_handshake => unreachable,
16901826
.header => return self.sendBody(),
16911827
.body => {},
16921828
},
@@ -1695,7 +1831,7 @@ fn AsyncHandler(comptime H: type) type {
16951831
return self.send(encrypted.unsent);
16961832
}
16971833
switch (handler.state) {
1698-
.handshake, .connect => unreachable,
1834+
.init, .handshake, .connect_header, .connect_handshake => unreachable,
16991835
.header => return self.sendBody(),
17001836
.body => {},
17011837
}
@@ -1762,6 +1898,9 @@ fn AsyncHandler(comptime H: type) type {
17621898
.encrypted => |*encrypted| {
17631899
const res = try encrypted.conn.encrypt(data, handler.write_buf);
17641900
encrypted.unsent = res.unused_cleartext;
1901+
1902+
// TODO encrypt with CONNECT tls if any
1903+
17651904
return handler.send(res.ciphertext);
17661905
},
17671906
.handshake => {
@@ -1788,7 +1927,7 @@ const SyncHandler = struct {
17881927
const c = request._connection.?;
17891928
if (c.tls) |*tls_client| {
17901929
switch (tls_client.*) {
1791-
.nonblocking => unreachable,
1930+
.nonblocking, .nonblocking_tlsproxy => unreachable,
17921931
.blocking => |*blocking| {
17931932
break :blk .{ .tls = blocking };
17941933
},
@@ -2804,6 +2943,7 @@ const State = struct {
28042943
// write_buf, even though HTTP is req -> resp, it's for TLS, which has
28052944
// bidirectional data.
28062945
write_buf: []u8,
2946+
write_connect_buf: []u8,
28072947

28082948
// Used for keeping any unparsed header line until more data is received
28092949
// At most, this represents 1 line in the header.
@@ -2822,6 +2962,8 @@ const State = struct {
28222962

28232963
const write_buf = try allocator.alloc(u8, buf_size);
28242964
errdefer allocator.free(write_buf);
2965+
const write_connect_buf = try allocator.alloc(u8, buf_size * 2);
2966+
errdefer allocator.free(write_connect_buf);
28252967

28262968
const header_buf = try allocator.alloc(u8, header_size);
28272969
errdefer allocator.free(header_buf);
@@ -2830,6 +2972,7 @@ const State = struct {
28302972
.peek_buf = peek_buf,
28312973
.read_buf = read_buf,
28322974
.write_buf = write_buf,
2975+
.write_connect_buf = write_connect_buf,
28332976
.header_buf = header_buf,
28342977
.arena = std.heap.ArenaAllocator.init(allocator),
28352978
};
@@ -2844,6 +2987,7 @@ const State = struct {
28442987
allocator.free(self.peek_buf);
28452988
allocator.free(self.read_buf);
28462989
allocator.free(self.write_buf);
2990+
allocator.free(self.write_connect_buf);
28472991
allocator.free(self.header_buf);
28482992
self.arena.deinit();
28492993
}

0 commit comments

Comments
 (0)