Skip to content

Commit 3be703a

Browse files
committed
feat: decompress reactoring
1 parent e399267 commit 3be703a

File tree

3 files changed

+54
-38
lines changed

3 files changed

+54
-38
lines changed

ext/zstdruby/common.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static size_t zstd_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_i
3434
{
3535
#ifdef HAVE_RUBY_THREAD_H
3636
struct compress_params params = { ctx, output, input, endOp };
37-
rb_thread_call_without_gvl(compress_wrapper, &params, RUBY_UBF_IO, NULL);
37+
rb_thread_call_without_gvl(compress_wrapper, &params, NULL, NULL);
3838
return params.ret;
3939
#else
4040
return ZSTD_compressStream2(ctx, output, input, endOp);
@@ -69,6 +69,31 @@ static void set_compress_params(ZSTD_CCtx* const ctx, VALUE level_from_args, VAL
6969
}
7070
}
7171

72+
struct decompress_params {
73+
ZSTD_DCtx* dctx;
74+
ZSTD_outBuffer* output;
75+
ZSTD_inBuffer* input;
76+
size_t ret;
77+
};
78+
79+
static void* decompress_wrapper(void* args)
80+
{
81+
struct decompress_params* params = args;
82+
params->ret = ZSTD_decompressStream(params->dctx, params->output, params->input);
83+
return NULL;
84+
}
85+
86+
static size_t zstd_decompress(ZSTD_DCtx* const dctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input)
87+
{
88+
#ifdef HAVE_RUBY_THREAD_H
89+
struct decompress_params params = { dctx, output, input };
90+
rb_thread_call_without_gvl(decompress_wrapper, &params, NULL, NULL);
91+
return params.ret;
92+
#else
93+
return ZSTD_decompressStream(dctx, output, input);
94+
#endif
95+
}
96+
7297
static void set_decompress_params(ZSTD_DCtx* const dctx, VALUE kwargs)
7398
{
7499
ID kwargs_keys[1];

ext/zstdruby/streaming_decompress.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ rb_streaming_decompress_decompress(VALUE obj, VALUE src)
104104
VALUE result = rb_str_new(0, 0);
105105
while (input.pos < input.size) {
106106
ZSTD_outBuffer output = { (void*)output_data, sd->buf_size, 0 };
107-
size_t const ret = ZSTD_decompressStream(sd->dctx, &output, &input);
107+
size_t const ret = zstd_decompress(sd->dctx, &output, &input);
108108
if (ZSTD_isError(ret)) {
109109
rb_raise(rb_eRuntimeError, "decompress error error code: %s", ZSTD_getErrorName(ret));
110110
}

ext/zstdruby/zstdruby.c

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
2626
char* input_data = RSTRING_PTR(input_value);
2727
size_t input_size = RSTRING_LEN(input_value);
2828
ZSTD_inBuffer input = { input_data, input_size, 0 };
29+
// ZSTD_compressBound causes SEGV under multi-thread
2930
size_t max_compressed_size = ZSTD_compressBound(input_size);
3031
VALUE buf = rb_str_new(NULL, max_compressed_size);
3132
char* output_data = RSTRING_PTR(buf);
@@ -87,19 +88,8 @@ static VALUE rb_compress_using_dict(int argc, VALUE *argv, VALUE self)
8788
}
8889

8990

90-
static VALUE decompress_buffered(const char* input_data, size_t input_size)
91+
static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* input_data, size_t input_size)
9192
{
92-
ZSTD_DStream* const dstream = ZSTD_createDStream();
93-
if (dstream == NULL) {
94-
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDStream failed");
95-
}
96-
97-
size_t initResult = ZSTD_initDStream(dstream);
98-
if (ZSTD_isError(initResult)) {
99-
ZSTD_freeDStream(dstream);
100-
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_initDStream failed", ZSTD_getErrorName(initResult));
101-
}
102-
10393
VALUE output_string = rb_str_new(NULL, 0);
10494
ZSTD_outBuffer output = { NULL, 0, 0 };
10595

@@ -109,15 +99,14 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
10999
rb_str_resize(output_string, output.size);
110100
output.dst = RSTRING_PTR(output_string);
111101

112-
size_t readHint = ZSTD_decompressStream(dstream, &output, &input);
113-
if (ZSTD_isError(readHint)) {
114-
ZSTD_freeDStream(dstream);
115-
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(readHint));
102+
size_t ret = ZSTD_decompressStream(dctx, &output, &input);
103+
if (ZSTD_isError(ret)) {
104+
ZSTD_freeDCtx(dctx);
105+
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(ret));
116106
}
117107
}
118-
119-
ZSTD_freeDStream(dstream);
120108
rb_str_resize(output_string, output.pos);
109+
ZSTD_freeDCtx(dctx);
121110
return output_string;
122111
}
123112

