Skip to content

Commit 1ba6838

Browse files
authored
Merge pull request ziglang#24740 from ziglang/http-plus-fixes
fetch, tls, and http fixes
2 parents d2f7792 + 8da645c commit 1ba6838

File tree

7 files changed

+172
-155
lines changed

7 files changed

+172
-155
lines changed

lib/std/Io/Reader.zig

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ pub const VTable = struct {
2525
///
2626
/// Returns the number of bytes written, which will be at minimum `0` and
2727
/// at most `limit`. The number returned, including zero, does not indicate
28-
/// end of stream. `limit` is guaranteed to be at least as large as the
29-
/// buffer capacity of `w`, a value whose minimum size is determined by the
30-
/// stream implementation.
28+
/// end of stream.
3129
///
3230
/// The reader's internal logical seek position moves forward in accordance
3331
/// with the number of bytes returned from this function.

lib/std/crypto/tls/Client.zig

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ pub const ReadError = error{
6161
TlsUnexpectedMessage,
6262
TlsIllegalParameter,
6363
TlsSequenceOverflow,
64-
/// The buffer provided to the read function was not at least
65-
/// `min_buffer_len`.
66-
OutputBufferUndersize,
6764
};
6865

6966
pub const SslKeyLog = struct {
@@ -372,7 +369,8 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
372369
};
373370
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch
374371
return error.TlsBadRecordMac;
375-
cleartext_fragment_end += std.mem.trimEnd(u8, cleartext, "\x00").len;
372+
// TODO use scalar, non-slice version
373+
cleartext_fragment_end += mem.trimEnd(u8, cleartext, "\x00").len;
376374
},
377375
}
378376
read_seq += 1;
@@ -395,9 +393,9 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
395393
const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
396394
if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
397395
const cleartext = cleartext_fragment_buf[0..message_len];
398-
const ad = std.mem.toBytes(big(read_seq)) ++
396+
const ad = mem.toBytes(big(read_seq)) ++
399397
record_header[0 .. 1 + 2] ++
400-
std.mem.toBytes(big(message_len));
398+
mem.toBytes(big(message_len));
401399
const record_iv = record_decoder.array(P.record_iv_length).*;
402400
const masked_read_seq = read_seq &
403401
comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
@@ -738,7 +736,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
738736
&.{ "server finished", &p.transcript_hash.finalResult() },
739737
P.verify_data_length,
740738
),
741-
.app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block),
739+
.app_cipher = mem.bytesToValue(P.Tls_1_2, &key_block),
742740
} };
743741
const pv = &p.version.tls_1_2;
744742
const nonce: [P.AEAD.nonce_length]u8 = nonce: {
@@ -756,7 +754,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
756754
client_verify_cleartext.len ..][0..client_verify_cleartext.len],
757755
client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length],
758756
&client_verify_cleartext,
759-
std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
757+
mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
760758
nonce,
761759
pv.app_cipher.client_write_key,
762760
);
@@ -873,7 +871,10 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
873871
.input = input,
874872
.reader = .{
875873
.buffer = options.read_buffer,
876-
.vtable = &.{ .stream = stream },
874+
.vtable = &.{
875+
.stream = stream,
876+
.readVec = readVec,
877+
},
877878
.seek = 0,
878879
.end = 0,
879880
},
@@ -1017,7 +1018,7 @@ fn prepareCiphertextRecord(
10171018
const nonce = nonce: {
10181019
const V = @Vector(P.AEAD.nonce_length, u8);
10191020
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
1020-
const operand: V = pad ++ std.mem.toBytes(big(c.write_seq));
1021+
const operand: V = pad ++ mem.toBytes(big(c.write_seq));
10211022
break :nonce @as(V, pv.client_iv) ^ operand;
10221023
};
10231024
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
@@ -1048,7 +1049,7 @@ fn prepareCiphertextRecord(
10481049
record_header.* = .{@intFromEnum(inner_content_type)} ++
10491050
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
10501051
int(u16, P.record_iv_length + message_len + P.mac_length);
1051-
const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
1052+
const ad = mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
10521053
const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length];
10531054
ciphertext_end += P.record_iv_length;
10541055
const nonce: [P.AEAD.nonce_length]u8 = nonce: {
@@ -1076,7 +1077,22 @@ pub fn eof(c: Client) bool {
10761077
}
10771078

10781079
fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
1080+
// This function writes exclusively to the buffer.
1081+
_ = w;
1082+
_ = limit;
1083+
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
1084+
return readIndirect(c);
1085+
}
1086+
1087+
fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
1088+
// This function writes exclusively to the buffer.
1089+
_ = data;
10791090
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
1091+
return readIndirect(c);
1092+
}
1093+
1094+
fn readIndirect(c: *Client) Reader.Error!usize {
1095+
const r = &c.reader;
10801096
if (c.eof()) return error.EndOfStream;
10811097
const input = c.input;
10821098
// If at least one full encrypted record is not buffered, read once.
@@ -1108,8 +1124,13 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
11081124
if (record_end > input.buffered().len) return 0;
11091125
}
11101126

