@@ -8,41 +8,74 @@ static VALUE zstdVersion(VALUE self)
8
8
return INT2NUM (version );
9
9
}
10
10
11
- static VALUE compress (int argc , VALUE * argv , VALUE self )
11
+ static VALUE rb_compress (int argc , VALUE * argv , VALUE self )
12
12
{
13
13
VALUE input_value ;
14
14
VALUE compression_level_value ;
15
15
rb_scan_args (argc , argv , "11" , & input_value , & compression_level_value );
16
+ int compression_level = convert_compression_level (compression_level_value );
16
17
17
18
StringValue (input_value );
18
- const char * input_data = RSTRING_PTR (input_value );
19
+ char * input_data = RSTRING_PTR (input_value );
19
20
size_t input_size = RSTRING_LEN (input_value );
21
+ size_t max_compressed_size = ZSTD_compressBound (input_size );
20
22
21
- int compression_level ;
22
- if (NIL_P (compression_level_value )) {
23
- compression_level = 0 ; // The default. See ZSTD_CLEVEL_DEFAULT in zstd_compress.c
24
- } else {
25
- compression_level = NUM2INT (compression_level_value );
23
+ VALUE output = rb_str_new (NULL , max_compressed_size );
24
+ char * output_data = RSTRING_PTR (output );
25
+ size_t compressed_size = ZSTD_compress ((void * )output_data , max_compressed_size ,
26
+ (void * )input_data , input_size , compression_level );
27
+ if (ZSTD_isError (compressed_size )) {
28
+ rb_raise (rb_eRuntimeError , "%s: %s" , "compress failed" , ZSTD_getErrorName (compressed_size ));
26
29
}
27
30
28
- // do compress
31
+ rb_str_resize (output , compressed_size );
32
+ return output ;
33
+ }
34
+
35
+ static VALUE rb_compress_using_dict (int argc , VALUE * argv , VALUE self )
36
+ {
37
+ VALUE input_value ;
38
+ VALUE dict ;
39
+ VALUE compression_level_value ;
40
+ rb_scan_args (argc , argv , "21" , & input_value , & dict , & compression_level_value );
41
+ int compression_level = convert_compression_level (compression_level_value );
42
+
43
+ StringValue (input_value );
44
+ char * input_data = RSTRING_PTR (input_value );
45
+ size_t input_size = RSTRING_LEN (input_value );
29
46
size_t max_compressed_size = ZSTD_compressBound (input_size );
30
47
48
+ char * dict_buffer = RSTRING_PTR (dict );
49
+ size_t dict_size = RSTRING_LEN (dict );
50
+
51
+ ZSTD_CDict * const cdict = ZSTD_createCDict (dict_buffer , dict_size , compression_level );
52
+ if (cdict == NULL ) {
53
+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createCDict failed" );
54
+ }
55
+ ZSTD_CCtx * const ctx = ZSTD_createCCtx ();
56
+ if (ctx == NULL ) {
57
+ ZSTD_freeCDict (cdict );
58
+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createCCtx failed" );
59
+ }
60
+
31
61
VALUE output = rb_str_new (NULL , max_compressed_size );
32
62
char * output_data = RSTRING_PTR (output );
33
-
34
- size_t compressed_size = ZSTD_compress ((void * )output_data , max_compressed_size ,
35
- (const void * )input_data , input_size , compression_level );
63
+ size_t const compressed_size = ZSTD_compress_usingCDict (ctx , (void * )output_data , max_compressed_size ,
64
+ (void * )input_data , input_size , cdict );
36
65
37
66
if (ZSTD_isError (compressed_size )) {
67
+ ZSTD_freeCDict (cdict );
68
+ ZSTD_freeCCtx (ctx );
38
69
rb_raise (rb_eRuntimeError , "%s: %s" , "compress failed" , ZSTD_getErrorName (compressed_size ));
39
- } else {
40
- rb_str_resize (output , compressed_size );
41
70
}
42
71
72
+ rb_str_resize (output , compressed_size );
73
+ ZSTD_freeCDict (cdict );
74
+ ZSTD_freeCCtx (ctx );
43
75
return output ;
44
76
}
45
77
78
+
46
79
static VALUE decompress_buffered (const char * input_data , size_t input_size )
47
80
{
48
81
const size_t outputBufferSize = 4096 ;
@@ -58,7 +91,6 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
58
91
rb_raise (rb_eRuntimeError , "%s: %s" , "ZSTD_initDStream failed" , ZSTD_getErrorName (initResult ));
59
92
}
60
93
61
-
62
94
VALUE output_string = rb_str_new (NULL , 0 );
63
95
ZSTD_outBuffer output = { NULL , 0 , 0 };
64
96
@@ -80,23 +112,24 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
80
112
return output_string ;
81
113
}
82
114
83
- static VALUE decompress (VALUE self , VALUE input )
115
+ static VALUE rb_decompress (VALUE self , VALUE input_value )
84
116
{
85
- StringValue (input );
86
- const char * input_data = RSTRING_PTR (input );
87
- size_t input_size = RSTRING_LEN (input );
88
-
89
- uint64_t uncompressed_size = ZSTD_getDecompressedSize (input_data , input_size );
117
+ StringValue (input_value );
118
+ char * input_data = RSTRING_PTR (input_value );
119
+ size_t input_size = RSTRING_LEN (input_value );
90
120
91
- if (uncompressed_size == 0 ) {
121
+ unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
122
+ if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
123
+ rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
124
+ }
125
+ if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
92
126
return decompress_buffered (input_data , input_size );
93
127
}
94
128
95
129
VALUE output = rb_str_new (NULL , uncompressed_size );
96
130
char * output_data = RSTRING_PTR (output );
97
-
98
- size_t decompress_size = ZSTD_decompress ((void * )output_data , uncompressed_size ,
99
- (const void * )input_data , input_size );
131
+ size_t const decompress_size = ZSTD_decompress ((void * )output_data , uncompressed_size ,
132
+ (void * )input_data , input_size );
100
133
101
134
if (ZSTD_isError (decompress_size )) {
102
135
rb_raise (rb_eRuntimeError , "%s: %s" , "decompress error" , ZSTD_getErrorName (decompress_size ));
@@ -105,10 +138,61 @@ static VALUE decompress(VALUE self, VALUE input)
105
138
return output ;
106
139
}
107
140
141
+ static VALUE rb_decompress_using_dict (int argc , VALUE * argv , VALUE self )
142
+ {
143
+ VALUE input_value ;
144
+ VALUE dict ;
145
+ rb_scan_args (argc , argv , "20" , & input_value , & dict );
146
+
147
+ StringValue (input_value );
148
+ char * input_data = RSTRING_PTR (input_value );
149
+ size_t input_size = RSTRING_LEN (input_value );
150
+ unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
151
+ if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
152
+ rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
153
+ }
154
+ if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
155
+ return decompress_buffered (input_data , input_size );
156
+ }
157
+ VALUE output = rb_str_new (NULL , uncompressed_size );
158
+ char * output_data = RSTRING_PTR (output );
159
+
160
+ char * dict_buffer = RSTRING_PTR (dict );
161
+ size_t dict_size = RSTRING_LEN (dict );
162
+ ZSTD_DDict * const ddict = ZSTD_createDDict (dict_buffer , dict_size );
163
+ if (ddict == NULL ) {
164
+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDDict failed" );
165
+ }
166
+
167
+ unsigned const expected_dict_id = ZSTD_getDictID_fromDDict (ddict );
168
+ unsigned const actual_dict_id = ZSTD_getDictID_fromFrame (input_data , input_size );
169
+ if (expected_dict_id != actual_dict_id ) {
170
+ ZSTD_freeDDict (ddict );
171
+ rb_raise (rb_eRuntimeError , "%s: %s" , "DictID mismatch" , ZSTD_getErrorName (uncompressed_size ));
172
+ }
173
+
174
+ ZSTD_DCtx * const ctx = ZSTD_createDCtx ();
175
+ if (ctx == NULL ) {
176
+ ZSTD_freeDDict (ddict );
177
+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDCtx failed" );
178
+ }
179
+ size_t const decompress_size = ZSTD_decompress_usingDDict (ctx , output_data , uncompressed_size , input_data , input_size , ddict );
180
+ if (ZSTD_isError (decompress_size )) {
181
+ ZSTD_freeDDict (ddict );
182
+ ZSTD_freeDCtx (ctx );
183
+ rb_raise (rb_eRuntimeError , "%s: %s" , "decompress error" , ZSTD_getErrorName (decompress_size ));
184
+ }
185
+ ZSTD_freeDDict (ddict );
186
+ ZSTD_freeDCtx (ctx );
187
+ return output ;
188
+ }
189
+
108
190
void
109
191
zstd_ruby_init (void )
110
192
{
111
193
rb_define_module_function (rb_mZstd , "zstd_version" , zstdVersion , 0 );
112
- rb_define_module_function (rb_mZstd , "compress" , compress , -1 );
113
- rb_define_module_function (rb_mZstd , "decompress" , decompress , 1 );
194
+ rb_define_module_function (rb_mZstd , "compress" , rb_compress , -1 );
195
+ rb_define_module_function (rb_mZstd , "compress_using_dict" , rb_compress_using_dict , -1 );
196
+ rb_define_module_function (rb_mZstd , "decompress" , rb_decompress , 1 );
197
+ rb_define_module_function (rb_mZstd , "decompress_using_dict" , rb_decompress_using_dict , -1 );
114
198
}
0 commit comments