1+ from __future__ import annotations
2+
13import gzip
24import io
35import struct
46
7+ from typing_extensions import Buffer
8+
59_XERIAL_V1_HEADER = (- 126 , b"S" , b"N" , b"A" , b"P" , b"P" , b"Y" , 0 , 1 , 1 )
610_XERIAL_V1_FORMAT = "bccccccBii"
711ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024
1216 cramjam = None
1317
1418
15- def has_gzip ():
19+ def has_gzip () -> bool :
1620 return True
1721
1822
19- def has_snappy ():
23+ def has_snappy () -> bool :
2024 return cramjam is not None
2125
2226
23- def has_zstd ():
27+ def has_zstd () -> bool :
2428 return cramjam is not None
2529
2630
27- def has_lz4 ():
31+ def has_lz4 () -> bool :
2832 return cramjam is not None
2933
3034
31- def gzip_encode (payload , compresslevel = None ):
35+ def gzip_encode (payload : Buffer , compresslevel : int | None = None ) -> bytes :
3236 if not compresslevel :
3337 compresslevel = 9
3438
@@ -45,7 +49,7 @@ def gzip_encode(payload, compresslevel=None):
4549 return buf .getvalue ()
4650
4751
48- def gzip_decode (payload ) :
52+ def gzip_decode (payload : Buffer ) -> bytes :
4953 buf = io .BytesIO (payload )
5054
5155 # Gzip context manager introduced in python 2.7
@@ -57,7 +61,9 @@ def gzip_decode(payload):
5761 gzipper .close ()
5862
5963
60- def snappy_encode (payload , xerial_compatible = True , xerial_blocksize = 32 * 1024 ):
64+ def snappy_encode (
65+ payload : Buffer , xerial_compatible : bool = True , xerial_blocksize : int = 32 * 1024
66+ ) -> bytes :
6167 """Encodes the given data with snappy compression.
6268
6369 If xerial_compatible is set then the stream is encoded in a fashion
@@ -93,12 +99,9 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024):
9399 for fmt , dat in zip (_XERIAL_V1_FORMAT , _XERIAL_V1_HEADER ):
94100 out .write (struct .pack ("!" + fmt , dat ))
95101
96- # Chunk through buffers to avoid creating intermediate slice copies
97- def chunker (payload , i , size ):
98- return memoryview (payload )[i : size + i ]
99-
102+ payload = memoryview (payload )
100103 for chunk in (
101- chunker ( payload , i , xerial_blocksize )
104+ payload [ i : i + xerial_blocksize ]
102105 for i in range (0 , len (payload ), xerial_blocksize )
103106 ):
104107 block = cramjam .snappy .compress_raw (chunk )
@@ -109,7 +112,7 @@ def chunker(payload, i, size):
109112 return out .getvalue ()
110113
111114
112- def _detect_xerial_stream (payload ) :
115+ def _detect_xerial_stream (payload : Buffer ) -> bool :
113116 """Detects if the data given might have been encoded with the blocking mode
114117 of the xerial snappy library.
115118
@@ -131,20 +134,21 @@ def _detect_xerial_stream(payload):
131134 1.
132135 """
133136
137+ payload = memoryview (payload )
134138 if len (payload ) > 16 :
135- header = struct .unpack ("!" + _XERIAL_V1_FORMAT , memoryview ( payload ) [:16 ])
139+ header = struct .unpack ("!" + _XERIAL_V1_FORMAT , payload [:16 ])
136140 return header == _XERIAL_V1_HEADER
137141 return False
138142
139143
140- def snappy_decode (payload ) :
144+ def snappy_decode (payload : Buffer ) -> bytes :
141145 if not has_snappy ():
142146 raise NotImplementedError ("Snappy codec is not available" )
143147
144148 if _detect_xerial_stream (payload ):
145149 # TODO ? Should become a fileobj ?
146150 out = io .BytesIO ()
147- byt = payload [16 :]
151+ byt = memoryview ( payload ) [16 :]
148152 length = len (byt )
149153 cursor = 0
150154
@@ -162,7 +166,7 @@ def snappy_decode(payload):
162166 return bytes (cramjam .snappy .decompress_raw (payload ))
163167
164168
165- def lz4_encode (payload , level = 9 ) :
169+ def lz4_encode (payload : Buffer , level : int = 9 ) -> bytes :
166170 # level=9 is used by default by broker itself
167171 # https://cwiki.apache.org/confluence/display/KAFKA/KIP-390%3A+Support+Compression+Level
168172 if not has_lz4 ():
@@ -177,14 +181,14 @@ def lz4_encode(payload, level=9):
177181 return bytes (compressor .finish ())
178182
179183
180- def lz4_decode (payload ) :
184+ def lz4_decode (payload : Buffer ) -> bytes :
181185 if not has_lz4 ():
182186 raise NotImplementedError ("LZ4 codec is not available" )
183187
184188 return bytes (cramjam .lz4 .decompress (payload ))
185189
186190
187- def zstd_encode (payload , level = None ):
191+ def zstd_encode (payload : Buffer , level : int | None = None ) -> bytes :
188192 if not has_zstd ():
189193 raise NotImplementedError ("Zstd codec is not available" )
190194
@@ -196,7 +200,7 @@ def zstd_encode(payload, level=None):
196200 return bytes (cramjam .zstd .compress (payload , level = level ))
197201
198202
199- def zstd_decode (payload ) :
203+ def zstd_decode (payload : Buffer ) -> bytes :
200204 if not has_zstd ():
201205 raise NotImplementedError ("Zstd codec is not available" )
202206
0 commit comments