@@ -129,6 +118,11 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
129118
StringValue(input_value);
130119
char* input_data = RSTRING_PTR(input_value);
131120
size_t input_size = RSTRING_LEN(input_value);
121+
ZSTD_DCtx* const dctx = ZSTD_createDCtx();
122+
if (dctx == NULL) {
123+
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
124+
}
125+
set_decompress_params(dctx, kwargs);
132126

133127
unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
134128
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
@@ -137,15 +131,9 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
137131
// ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness.
138132
// Therefore, we will not standardize on ZSTD_decompressStream
139133
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
140-
return decompress_buffered(input_data, input_size);
134+
return decompress_buffered(dctx, input_data, input_size);
141135
}
142136

143-
ZSTD_DCtx* const dctx = ZSTD_createDCtx();
144-
if (dctx == NULL) {
145-
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
146-
}
147-
set_decompress_params(dctx, kwargs);
148-
149137
VALUE output = rb_str_new(NULL, uncompressed_size);
150138
char* output_data = RSTRING_PTR(output);
151139

@@ -167,35 +155,38 @@ static VALUE rb_decompress_using_dict(int argc, VALUE *argv, VALUE self)
167155
StringValue(input_value);
168156
char* input_data = RSTRING_PTR(input_value);
169157
size_t input_size = RSTRING_LEN(input_value);
170-
unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
171-
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
172-
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
173-
}
174-
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
175-
return decompress_buffered(input_data, input_size);
176-
}
177-
VALUE output = rb_str_new(NULL, uncompressed_size);
178-
char* output_data = RSTRING_PTR(output);
179158

180159
char* dict_buffer = RSTRING_PTR(dict);
181160
size_t dict_size = RSTRING_LEN(dict);
182161
ZSTD_DDict* const ddict = ZSTD_createDDict(dict_buffer, dict_size);
183162
if (ddict == NULL) {
184163
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDDict failed");
185164
}
186-
187165
unsigned const expected_dict_id = ZSTD_getDictID_fromDDict(ddict);
188166
unsigned const actual_dict_id = ZSTD_getDictID_fromFrame(input_data, input_size);
189167
if (expected_dict_id != actual_dict_id) {
190168
ZSTD_freeDDict(ddict);
191-
rb_raise(rb_eRuntimeError, "%s: %s", "DictID mismatch", ZSTD_getErrorName(uncompressed_size));
169+
rb_raise(rb_eRuntimeError, "DictID mismatch");
192170
}
193171

194172
ZSTD_DCtx* const ctx = ZSTD_createDCtx();
195173
if (ctx == NULL) {
196174
ZSTD_freeDDict(ddict);
197175
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
198176
}
177+
178+
unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
179+
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
180+
ZSTD_freeDDict(ddict);
181+
ZSTD_freeDCtx(ctx);
182+
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
183+
}
184+
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
185+
return decompress_buffered(ctx, input_data, input_size);
186+
}
187+
188+
VALUE output = rb_str_new(NULL, uncompressed_size);
189+
char* output_data = RSTRING_PTR(output);
199190
size_t const decompress_size = ZSTD_decompress_usingDDict(ctx, output_data, uncompressed_size, input_data, input_size, ddict);
200191
if (ZSTD_isError(decompress_size)) {
201192
ZSTD_freeDDict(ddict);

0 commit comments

Comments
 (0)