1111-
var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
1112-
const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
1127+
if (r.seek == r.end) {
1128+
r.seek = 0;
1129+
r.end = 0;
1130+
}
1131+
const cleartext_buffer = r.buffer[r.end..];
1132+
1133+
const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
11131134
inline else => |*p| switch (c.tls_version) {
11141135
.tls_1_3 => {
11151136
const pv = &p.tls_1_3;
@@ -1121,23 +1142,24 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
11211142
const nonce = nonce: {
11221143
const V = @Vector(P.AEAD.nonce_length, u8);
11231144
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
1124-
const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
1145+
const operand: V = pad ++ mem.toBytes(big(c.read_seq));
11251146
break :nonce @as(V, pv.server_iv) ^ operand;
11261147
};
1127-
const cleartext = cleartext_stack_buffer[0..ciphertext.len];
1148+
const cleartext = cleartext_buffer[0..ciphertext.len];
11281149
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
11291150
return failRead(c, error.TlsBadRecordMac);
1151+
// TODO use scalar, non-slice version
11301152
const msg = mem.trimRight(u8, cleartext, "\x00");
1131-
break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
1153+
break :cleartext .{ msg.len - 1, @enumFromInt(msg[msg.len - 1]) };
11321154
},
11331155
.tls_1_2 => {
11341156
const pv = &p.tls_1_2;
11351157
const P = @TypeOf(p.*);
11361158
const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
11371159
const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked
1138-
const ad = std.mem.toBytes(big(c.read_seq)) ++
1160+
const ad = mem.toBytes(big(c.read_seq)) ++
11391161
ad_header[0 .. 1 + 2] ++
1140-
std.mem.toBytes(big(message_len));
1162+
mem.toBytes(big(message_len));
11411163
const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked
11421164
const masked_read_seq = c.read_seq &
11431165
comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
@@ -1149,14 +1171,15 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
11491171
};
11501172
const ciphertext = input.take(message_len) catch unreachable; // already peeked
11511173
const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
1152-
const cleartext = cleartext_stack_buffer[0..ciphertext.len];
1174+
const cleartext = cleartext_buffer[0..ciphertext.len];
11531175
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
11541176
return failRead(c, error.TlsBadRecordMac);
1155-
break :cleartext .{ cleartext, ct };
1177+
break :cleartext .{ cleartext.len, ct };
11561178
},
11571179
else => unreachable,
11581180
},
11591181
};
1182+
const cleartext = cleartext_buffer[0..cleartext_len];
11601183
c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
11611184
switch (inner_ct) {
11621185
.alert => {
@@ -1245,9 +1268,8 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
12451268
return 0;
12461269
},
12471270
.application_data => {
1248-
if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
1249-
try w.writeAll(cleartext);
1250-
return cleartext.len;
1271+
r.end += cleartext.len;
1272+
return 0;
12511273
},
12521274
else => return failRead(c, error.TlsUnexpectedMessage),
12531275
}

lib/std/http.zig

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,14 @@ pub const ContentEncoding = enum {
292292
});
293293
return map.get(s);
294294
}
295+
296+
pub fn minBufferCapacity(ce: ContentEncoding) usize {
297+
return switch (ce) {
298+
.zstd => std.compress.zstd.default_window_len,
299+
.gzip, .deflate => std.compress.flate.max_window_len,
300+
.compress, .identity => 0,
301+
};
302+
}
295303
};
296304

