Skip to content

Commit b5beb58

Browse files
committed
feat: add simple dictionay compression and decompression
1 parent 0f49eb2 commit b5beb58

File tree

7 files changed

+180
-34
lines changed

7 files changed

+180
-34
lines changed

ext/zstdruby/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,12 @@
44
#include "ruby.h"
55
#include "./libzstd/zstd.h"
66

7+
static int convert_compression_level(VALUE compression_level_value)
8+
{
9+
if (NIL_P(compression_level_value)) {
10+
return ZSTD_CLEVEL_DEFAULT;
11+
}
12+
return NUM2INT(compression_level_value);
13+
}
714

815
#endif /* ZSTD_RUBY_H */

ext/zstdruby/streaming_compress.c

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,7 @@ rb_streaming_compress_initialize(int argc, VALUE *argv, VALUE obj)
5353
{
5454
VALUE compression_level_value;
5555
rb_scan_args(argc, argv, "01", &compression_level_value);
56-
57-
int compression_level;
58-
if (NIL_P(compression_level_value)) {
59-
compression_level = ZSTD_CLEVEL_DEFAULT;
60-
} else {
61-
compression_level = NUM2INT(compression_level_value);
62-
}
56+
int compression_level = convert_compression_level(compression_level_value);
6357

6458
struct streaming_compress_t* sc;
6559
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);

ext/zstdruby/zstdruby.c

Lines changed: 110 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,41 +8,74 @@ static VALUE zstdVersion(VALUE self)
88
return INT2NUM(version);
99
}
1010

11-
static VALUE compress(int argc, VALUE *argv, VALUE self)
11+
static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
1212
{
1313
VALUE input_value;
1414
VALUE compression_level_value;
1515
rb_scan_args(argc, argv, "11", &input_value, &compression_level_value);
16+
int compression_level = convert_compression_level(compression_level_value);
1617

1718
StringValue(input_value);
18-
const char* input_data = RSTRING_PTR(input_value);
19+
char* input_data = RSTRING_PTR(input_value);
1920
size_t input_size = RSTRING_LEN(input_value);
21+
size_t max_compressed_size = ZSTD_compressBound(input_size);
2022

21-
int compression_level;
22-
if (NIL_P(compression_level_value)) {
23-
compression_level = 0; // The default. See ZSTD_CLEVEL_DEFAULT in zstd_compress.c
24-
} else {
25-
compression_level = NUM2INT(compression_level_value);
23+
VALUE output = rb_str_new(NULL, max_compressed_size);
24+
char* output_data = RSTRING_PTR(output);
25+
size_t compressed_size = ZSTD_compress((void*)output_data, max_compressed_size,
26+
(void*)input_data, input_size, compression_level);
27+
if (ZSTD_isError(compressed_size)) {
28+
rb_raise(rb_eRuntimeError, "%s: %s", "compress failed", ZSTD_getErrorName(compressed_size));
2629
}
2730

28-
// do compress
31+
rb_str_resize(output, compressed_size);
32+
return output;
33+
}
34+
35+
static VALUE rb_compress_using_dict(int argc, VALUE *argv, VALUE self)
36+
{
37+
VALUE input_value;
38+
VALUE dict;
39+
VALUE compression_level_value;
40+
rb_scan_args(argc, argv, "21", &input_value, &dict, &compression_level_value);
41+
int compression_level = convert_compression_level(compression_level_value);
42+
43+
StringValue(input_value);
44+
char* input_data = RSTRING_PTR(input_value);
45+
size_t input_size = RSTRING_LEN(input_value);
2946
size_t max_compressed_size = ZSTD_compressBound(input_size);
3047

48+
char* dict_buffer = RSTRING_PTR(dict);
49+
size_t dict_size = RSTRING_LEN(dict);
50+
51+
ZSTD_CDict* const cdict = ZSTD_createCDict(dict_buffer, dict_size, compression_level);
52+
if (cdict == NULL) {
53+
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createCDict failed");
54+
}
55+
ZSTD_CCtx* const ctx = ZSTD_createCCtx();
56+
if (ctx == NULL) {
57+
ZSTD_freeCDict(cdict);
58+
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createCCtx failed");
59+
}
60+
3161
VALUE output = rb_str_new(NULL, max_compressed_size);
3262
char* output_data = RSTRING_PTR(output);
33-
34-
size_t compressed_size = ZSTD_compress((void*)output_data, max_compressed_size,
35-
(const void*)input_data, input_size, compression_level);
63+
size_t const compressed_size = ZSTD_compress_usingCDict(ctx, (void*)output_data, max_compressed_size,
64+
(void*)input_data, input_size, cdict);
3665

3766
if (ZSTD_isError(compressed_size)) {
67+
ZSTD_freeCDict(cdict);
68+
ZSTD_freeCCtx(ctx);
3869
rb_raise(rb_eRuntimeError, "%s: %s", "compress failed", ZSTD_getErrorName(compressed_size));
39-
} else {
40-
rb_str_resize(output, compressed_size);
4170
}
4271

72+
rb_str_resize(output, compressed_size);
73+
ZSTD_freeCDict(cdict);
74+
ZSTD_freeCCtx(ctx);
4375
return output;
4476
}
4577

78+
4679
static VALUE decompress_buffered(const char* input_data, size_t input_size)
4780
{
4881
const size_t outputBufferSize = 4096;
@@ -58,7 +91,6 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
5891
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_initDStream failed", ZSTD_getErrorName(initResult));
5992
}
6093

61-
6294
VALUE output_string = rb_str_new(NULL, 0);
6395
ZSTD_outBuffer output = { NULL, 0, 0 };
6496

@@ -80,23 +112,24 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
80112
return output_string;
81113
}
82114

