Skip to content

Commit 63a4066

Browse files
committed
fix: compression crash
Zstd.compress does not consume buffer for large bytes
1 parent aeb79df commit 63a4066

File tree

3 files changed

+18
-21
lines changed

3 files changed

+18
-21
lines changed

ext/zstdruby/common.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,29 @@ static void set_compress_params(ZSTD_CCtx* const ctx, VALUE level_from_args, VAL
4444
}
4545
}
4646

47-
struct compress_params {
47+
struct stream_compress_params {
4848
ZSTD_CCtx* ctx;
4949
ZSTD_outBuffer* output;
5050
ZSTD_inBuffer* input;
5151
ZSTD_EndDirective endOp;
5252
size_t ret;
5353
};
5454

55-
static void* compress_wrapper(void* args)
55+
static void* stream_compress_wrapper(void* args)
5656
{
57-
struct compress_params* params = args;
57+
struct stream_compress_params* params = args;
5858
params->ret = ZSTD_compressStream2(params->ctx, params->output, params->input, params->endOp);
5959
return NULL;
6060
}
6161

62-
static size_t zstd_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective endOp, bool gvl)
62+
static size_t zstd_stream_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective endOp, bool gvl)
6363
{
6464
#ifdef HAVE_RUBY_THREAD_H
6565
if (gvl) {
6666
return ZSTD_compressStream2(ctx, output, input, endOp);
6767
} else {
68-
struct compress_params params = { ctx, output, input, endOp };
69-
rb_thread_call_without_gvl(compress_wrapper, &params, NULL, NULL);
68+
struct stream_compress_params params = { ctx, output, input, endOp };
69+
rb_thread_call_without_gvl(stream_compress_wrapper, &params, NULL, NULL);
7070
return params.ret;
7171
}
7272
#else

ext/zstdruby/streaming_compress.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ no_compress(struct streaming_compress_t* sc, ZSTD_EndDirective endOp)
106106
do {
107107
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
108108

109-
size_t const ret = zstd_compress(sc->ctx, &output, &input, endOp, false);
109+
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false);
110110
if (ZSTD_isError(ret)) {
111111
rb_raise(rb_eRuntimeError, "flush error error code: %s", ZSTD_getErrorName(ret));
112112
}
@@ -130,7 +130,7 @@ rb_streaming_compress_compress(VALUE obj, VALUE src)
130130
VALUE result = rb_str_new(0, 0);
131131
while (input.pos < input.size) {
132132
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
133-
size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
133+
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
134134
if (ZSTD_isError(ret)) {
135135
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
136136
}
@@ -157,7 +157,7 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj)
157157

158158
while (input.pos < input.size) {
159159
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
160-
size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
160+
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
161161
if (ZSTD_isError(ret)) {
162162
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
163163
}

ext/zstdruby/zstdruby.c

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,19 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
2525
StringValue(input_value);
2626
char* input_data = RSTRING_PTR(input_value);
2727
size_t input_size = RSTRING_LEN(input_value);
28-
ZSTD_inBuffer input = { input_data, input_size, 0 };
29-
// ZSTD_compressBound causes SEGV under multi-thread
30-
size_t max_compressed_size = ZSTD_compressBound(input_size);
31-
VALUE buf = rb_str_new(NULL, max_compressed_size);
32-
char* output_data = RSTRING_PTR(buf);
33-
ZSTD_outBuffer output = { (void*)output_data, max_compressed_size, 0 };
3428

35-
size_t const ret = zstd_compress(ctx, &output, &input, ZSTD_e_end, true);
29+
size_t const max_compressed_size = ZSTD_compressBound(input_size);
30+
VALUE output = rb_str_new(NULL, max_compressed_size);
31+
const char* output_data = RSTRING_PTR(output);
32+
33+
size_t const ret = ZSTD_compress2(ctx,(void*)output_data, max_compressed_size, (void*)input_data, input_size);
3634
if (ZSTD_isError(ret)) {
37-
ZSTD_freeCCtx(ctx);
38-
rb_raise(rb_eRuntimeError, "%s: %s", "compress failed", ZSTD_getErrorName(ret));
35+
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
3936
}
40-
VALUE result = rb_str_new(0, 0);
41-
rb_str_cat(result, output.dst, output.pos);
37+
rb_str_resize(output, ret);
38+
4239
ZSTD_freeCCtx(ctx);
43-
return result;
40+
return output;
4441
}
4542

4643
static VALUE rb_compress_using_dict(int argc, VALUE *argv, VALUE self)

0 commit comments

Comments
 (0)