Skip to content

Commit b062fd5

Browse files
committed
Make namedtuples dataclasses
1 parent 5964643 commit b062fd5

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

Lib/compression/zstd/__init__.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
"ZstdFile",
2828
)
2929

30-
from collections import namedtuple
31-
from enum import IntEnum
32-
from functools import lru_cache
30+
import enum
31+
import functools
32+
import dataclasses
3333

3434
from compression.zstd.zstdfile import ZstdFile, open
3535
from _zstd import *
@@ -43,14 +43,19 @@
4343
_finalize_dict = _zstd._finalize_dict
4444

4545

46-
# TODO(emmatyping): these should be dataclasses or some other class, not namedtuples
46+
@dataclasses.dataclass(frozen=True)
47+
class _CompressionLevelValues:
48+
default: int
49+
min: int
50+
max: int
4751

48-
# compressionLevel_values
49-
_nt_values = namedtuple("values", ["default", "min", "max"])
50-
compressionLevel_values = _nt_values(*_zstd._compressionLevel_values)
52+
compressionLevel_values = _CompressionLevelValues(*_zstd._compressionLevel_values)
5153

52-
53-
_nt_frame_info = namedtuple("frame_info", ["decompressed_size", "dictionary_id"])
54+
@dataclasses.dataclass(frozen=True)
55+
class FrameInfo:
56+
"""A dataclass storing information about a Zstandard frame."""
57+
decompressed_size: int
58+
dictionary_id: int
5459

5560

5661
def get_frame_info(frame_buffer):
@@ -61,18 +66,20 @@ def get_frame_info(frame_buffer):
6166
a frame, and needs to include at least the frame header (6 to
6267
18 bytes).
6368
64-
Return a two-items namedtuple: (decompressed_size, dictionary_id)
69+
Return a FrameInfo dataclass, which currently has two attributes
70+
71+
'decompressed_size' is the size in bytes of the data in the frame when
72+
decompressed.
6573
6674
If decompressed_size is None, decompressed size is unknown.
6775
68-
dictionary_id is a 32-bit unsigned integer value. 0 means dictionary ID was
76+
'dictionary_id' is a 32-bit unsigned integer value. 0 means dictionary ID was
6977
not recorded in the frame header, the frame may or may not need a dictionary
7078
to be decoded, and the ID of such a dictionary is not specified.
71-
72-
It's possible to append more items to the namedtuple in the future."""
79+
"""
7380

7481
ret_tuple = _zstd._get_frame_info(frame_buffer)
75-
return _nt_frame_info(*ret_tuple)
82+
return FrameInfo(*ret_tuple)
7683

7784

7885
def _nbytes(dat):
@@ -215,7 +222,7 @@ def __get__(self, *_, **__):
215222
raise NotImplementedError(msg)
216223

217224

218-
class CParameter(IntEnum):
225+
class CParameter(enum.IntEnum):
219226
"""Compression parameters"""
220227

221228
compressionLevel = _zstd._ZSTD_c_compressionLevel
@@ -243,26 +250,26 @@ class CParameter(IntEnum):
243250
jobSize = _zstd._ZSTD_c_jobSize
244251
overlapLog = _zstd._ZSTD_c_overlapLog
245252

246-
@lru_cache(maxsize=None)
253+
@functools.lru_cache(maxsize=None)
247254
def bounds(self):
248255
"""Return lower and upper bounds of a compression parameter, both inclusive."""
249256
# 1 means compression parameter
250257
return _zstd._get_param_bounds(1, self.value)
251258

252259

253-
class DParameter(IntEnum):
260+
class DParameter(enum.IntEnum):
254261
"""Decompression parameters"""
255262

256263
windowLogMax = _zstd._ZSTD_d_windowLogMax
257264

258-
@lru_cache(maxsize=None)
265+
@functools.lru_cache(maxsize=None)
259266
def bounds(self):
260267
"""Return lower and upper bounds of a decompression parameter, both inclusive."""
261268
# 0 means decompression parameter
262269
return _zstd._get_param_bounds(0, self.value)
263270

264271

265-
class Strategy(IntEnum):
272+
class Strategy(enum.IntEnum):
266273
"""Compression strategies, listed from fastest to strongest.
267274
268275
Note : new strategies _might_ be added in the future, only the order

Lib/test/test_zstd/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def test_roundtrip_default(self):
149149

150150
def test_roundtrip_level(self):
151151
raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
152-
_default, minv, maxv = compressionLevel_values
152+
minv = compressionLevel_values.min
153+
maxv = compressionLevel_values.max
153154

154155
for level in range(max(-20, minv), maxv + 1):
155156
dat1 = compress(raw_dat, level)

0 commit comments

Comments
 (0)