83-
static VALUE decompress(VALUE self, VALUE input)
115+
static VALUE rb_decompress(VALUE self, VALUE input_value)
84116
{
85-
StringValue(input);
86-
const char* input_data = RSTRING_PTR(input);
87-
size_t input_size = RSTRING_LEN(input);
88-
89-
uint64_t uncompressed_size = ZSTD_getDecompressedSize(input_data, input_size);
117+
StringValue(input_value);
118+
char* input_data = RSTRING_PTR(input_value);
119+
size_t input_size = RSTRING_LEN(input_value);
90120

91-
if (uncompressed_size == 0) {
121+
unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
122+
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
123+
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
124+
}
125+
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
92126
return decompress_buffered(input_data, input_size);
93127
}
94128

95129
VALUE output = rb_str_new(NULL, uncompressed_size);
96130
char* output_data = RSTRING_PTR(output);
97-
98-
size_t decompress_size = ZSTD_decompress((void*)output_data, uncompressed_size,
99-
(const void*)input_data, input_size);
131+
size_t const decompress_size = ZSTD_decompress((void*)output_data, uncompressed_size,
132+
(void*)input_data, input_size);
100133

101134
if (ZSTD_isError(decompress_size)) {
102135
rb_raise(rb_eRuntimeError, "%s: %s", "decompress error", ZSTD_getErrorName(decompress_size));
@@ -105,10 +138,61 @@ static VALUE decompress(VALUE self, VALUE input)
105138
return output;
106139
}
107140

141+
static VALUE rb_decompress_using_dict(int argc, VALUE *argv, VALUE self)
142+
{
143+
VALUE input_value;
144+
VALUE dict;
145+
rb_scan_args(argc, argv, "20", &input_value, &dict);
146+
147+
StringValue(input_value);
148+
char* input_data = RSTRING_PTR(input_value);
149+
size_t input_size = RSTRING_LEN(input_value);
150+
unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
151+
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
152+
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
153+
}
154+
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
155+
return decompress_buffered(input_data, input_size);
156+
}
157+
VALUE output = rb_str_new(NULL, uncompressed_size);
158+
char* output_data = RSTRING_PTR(output);
159+
160+
char* dict_buffer = RSTRING_PTR(dict);
161+
size_t dict_size = RSTRING_LEN(dict);
162+
ZSTD_DDict* const ddict = ZSTD_createDDict(dict_buffer, dict_size);
163+
if (ddict == NULL) {
164+
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDDict failed");
165+
}
166+
167+
unsigned const expected_dict_id = ZSTD_getDictID_fromDDict(ddict);
168+
unsigned const actual_dict_id = ZSTD_getDictID_fromFrame(input_data, input_size);
169+
if (expected_dict_id != actual_dict_id) {
170+
ZSTD_freeDDict(ddict);
171+
rb_raise(rb_eRuntimeError, "%s: %s", "DictID mismatch", ZSTD_getErrorName(uncompressed_size));
172+
}
173+
174+
ZSTD_DCtx* const ctx = ZSTD_createDCtx();
175+
if (ctx == NULL) {
176+
ZSTD_freeDDict(ddict);
177+
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
178+
}
179+
size_t const decompress_size = ZSTD_decompress_usingDDict(ctx, output_data, uncompressed_size, input_data, input_size, ddict);
180+
if (ZSTD_isError(decompress_size)) {
181+
ZSTD_freeDDict(ddict);
182+
ZSTD_freeDCtx(ctx);
183+
rb_raise(rb_eRuntimeError, "%s: %s", "decompress error", ZSTD_getErrorName(decompress_size));
184+
}
185+
ZSTD_freeDDict(ddict);
186+
ZSTD_freeDCtx(ctx);
187+
return output;
188+
}
189+
108190
void
109191
zstd_ruby_init(void)
110192
{
111193
rb_define_module_function(rb_mZstd, "zstd_version", zstdVersion, 0);
112-
rb_define_module_function(rb_mZstd, "compress", compress, -1);
113-
rb_define_module_function(rb_mZstd, "decompress", decompress, 1);
194+
rb_define_module_function(rb_mZstd, "compress", rb_compress, -1);
195+
rb_define_module_function(rb_mZstd, "compress_using_dict", rb_compress_using_dict, -1);
196+
rb_define_module_function(rb_mZstd, "decompress", rb_decompress, 1);
197+
rb_define_module_function(rb_mZstd, "decompress_using_dict", rb_decompress_using_dict, -1);
114198
}

