Skip to content

Commit 856938d

Browse files
committed
Fix Zstd.decompress glitches
Signed-off-by: Hiroshi Hatake <[email protected]>
1 parent 10dfa82 commit 856938d

File tree

1 file changed

+70
-40
lines changed

1 file changed

+70
-40
lines changed

ext/zstdruby/zstdruby.c

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,61 +39,91 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
3939
return output;
4040
}
4141

42-
static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* input_data, size_t input_size)
43-
{
44-
ZSTD_inBuffer input = { input_data, input_size, 0 };
45-
VALUE result = rb_str_new(0, 0);
42+
static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t size, VALUE kwargs) {
43+
VALUE out = rb_str_buf_new(0);
44+
size_t cap = ZSTD_DStreamOutSize();
45+
char *buf = ALLOC_N(char, cap);
46+
ZSTD_inBuffer in = (ZSTD_inBuffer){ src, size, 0 };
4647

47-
while (input.pos < input.size) {
48-
ZSTD_outBuffer output = { NULL, 0, 0 };
49-
output.size += ZSTD_DStreamOutSize();
50-
VALUE output_string = rb_str_new(NULL, output.size);
51-
output.dst = RSTRING_PTR(output_string);
48+
ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
49+
set_decompress_params(dctx, kwargs);
5250

53-
size_t ret = zstd_stream_decompress(dctx, &output, &input, false);
51+
for (;;) {
52+
ZSTD_outBuffer o = (ZSTD_outBuffer){ buf, cap, 0 };
53+
size_t ret = ZSTD_decompressStream(dctx, &o, &in);
5454
if (ZSTD_isError(ret)) {
55-
ZSTD_freeDCtx(dctx);
56-
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(ret));
55+
xfree(buf);
56+
rb_raise(rb_eRuntimeError, "ZSTD_decompressStream failed: %s", ZSTD_getErrorName(ret));
57+
}
58+
if (o.pos) {
59+
rb_str_cat(out, buf, o.pos);
60+
}
61+
if (ret == 0) {
62+
break;
5763
}
58-
rb_str_cat(result, output.dst, output.pos);
5964
}
60-
ZSTD_freeDCtx(dctx);
61-
return result;
65+
xfree(buf);
66+
return out;
67+
}
68+
69+
static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* data, size_t len) {
70+
return decode_one_frame(dctx, (const unsigned char*)data, len, Qnil);
6271
}
6372

6473
static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
6574
{
66-
VALUE input_value;
67-
VALUE kwargs;
75+
VALUE input_value, kwargs;
6876
rb_scan_args(argc, argv, "10:", &input_value, &kwargs);
6977
StringValue(input_value);
70-
char* input_data = RSTRING_PTR(input_value);
71-
size_t input_size = RSTRING_LEN(input_value);
72-
ZSTD_DCtx* const dctx = ZSTD_createDCtx();
73-
if (dctx == NULL) {
74-
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
75-
}
76-
set_decompress_params(dctx, kwargs);
7778

78-
unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
79-
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
80-
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
81-
}
82-
// ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness.
83-
// Therefore, we will not standardize on ZSTD_decompressStream
84-
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
85-
return decompress_buffered(dctx, input_data, input_size);
86-
}
79+
size_t in_size = RSTRING_LEN(input_value);
80+
const unsigned char *in_r = (const unsigned char *)RSTRING_PTR(input_value);
81+
unsigned char *in = ALLOC_N(unsigned char, in_size);
82+
memcpy(in, in_r, in_size);
83+
84+
size_t off = 0;
85+
const uint32_t ZSTD_MAGIC = 0xFD2FB528U;
86+
const uint32_t SKIP_LO = 0x184D2A50U; /* ...5F */
87+
88+
while (off + 4 <= in_size) {
89+
uint32_t magic = (uint32_t)in[off]
90+
| ((uint32_t)in[off+1] << 8)
91+
| ((uint32_t)in[off+2] << 16)
92+
| ((uint32_t)in[off+3] << 24);
93+
94+
if ((magic & 0xFFFFFFF0U) == (SKIP_LO & 0xFFFFFFF0U)) {
95+
if (off + 8 > in_size) break;
96+
uint32_t skipLen = (uint32_t)in[off+4]
97+
| ((uint32_t)in[off+5] << 8)
98+
| ((uint32_t)in[off+6] << 16)
99+
| ((uint32_t)in[off+7] << 24);
100+
size_t adv = (size_t)8 + (size_t)skipLen;
101+
if (off + adv > in_size) break;
102+
off += adv;
103+
continue;
104+
}
87105

88-
VALUE output = rb_str_new(NULL, uncompressed_size);
89-
char* output_data = RSTRING_PTR(output);
106+
if (magic == ZSTD_MAGIC) {
107+
ZSTD_DCtx *dctx = ZSTD_createDCtx();
108+
if (!dctx) {
109+
xfree(in);
110+
rb_raise(rb_eRuntimeError, "ZSTD_createDCtx failed");
111+
}
112+
113+
VALUE out = decode_one_frame(dctx, in + off, in_size - off, kwargs);
90114

91-
size_t const decompress_size = zstd_decompress(dctx, output_data, uncompressed_size, input_data, input_size, false);
92-
if (ZSTD_isError(decompress_size)) {
93-
rb_raise(rb_eRuntimeError, "%s: %s", "decompress error", ZSTD_getErrorName(decompress_size));
115+
ZSTD_freeDCtx(dctx);
116+
xfree(in);
117+
RB_GC_GUARD(input_value);
118+
return out;
119+
}
120+
121+
off += 1;
94122
}
95-
ZSTD_freeDCtx(dctx);
96-
return output;
123+
124+
xfree(in);
125+
RB_GC_GUARD(input_value);
126+
rb_raise(rb_eRuntimeError, "not a zstd frame (magic not found)");
97127
}
98128

99129
static void free_cdict(void *dict)

0 commit comments

Comments
 (0)