Skip to content

Commit 8f272d9

Browse files
authored
Merge pull request #116 from cosmo0920/fix-checksum-mismatch
Prevent GC compaction and unintended sweeps for preventing checksum mismatch
2 parents c888e6f + 856938d commit 8f272d9

File tree

3 files changed

+92
-54
lines changed

3 files changed

+92
-54
lines changed

ext/zstdruby/streaming_compress.c

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ static VALUE
105105
no_compress(struct streaming_compress_t* sc, ZSTD_EndDirective endOp)
106106
{
107107
ZSTD_inBuffer input = { NULL, 0, 0 };
108-
const char* output_data = RSTRING_PTR(sc->buf);
109108
VALUE result = rb_str_new(0, 0);
110109
size_t ret;
111110
do {
111+
const char* output_data = RSTRING_PTR(sc->buf);
112112
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
113113

114-
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false);
114+
ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false);
115115
if (ZSTD_isError(ret)) {
116116
rb_raise(rb_eRuntimeError, "flush error error code: %s", ZSTD_getErrorName(ret));
117117
}
@@ -131,9 +131,9 @@ rb_streaming_compress_compress(VALUE obj, VALUE src)
131131
struct streaming_compress_t* sc;
132132
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);
133133

134-
const char* output_data = RSTRING_PTR(sc->buf);
135134
VALUE result = rb_str_new(0, 0);
136135
while (input.pos < input.size) {
136+
const char* output_data = RSTRING_PTR(sc->buf);
137137
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
138138
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
139139
if (ZSTD_isError(ret)) {
@@ -150,7 +150,6 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj)
150150
size_t total = 0;
151151
struct streaming_compress_t* sc;
152152
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);
153-
const char* output_data = RSTRING_PTR(sc->buf);
154153

155154
while (argc-- > 0) {
156155
VALUE str = *argv++;
@@ -160,18 +159,20 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj)
160159
ZSTD_inBuffer input = { input_data, input_size, 0 };
161160

162161
while (input.pos < input.size) {
162+
const char* output_data = RSTRING_PTR(sc->buf);
163163
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
164164
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
165165
if (ZSTD_isError(ret)) {
166166
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
167167
}
168-
/* collect produced bytes */
168+
/* Directly append to the pending buffer */
169169
if (output.pos > 0) {
170170
rb_str_cat(sc->pending, output.dst, output.pos);
171171
}
172-
total += RSTRING_LEN(str);
173172
}
173+
total += RSTRING_LEN(str);
174174
}
175+
175176
return SIZET2NUM(total);
176177
}
177178

@@ -202,9 +203,9 @@ rb_streaming_compress_flush(VALUE obj)
202203
struct streaming_compress_t* sc;
203204
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);
204205
VALUE drained = no_compress(sc, ZSTD_e_flush);
205-
rb_str_cat(sc->pending, RSTRING_PTR(drained), RSTRING_LEN(drained));
206-
VALUE out = sc->pending;
207-
sc->pending = rb_str_new(0, 0);
206+
VALUE out = rb_str_dup(sc->pending);
207+
rb_str_cat(out, RSTRING_PTR(drained), RSTRING_LEN(drained));
208+
rb_str_resize(sc->pending, 0);
208209
return out;
209210
}
210211

@@ -214,9 +215,9 @@ rb_streaming_compress_finish(VALUE obj)
214215
struct streaming_compress_t* sc;
215216
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);
216217
VALUE drained = no_compress(sc, ZSTD_e_end);
217-
rb_str_cat(sc->pending, RSTRING_PTR(drained), RSTRING_LEN(drained));
218-
VALUE out = sc->pending;
219-
sc->pending = rb_str_new(0, 0);
218+
VALUE out = rb_str_dup(sc->pending);
219+
rb_str_cat(out, RSTRING_PTR(drained), RSTRING_LEN(drained));
220+
rb_str_resize(sc->pending, 0);
220221
return out;
221222
}
222223

ext/zstdruby/streaming_decompress.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,22 @@ rb_streaming_decompress_decompress(VALUE obj, VALUE src)
100100

101101
struct streaming_decompress_t* sd;
102102
TypedData_Get_Struct(obj, struct streaming_decompress_t, &streaming_decompress_type, sd);
103-
const char* output_data = RSTRING_PTR(sd->buf);
104103
VALUE result = rb_str_new(0, 0);
104+
105105
while (input.pos < input.size) {
106+
const char* output_data = RSTRING_PTR(sd->buf);
106107
ZSTD_outBuffer output = { (void*)output_data, sd->buf_size, 0 };
107108
size_t const ret = zstd_stream_decompress(sd->dctx, &output, &input, false);
109+
108110
if (ZSTD_isError(ret)) {
109111
rb_raise(rb_eRuntimeError, "decompress error error code: %s", ZSTD_getErrorName(ret));
110112
}
111-
rb_str_cat(result, output.dst, output.pos);
113+
if (output.pos > 0) {
114+
rb_str_cat(result, output.dst, output.pos);
115+
}
116+
if (ret == 0 && output.pos == 0) {
117+
break;
118+
}
112119
}
113120
return result;
114121
}

ext/zstdruby/zstdruby.c

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,61 +39,91 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
3939
return output;
4040
}
4141