lib/zstd-ruby.rb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,4 @@
22
require "zstd-ruby/zstdruby"
33

44
module Zstd
5-
# Your code goes here...
65
end

spec/dictionary

110 KB
Binary file not shown.

spec/user_springmt.json

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"login": "SpringMT",
3+
"id": 579005,
4+
"avatar_url": "https://avatars.githubusercontent.com/u/579005?v=4",
5+
"gravatar_id": "",
6+
"url": "https://api.github.com/users/SpringMT",
7+
"html_url": "https://github.com/SpringMT",
8+
"followers_url": "https://api.github.com/users/SpringMT/followers",
9+
"following_url": "https://api.github.com/users/SpringMT/following{/other_user}",
10+
"gists_url": "https://api.github.com/users/SpringMT/gists{/gist_id}",
11+
"starred_url": "https://api.github.com/users/SpringMT/starred{/owner}{/repo}",
12+
"subscriptions_url": "https://api.github.com/users/SpringMT/subscriptions",
13+
"organizations_url": "https://api.github.com/users/SpringMT/orgs",
14+
"repos_url": "https://api.github.com/users/SpringMT/repos",
15+
"events_url": "https://api.github.com/users/SpringMT/events{/privacy}",
16+
"received_events_url": "https://api.github.com/users/SpringMT/received_events",
17+
"type": "User",
18+
"site_admin": false
19+
}

spec/zstd-ruby-using-dict_spec.rb

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
require "spec_helper"
2+
require 'zstd-ruby'
3+
require 'securerandom'
4+
5+
# Generate dictionay methods
6+
# https://github.com/facebook/zstd#the-case-for-small-data-compression
7+
# https://github.com/facebook/zstd/releases/tag/v1.1.3
8+
9+
RSpec.describe Zstd do
10+
describe 'compress_using_dict' do
11+
let(:user_json) do
12+
IO.read("#{__dir__}/user_springmt.json")
13+
end
14+
let(:dictionary) do
15+
IO.read("#{__dir__}/dictionary")
16+
end
17+
18+
it 'should work' do
19+
compressed_using_dict = Zstd.compress_using_dict(user_json, dictionary)
20+
compressed = Zstd.compress(user_json)
21+
expect(compressed_using_dict.length).to be < compressed.length
22+
expect(user_json).to eq(Zstd.decompress_using_dict(compressed_using_dict, dictionary))
23+
end
24+
25+
it 'should work with simple string' do
26+
compressed_using_dict = Zstd.compress_using_dict("abc", dictionary)
27+
expect("abc").to eq(Zstd.decompress_using_dict(compressed_using_dict, dictionary))
28+
end
29+
30+
it 'should work with blank' do
31+
compressed_using_dict = Zstd.compress_using_dict("", dictionary)
32+
expect("").to eq(Zstd.decompress_using_dict(compressed_using_dict, dictionary))
33+
end
34+
35+
it 'should support compression levels' do
36+
compressed_using_dict = Zstd.compress_using_dict(user_json, dictionary)
37+
compressed_using_dict_10 = Zstd.compress_using_dict(user_json, dictionary, 10)
38+
expect(compressed_using_dict_10.length).to be < compressed_using_dict.length
39+
end
40+
end
41+
42+
end
43+

0 commit comments

Comments
 (0)