Skip to content

Commit 3d74772

Browse files
authored
Make Zlib.Compressor/Decompressor classes (#1769)
Motivation: z_stream stores a pointer to itself in its internal state which it checks against in inflate/deflate. As we hold these within structs, and call through to C functions which take a pointer to a z_stream, this address can change as the struct is copied about. This results in errors when calling deflate/inflate. Modifications: - Hold a pointer to the z_stream Result: Harder to misue compressor/decompressor
1 parent f7f3d2d commit 3d74772

File tree

2 files changed

+78
-98
lines changed

2 files changed

+78
-98
lines changed

Sources/GRPCHTTP2Core/Compression/Zlib.swift

Lines changed: 73 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,18 @@ enum Zlib {
3737
extension Zlib {
3838
/// Creates a new compressor for the given compression format.
3939
///
40-
/// This compressor is only suitable for compressing whole messages at a time. Callers
41-
/// must ``initialize()`` the compressor before using it.
40+
/// This compressor is only suitable for compressing whole messages at a time.
4241
struct Compressor {
43-
private var stream: z_stream
42+
// TODO: Make this ~Copyable when 5.9 is the lowest supported Swift version.
43+
44+
private var stream: UnsafeMutablePointer<z_stream>
4445
private let method: Method
45-
private var isInitialized = false
4646

4747
init(method: Method) {
4848
self.method = method
49-
self.stream = z_stream()
50-
}
51-
52-
/// Initialize the compressor.
53-
mutating func initialize() {
54-
precondition(!self.isInitialized)
49+
self.stream = .allocate(capacity: 1)
50+
self.stream.initialize(to: z_stream())
5551
self.stream.deflateInit(windowBits: self.method.windowBits)
56-
self.isInitialized = true
57-
}
58-
59-
static func initialized(_ method: Method) -> Self {
60-
var compressor = Compressor(method: method)
61-
compressor.initialize()
62-
return compressor
6352
}
6453

6554
/// Compresses the data in `input` into the `output` buffer.
@@ -68,77 +57,73 @@ extension Zlib {
6857
/// - Parameter output: The `ByteBuffer` into which the compressed message should be written.
6958
/// - Returns: The number of bytes written into the `output` buffer.
7059
@discardableResult
71-
mutating func compress(_ input: [UInt8], into output: inout ByteBuffer) throws -> Int {
72-
precondition(self.isInitialized)
60+
func compress(_ input: [UInt8], into output: inout ByteBuffer) throws -> Int {
7361
defer { self.reset() }
7462
let upperBound = self.stream.deflateBound(inputBytes: input.count)
7563
return try self.stream.deflate(input, into: &output, upperBound: upperBound)
7664
}
7765

7866
/// Resets compression state.
79-
private mutating func reset() {
67+
private func reset() {
8068
do {
8169
try self.stream.deflateReset()
8270
} catch {
8371
self.end()
84-
self.stream = z_stream()
72+
self.stream.initialize(to: z_stream())
8573
self.stream.deflateInit(windowBits: self.method.windowBits)
8674
}
8775
}
8876

8977
/// Deallocates any resources allocated by Zlib.
90-
mutating func end() {
78+
func end() {
9179
self.stream.deflateEnd()
80+
self.stream.deallocate()
9281
}
9382
}
9483
}
9584

9685
extension Zlib {
9786
/// Creates a new decompressor for the given compression format.
9887
///
99-
/// This decompressor is only suitable for compressing whole messages at a time. Callers
100-
/// must ``initialize()`` the decompressor before using it.
88+
/// This decompressor is only suitable for compressing whole messages at a time.
10189
struct Decompressor {
102-
private var stream: z_stream
90+
// TODO: Make this ~Copyable when 5.9 is the lowest supported Swift version.
91+
92+
private var stream: UnsafeMutablePointer<z_stream>
10393
private let method: Method
104-
private var isInitialized = false
10594

10695
init(method: Method) {
10796
self.method = method
108-
self.stream = z_stream()
109-
}
110-
111-
mutating func initialize() {
112-
precondition(!self.isInitialized)
97+
self.stream = UnsafeMutablePointer.allocate(capacity: 1)
98+
self.stream.initialize(to: z_stream())
11399
self.stream.inflateInit(windowBits: self.method.windowBits)
114-
self.isInitialized = true
115100
}
116101

117102
/// Returns the decompressed bytes from ``input``.
118103
///
119104
/// - Parameters:
120105
/// - input: The buffer read compressed bytes from.
121106
/// - limit: The largest size a decompressed payload may be.
122-
mutating func decompress(_ input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
123-
precondition(self.isInitialized)
107+
func decompress(_ input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
124108
defer { self.reset() }
125109
return try self.stream.inflate(input: &input, limit: limit)
126110
}
127111

128112
/// Resets decompression state.
129-
private mutating func reset() {
113+
private func reset() {
130114
do {
131115
try self.stream.inflateReset()
132116
} catch {
133117
self.end()
134-
self.stream = z_stream()
118+
self.stream.initialize(to: z_stream())
135119
self.stream.inflateInit(windowBits: self.method.windowBits)
136120
}
137121
}
138122

139123
/// Deallocates any resources allocated by Zlib.
140-
mutating func end() {
124+
func end() {
141125
self.stream.inflateEnd()
126+
self.stream.deallocate()
142127
}
143128
}
144129
}
@@ -155,13 +140,13 @@ struct ZlibError: Error, Hashable {
155140
}
156141
}
157142

158-
extension z_stream {
159-
mutating func inflateInit(windowBits: Int32) {
160-
self.zfree = nil
161-
self.zalloc = nil
162-
self.opaque = nil
143+
extension UnsafeMutablePointer<z_stream> {
144+
func inflateInit(windowBits: Int32) {
145+
self.pointee.zfree = nil
146+
self.pointee.zalloc = nil
147+
self.pointee.opaque = nil
163148

164-
let rc = CGRPCZlib_inflateInit2(&self, windowBits)
149+
let rc = CGRPCZlib_inflateInit2(self, windowBits)
165150
// Possible return codes:
166151
// - Z_OK
167152
// - Z_MEM_ERROR: not enough memory
@@ -171,8 +156,8 @@ extension z_stream {
171156
precondition(rc == Z_OK, "inflateInit2 failed with error (\(rc)) \(self.lastError ?? "")")
172157
}
173158

174-
mutating func inflateReset() throws {
175-
let rc = CGRPCZlib_inflateReset(&self)
159+
func inflateReset() throws {
160+
let rc = CGRPCZlib_inflateReset(self)
176161

177162
// Possible return codes:
178163
// - Z_OK
@@ -187,17 +172,17 @@ extension z_stream {
187172
}
188173
}
189174

190-
mutating func inflateEnd() {
191-
_ = CGRPCZlib_inflateEnd(&self)
175+
func inflateEnd() {
176+
_ = CGRPCZlib_inflateEnd(self)
192177
}
193178

194-
mutating func deflateInit(windowBits: Int32) {
195-
self.zfree = nil
196-
self.zalloc = nil
197-
self.opaque = nil
179+
func deflateInit(windowBits: Int32) {
180+
self.pointee.zfree = nil
181+
self.pointee.zalloc = nil
182+
self.pointee.opaque = nil
198183

199184
let rc = CGRPCZlib_deflateInit2(
200-
&self,
185+
self,
201186
Z_DEFAULT_COMPRESSION, // compression level
202187
Z_DEFLATED, // compression method (this must be Z_DEFLATED)
203188
windowBits, // window size, i.e. deflate/gzip
@@ -215,8 +200,8 @@ extension z_stream {
215200
precondition(rc == Z_OK, "deflateInit2 failed with error (\(rc)) \(self.lastError ?? "")")
216201
}
217202

218-
mutating func deflateReset() throws {
219-
let rc = CGRPCZlib_deflateReset(&self)
203+
func deflateReset() throws {
204+
let rc = CGRPCZlib_deflateReset(self)
220205

221206
// Possible return codes:
222207
// - Z_OK
@@ -231,87 +216,87 @@ extension z_stream {
231216
}
232217
}
233218

234-
mutating func deflateEnd() {
235-
_ = CGRPCZlib_deflateEnd(&self)
219+
func deflateEnd() {
220+
_ = CGRPCZlib_deflateEnd(self)
236221
}
237222

238-
mutating func deflateBound(inputBytes: Int) -> Int {
239-
let bound = CGRPCZlib_deflateBound(&self, UInt(inputBytes))
223+
func deflateBound(inputBytes: Int) -> Int {
224+
let bound = CGRPCZlib_deflateBound(self, UInt(inputBytes))
240225
return Int(bound)
241226
}
242227

243-
mutating func setNextInputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
228+
func setNextInputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
244229
if let baseAddress = buffer.baseAddress {
245-
self.next_in = baseAddress
246-
self.avail_in = UInt32(buffer.count)
230+
self.pointee.next_in = baseAddress
231+
self.pointee.avail_in = UInt32(buffer.count)
247232
} else {
248-
self.next_in = nil
249-
self.avail_in = 0
233+
self.pointee.next_in = nil
234+
self.pointee.avail_in = 0
250235
}
251236
}
252237

253-
mutating func setNextInputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
238+
func setNextInputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
254239
if let buffer = buffer, let baseAddress = buffer.baseAddress {
255-
self.next_in = CGRPCZlib_castVoidToBytefPointer(baseAddress)
256-
self.avail_in = UInt32(buffer.count)
240+
self.pointee.next_in = CGRPCZlib_castVoidToBytefPointer(baseAddress)
241+
self.pointee.avail_in = UInt32(buffer.count)
257242
} else {
258-
self.next_in = nil
259-
self.avail_in = 0
243+
self.pointee.next_in = nil
244+
self.pointee.avail_in = 0
260245
}
261246
}
262247

263-
mutating func setNextOutputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
248+
func setNextOutputBuffer(_ buffer: UnsafeMutableBufferPointer<UInt8>) {
264249
if let baseAddress = buffer.baseAddress {
265-
self.next_out = baseAddress
266-
self.avail_out = UInt32(buffer.count)
250+
self.pointee.next_out = baseAddress
251+
self.pointee.avail_out = UInt32(buffer.count)
267252
} else {
268-
self.next_out = nil
269-
self.avail_out = 0
253+
self.pointee.next_out = nil
254+
self.pointee.avail_out = 0
270255
}
271256
}
272257

273-
mutating func setNextOutputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
258+
func setNextOutputBuffer(_ buffer: UnsafeMutableRawBufferPointer?) {
274259
if let buffer = buffer, let baseAddress = buffer.baseAddress {
275-
self.next_out = CGRPCZlib_castVoidToBytefPointer(baseAddress)
276-
self.avail_out = UInt32(buffer.count)
260+
self.pointee.next_out = CGRPCZlib_castVoidToBytefPointer(baseAddress)
261+
self.pointee.avail_out = UInt32(buffer.count)
277262
} else {
278-
self.next_out = nil
279-
self.avail_out = 0
263+
self.pointee.next_out = nil
264+
self.pointee.avail_out = 0
280265
}
281266
}
282267

283268
/// Number of bytes available to read `self.nextInputBuffer`. See also: `z_stream.avail_in`.
284269
var availableInputBytes: Int {
285270
get {
286-
Int(self.avail_in)
271+
Int(self.pointee.avail_in)
287272
}
288273
set {
289-
self.avail_in = UInt32(newValue)
274+
self.pointee.avail_in = UInt32(newValue)
290275
}
291276
}
292277

293278
/// The remaining writable space in `nextOutputBuffer`. See also: `z_stream.avail_out`.
294279
var availableOutputBytes: Int {
295280
get {
296-
Int(self.avail_out)
281+
Int(self.pointee.avail_out)
297282
}
298283
set {
299-
self.avail_out = UInt32(newValue)
284+
self.pointee.avail_out = UInt32(newValue)
300285
}
301286
}
302287

303288
/// The total number of bytes written to the output buffer. See also: `z_stream.total_out`.
304289
var totalOutputBytes: Int {
305-
Int(self.total_out)
290+
Int(self.pointee.total_out)
306291
}
307292

308293
/// The last error message that zlib wrote. No message is guaranteed on error, however, `nil` is
309294
/// guaranteed if there is no error. See also `z_stream.msg`.
310295
var lastError: String? {
311-
self.msg.map { String(cString: $0) }
296+
self.pointee.msg.map { String(cString: $0) }
312297
}
313298

314-
mutating func inflate(input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
299+
func inflate(input: inout ByteBuffer, limit: Int) throws -> [UInt8] {
315300
return try input.readWithUnsafeMutableReadableBytes { inputPointer in
316301
self.setNextInputBuffer(inputPointer)
317302
defer {
@@ -342,7 +327,7 @@ extension z_stream {
342327
//
343328
// Note that Z_OK is not okay here since we always flush with Z_FINISH and therefore
344329
// use Z_STREAM_END as our success criteria.
345-
let rc = CGRPCZlib_inflate(&self, Z_FINISH)
330+
let rc = CGRPCZlib_inflate(self, Z_FINISH)
346331
switch rc {
347332
case Z_STREAM_END:
348333
finished = true
@@ -377,7 +362,7 @@ extension z_stream {
377362
}
378363
}
379364

380-
mutating func deflate(
365+
func deflate(
381366
_ input: [UInt8],
382367
into output: inout ByteBuffer,
383368
upperBound: Int
@@ -394,7 +379,7 @@ extension z_stream {
394379
return try output.writeWithUnsafeMutableBytes(minimumWritableBytes: upperBound) { output in
395380
self.setNextOutputBuffer(output)
396381

397-
let rc = CGRPCZlib_deflate(&self, Z_FINISH)
382+
let rc = CGRPCZlib_deflate(self, Z_FINISH)
398383

399384
// Possible return codes:
400385
// - Z_OK: some progress has been made

Tests/GRPCHTTP2CoreTests/Server/Compression/ZlibTests.swift

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ final class ZlibTests: XCTestCase {
3131
"""
3232

3333
private func compress(_ input: [UInt8], method: Zlib.Method) throws -> ByteBuffer {
34-
var compressor = Zlib.Compressor(method: method)
35-
compressor.initialize()
34+
let compressor = Zlib.Compressor(method: method)
3635
defer { compressor.end() }
3736

3837
var buffer = ByteBuffer()
@@ -45,8 +44,7 @@ final class ZlibTests: XCTestCase {
4544
method: Zlib.Method,
4645
limit: Int = .max
4746
) throws -> [UInt8] {
48-
var decompressor = Zlib.Decompressor(method: method)
49-
decompressor.initialize()
47+
let decompressor = Zlib.Decompressor(method: method)
5048
defer { decompressor.end() }
5149

5250
var input = input
@@ -69,8 +67,7 @@ final class ZlibTests: XCTestCase {
6967

7068
func testRepeatedCompresses() throws {
7169
let original = Array(self.text.utf8)
72-
var compressor = Zlib.Compressor(method: .deflate)
73-
compressor.initialize()
70+
let compressor = Zlib.Compressor(method: .deflate)
7471
defer { compressor.end() }
7572

7673
var compressed = ByteBuffer()
@@ -86,8 +83,7 @@ final class ZlibTests: XCTestCase {
8683

8784
func testRepeatedDecompresses() throws {
8885
let original = Array(self.text.utf8)
89-
var decompressor = Zlib.Decompressor(method: .deflate)
90-
decompressor.initialize()
86+
let decompressor = Zlib.Decompressor(method: .deflate)
9187
defer { decompressor.end() }
9288

9389
let compressed = try self.compress(original, method: .deflate)
@@ -123,8 +119,7 @@ final class ZlibTests: XCTestCase {
123119
}
124120

125121
func testCompressAppendsToBuffer() throws {
126-
var compressor = Zlib.Compressor(method: .deflate)
127-
compressor.initialize()
122+
let compressor = Zlib.Compressor(method: .deflate)
128123
defer { compressor.end() }
129124

130125
var buffer = ByteBuffer()

0 commit comments

Comments
 (0)