Skip to content

Commit 8b15475

Browse files
[Zstandard] Decompress even when discord doesn't encode size information
1 parent 0ace5f8 commit 8b15475

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

discord/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,20 @@
8181
else:
8282
HAS_ORJSON = True
8383

84+
_ZSTD_SOURCE: Literal['zstandard', 'compression.zstd'] | None = None
85+
8486
try:
8587
from zstandard import ZstdDecompressor # type: ignore
8688

87-
_HAS_ZSTD = True
89+
_ZSTD_SOURCE = 'zstandard'
8890
except ImportError:
8991
try:
9092
from compression.zstd import ZstdDecompressor # type: ignore
93+
94+
_ZSTD_SOURCE = 'compression.zstd'
9195
except ImportError:
9296
import zlib
9397

94-
_HAS_ZSTD = False
95-
else:
96-
_HAS_ZSTD = True
9798

9899
__all__ = (
99100
'oauth_url',
@@ -1432,7 +1433,7 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o
14321433
return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}'
14331434

14341435

1435-
if _HAS_ZSTD:
1436+
if _ZSTD_SOURCE is not None:
14361437

14371438
class _ZstdDecompressionContext:
14381439
__slots__ = ('decompressor',)
@@ -1441,6 +1442,12 @@ class _ZstdDecompressionContext:
14411442

14421443
def __init__(self) -> None:
14431444
self.decompressor = ZstdDecompressor()
1445+
if _ZSTD_SOURCE == 'zstandard':
1446+
# The default API for zstandard requires a size hint when
1447+
# the size is not included in the zstandard frame.
1448+
# This constructs an instance of zstandard.ZstdDecompressionObj
1449+
# which dynamically allocates a buffer, matching stdlib module's behavior.
1450+
self.decompressor = self.decompressor.decompressobj()
14441451

14451452
def decompress(self, data: bytes, /) -> str | None:
14461453
# Each WS message is a complete gateway message

0 commit comments

Comments
 (0)