42-
static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* input_data, size_t input_size)
43-
{
44-
ZSTD_inBuffer input = { input_data, input_size, 0 };
45-
VALUE result = rb_str_new(0, 0);
42+
static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t size, VALUE kwargs) {
43+
VALUE out = rb_str_buf_new(0);
44+
size_t cap = ZSTD_DStreamOutSize();
45+
char *buf = ALLOC_N(char, cap);
46+
ZSTD_inBuffer in = (ZSTD_inBuffer){ src, size, 0 };
4647

47-
while (input.pos < input.size) {
48-
ZSTD_outBuffer output = { NULL, 0, 0 };
49-
output.size += ZSTD_DStreamOutSize();
50-
VALUE output_string = rb_str_new(NULL, output.size);
51-
output.dst = RSTRING_PTR(output_string);
48+
ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
49+
set_decompress_params(dctx, kwargs);
5250

53-
size_t ret = zstd_stream_decompress(dctx, &output, &input, false);
51+
for (;;) {
52+
ZSTD_outBuffer o = (ZSTD_outBuffer){ buf, cap, 0 };
53+
size_t ret = ZSTD_decompressStream(dctx, &o, &in);
5454
if (ZSTD_isError(ret)) {
55-
ZSTD_freeDCtx(dctx);
56-
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(ret));
55+
xfree(buf);
56+
rb_raise(rb_eRuntimeError, "ZSTD_decompressStream failed: %s", ZSTD_getErrorName(ret));
57+
}
58+
if (o.pos) {
59+
rb_str_cat(out, buf, o.pos);
60+
}
61+
if (ret == 0) {
62+
break;
5763
}
58-
rb_str_cat(result, output.dst, output.pos);
5964
}
60-
ZSTD_freeDCtx(dctx);
61-
return result;
65+
xfree(buf);
66+
return out;
67+
}
68+
69+
static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* data, size_t len) {
70+
return decode_one_frame(dctx, (const unsigned char*)data, len, Qnil);
6271
}
6372

6473
static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
6574
{
66-
VALUE input_value;
67-
VALUE kwargs;
75+
VALUE input_value, kwargs;
6876
rb_scan_args(argc, argv, "10:", &input_value, &kwargs);
6977
StringValue(input_value);
70-
char* input_data = RSTRING_PTR(input_value);
71-
size_t input_size = RSTRING_LEN(input_value);
72-
ZSTD_DCtx* const dctx = ZSTD_createDCtx();
73-
if (dctx == NULL) {
74-
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
75-
}
76-
set_decompress_params(dctx, kwargs);
7778

78-
unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
79-
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
80-
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
81-
}
82-
// ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness.
83-
// Therefore, we will not standardize on ZSTD_decompressStream
84-
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
85-
return decompress_buffered(dctx, input_data, input_size);
86-
}
79+
size_t in_size = RSTRING_LEN(input_value);
80+
const unsigned char *in_r = (const unsigned char *)RSTRING_PTR(input_value);
81+
unsigned char *in = ALLOC_N(unsigned char, in_size);
82+
memcpy(in, in_r, in_size);
83+
84+
size_t off = 0;
85+
const uint32_t ZSTD_MAGIC = 0xFD2FB528U;
86+
const uint32_t SKIP_LO = 0x184D2A50U; /* ...5F */
87+
88+
while (off + 4 <= in_size) {
89+
uint32_t magic = (uint32_t)in[off]
90+
| ((uint32_t)in[off+1] << 8)
91+
| ((uint32_t)in[off+2] << 16)
92+
| ((uint32_t)in[off+3] << 24);
93+
94+
if ((magic & 0xFFFFFFF0U) == (SKIP_LO & 0xFFFFFFF0U)) {
95+
if (off + 8 > in_size) break;
96+
uint32_t skipLen = (uint32_t)in[off+4]
97+
| ((uint32_t)in[off+5] << 8)
98+
| ((uint32_t)in[off+6] << 16)
99+
| ((uint32_t)in[off+7] << 24);
100+
size_t adv = (size_t)8 + (size_t)skipLen;
101+
if (off + adv > in_size) break;
102+
off += adv;
103+
continue;
104+
}
87105

88-
VALUE output = rb_str_new(NULL, uncompressed_size);
89-
char* output_data = RSTRING_PTR(output);
106+
if (magic == ZSTD_MAGIC) {
107+
ZSTD_DCtx *dctx = ZSTD_createDCtx();
108+
if (!dctx) {
109+
xfree(in);
110+
rb_raise(rb_eRuntimeError, "ZSTD_createDCtx failed");
111+
}
112+
113+
VALUE out = decode_one_frame(dctx, in + off, in_size - off, kwargs);
90114

91-
size_t const decompress_size = zstd_decompress(dctx, output_data, uncompressed_size, input_data, input_size, false);
92-
if (ZSTD_isError(decompress_size)) {
93-
rb_raise(rb_eRuntimeError, "%s: %s", "decompress error", ZSTD_getErrorName(decompress_size));
115+
ZSTD_freeDCtx(dctx);
116+
xfree(in);
117+
RB_GC_GUARD(input_value);
118+
return out;
119+
}
120+
121+
off += 1;
94122
}
95-
ZSTD_freeDCtx(dctx);
96-
return output;
123+
124+
xfree(in);
125+
RB_GC_GUARD(input_value);
126+
rb_raise(rb_eRuntimeError, "not a zstd frame (magic not found)");
97127
}
98128

99129
static void free_cdict(void *dict)

0 commit comments

Comments
 (0)