297305
pub const Connection = enum {
@@ -412,7 +420,7 @@ pub const Reader = struct {
412420
/// * `interfaceDecompressing`
413421
pub fn bodyReader(
414422
reader: *Reader,
415-
buffer: []u8,
423+
transfer_buffer: []u8,
416424
transfer_encoding: TransferEncoding,
417425
content_length: ?u64,
418426
) *std.Io.Reader {
@@ -421,7 +429,7 @@ pub const Reader = struct {
421429
.chunked => {
422430
reader.state = .{ .body_remaining_chunk_len = .head };
423431
reader.interface = .{
424-
.buffer = buffer,
432+
.buffer = transfer_buffer,
425433
.seek = 0,
426434
.end = 0,
427435
.vtable = &.{
@@ -435,7 +443,7 @@ pub const Reader = struct {
435443
if (content_length) |len| {
436444
reader.state = .{ .body_remaining_content_length = len };
437445
reader.interface = .{
438-
.buffer = buffer,
446+
.buffer = transfer_buffer,
439447
.seek = 0,
440448
.end = 0,
441449
.vtable = &.{
@@ -460,11 +468,12 @@ pub const Reader = struct {
460468
/// * `interface`
461469
pub fn bodyReaderDecompressing(
462470
reader: *Reader,
471+
transfer_buffer: []u8,
463472
transfer_encoding: TransferEncoding,
464473
content_length: ?u64,
465474
content_encoding: ContentEncoding,
466-
decompressor: *Decompressor,
467-
decompression_buffer: []u8,
475+
decompress: *Decompress,
476+
decompress_buffer: []u8,
468477
) *std.Io.Reader {
469478
if (transfer_encoding == .none and content_length == null) {
470479
assert(reader.state == .received_head);
@@ -474,22 +483,22 @@ pub const Reader = struct {
474483
return reader.in;
475484
},
476485
.deflate => {
477-
decompressor.* = .{ .flate = .init(reader.in, .zlib, decompression_buffer) };
478-
return &decompressor.flate.reader;
486+
decompress.* = .{ .flate = .init(reader.in, .zlib, decompress_buffer) };
487+
return &decompress.flate.reader;
479488
},
480489
.gzip => {
481-
decompressor.* = .{ .flate = .init(reader.in, .gzip, decompression_buffer) };
482-
return &decompressor.flate.reader;
490+
decompress.* = .{ .flate = .init(reader.in, .gzip, decompress_buffer) };
491+
return &decompress.flate.reader;
483492
},
484493
.zstd => {
485-
decompressor.* = .{ .zstd = .init(reader.in, decompression_buffer, .{ .verify_checksum = false }) };
486-
return &decompressor.zstd.reader;
494+
decompress.* = .{ .zstd = .init(reader.in, decompress_buffer, .{ .verify_checksum = false }) };
495+
return &decompress.zstd.reader;
487496
},
488497
.compress => unreachable,
489498
}
490499
}
491-
const transfer_reader = bodyReader(reader, &.{}, transfer_encoding, content_length);
492-
return decompressor.init(transfer_reader, decompression_buffer, content_encoding);
500+
const transfer_reader = bodyReader(reader, transfer_buffer, transfer_encoding, content_length);
501+
return decompress.init(transfer_reader, decompress_buffer, content_encoding);
493502
}
494503

495504
fn contentLengthStream(
@@ -691,33 +700,33 @@ pub const Reader = struct {
691700
}
692701
};
693702

694-
pub const Decompressor = union(enum) {
703+
pub const Decompress = union(enum) {
695704
flate: std.compress.flate.Decompress,
696705
zstd: std.compress.zstd.Decompress,
697706
none: *std.Io.Reader,
698707

699708
pub fn init(
700-
decompressor: *Decompressor,
709+
decompress: *Decompress,
701710
transfer_reader: *std.Io.Reader,
702711
buffer: []u8,
703712
content_encoding: ContentEncoding,
704713
) *std.Io.Reader {
705714
switch (content_encoding) {
706715
.identity => {
707-
decompressor.* = .{ .none = transfer_reader };
716+
decompress.* = .{ .none = transfer_reader };
708717
return transfer_reader;
709718
},
710719
.deflate => {
711-
decompressor.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
712-
return &decompressor.flate.reader;
720+
decompress.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
721+
return &decompress.flate.reader;
713722
},
714723
.gzip => {
715-
decompressor.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
716-
return &decompressor.flate.reader;
724+
decompress.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
725+
return &decompress.flate.reader;
717726
},
718727
.zstd => {
719-
decompressor.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
720-
return &decompressor.zstd.reader;
728+
decompress.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
729+
return &decompress.zstd.reader;
721730
},
722731
.compress => unreachable,
723732
}
@@ -794,7 +803,7 @@ pub const BodyWriter = struct {
794803
}
795804

796805
/// When using content-length, asserts that the amount of data sent matches
797-
/// the value sent in the header, then flushes.
806+
/// the value sent in the header, then flushes `http_protocol_output`.
798807
///
799808
/// When using transfer-encoding: chunked, writes the end-of-stream message
800809
/// with empty trailers, then flushes the stream to the system. Asserts any
@@ -818,10 +827,13 @@ pub const BodyWriter = struct {
818827
///
819828
/// Respects the value of `isEliding` to omit all data after the headers.
820829
///
830+
/// Does not flush `http_protocol_output`, but does flush `writer`.
831+
///
821832
/// See also:
822833
/// * `end`
823834
/// * `endChunked`
824835
pub fn endUnflushed(w: *BodyWriter) Error!void {
836+
try w.writer.flush();
825837
switch (w.state) {
826838
.end => unreachable,
827839
.content_length => |len| {

0 commit comments

Comments
 (0)