@@ -39,61 +39,91 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
39
39
return output ;
40
40
}
41
41
42
- static VALUE decompress_buffered (ZSTD_DCtx * dctx , const char * input_data , size_t input_size )
43
- {
44
- ZSTD_inBuffer input = { input_data , input_size , 0 };
45
- VALUE result = rb_str_new (0 , 0 );
42
+ static VALUE decode_one_frame (ZSTD_DCtx * dctx , const unsigned char * src , size_t size , VALUE kwargs ) {
43
+ VALUE out = rb_str_buf_new (0 );
44
+ size_t cap = ZSTD_DStreamOutSize ();
45
+ char * buf = ALLOC_N (char , cap );
46
+ ZSTD_inBuffer in = (ZSTD_inBuffer ){ src , size , 0 };
46
47
47
- while (input .pos < input .size ) {
48
- ZSTD_outBuffer output = { NULL , 0 , 0 };
49
- output .size += ZSTD_DStreamOutSize ();
50
- VALUE output_string = rb_str_new (NULL , output .size );
51
- output .dst = RSTRING_PTR (output_string );
48
+ ZSTD_DCtx_reset (dctx , ZSTD_reset_session_only );
49
+ set_decompress_params (dctx , kwargs );
52
50
53
- size_t ret = zstd_stream_decompress (dctx , & output , & input , false);
51
+ for (;;) {
52
+ ZSTD_outBuffer o = (ZSTD_outBuffer ){ buf , cap , 0 };
53
+ size_t ret = ZSTD_decompressStream (dctx , & o , & in );
54
54
if (ZSTD_isError (ret )) {
55
- ZSTD_freeDCtx (dctx );
56
- rb_raise (rb_eRuntimeError , "%s: %s" , "ZSTD_decompressStream failed" , ZSTD_getErrorName (ret ));
55
+ xfree (buf );
56
+ rb_raise (rb_eRuntimeError , "ZSTD_decompressStream failed: %s" , ZSTD_getErrorName (ret ));
57
+ }
58
+ if (o .pos ) {
59
+ rb_str_cat (out , buf , o .pos );
60
+ }
61
+ if (ret == 0 ) {
62
+ break ;
57
63
}
58
- rb_str_cat (result , output .dst , output .pos );
59
64
}
60
- ZSTD_freeDCtx (dctx );
61
- return result ;
65
+ xfree (buf );
66
+ return out ;
67
+ }
68
+
69
+ static VALUE decompress_buffered (ZSTD_DCtx * dctx , const char * data , size_t len ) {
70
+ return decode_one_frame (dctx , (const unsigned char * )data , len , Qnil );
62
71
}
63
72
64
73
static VALUE rb_decompress (int argc , VALUE * argv , VALUE self )
65
74
{
66
- VALUE input_value ;
67
- VALUE kwargs ;
75
+ VALUE input_value , kwargs ;
68
76
rb_scan_args (argc , argv , "10:" , & input_value , & kwargs );
69
77
StringValue (input_value );
70
- char * input_data = RSTRING_PTR (input_value );
71
- size_t input_size = RSTRING_LEN (input_value );
72
- ZSTD_DCtx * const dctx = ZSTD_createDCtx ();
73
- if (dctx == NULL ) {
74
- rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDCtx failed" );
75
- }
76
- set_decompress_params (dctx , kwargs );
77
78
78
- unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
79
- if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
80
- rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
81
- }
82
- // ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness.
83
- // Therefore, we will not standardize on ZSTD_decompressStream
84
- if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
85
- return decompress_buffered (dctx , input_data , input_size );
86
- }
79
+ size_t in_size = RSTRING_LEN (input_value );
80
+ const unsigned char * in_r = (const unsigned char * )RSTRING_PTR (input_value );
81
+ unsigned char * in = ALLOC_N (unsigned char , in_size );
82
+ memcpy (in , in_r , in_size );
83
+
84
+ size_t off = 0 ;
85
+ const uint32_t ZSTD_MAGIC = 0xFD2FB528U ;
86
+ const uint32_t SKIP_LO = 0x184D2A50U ; /* ...5F */
87
+
88
+ while (off + 4 <= in_size ) {
89
+ uint32_t magic = (uint32_t )in [off ]
90
+ | ((uint32_t )in [off + 1 ] << 8 )
91
+ | ((uint32_t )in [off + 2 ] << 16 )
92
+ | ((uint32_t )in [off + 3 ] << 24 );
93
+
94
+ if ((magic & 0xFFFFFFF0U ) == (SKIP_LO & 0xFFFFFFF0U )) {
95
+ if (off + 8 > in_size ) break ;
96
+ uint32_t skipLen = (uint32_t )in [off + 4 ]
97
+ | ((uint32_t )in [off + 5 ] << 8 )
98
+ | ((uint32_t )in [off + 6 ] << 16 )
99
+ | ((uint32_t )in [off + 7 ] << 24 );
100
+ size_t adv = (size_t )8 + (size_t )skipLen ;
101
+ if (off + adv > in_size ) break ;
102
+ off += adv ;
103
+ continue ;
104
+ }
87
105
88
- VALUE output = rb_str_new (NULL , uncompressed_size );
89
- char * output_data = RSTRING_PTR (output );
106
+ if (magic == ZSTD_MAGIC ) {
107
+ ZSTD_DCtx * dctx = ZSTD_createDCtx ();
108
+ if (!dctx ) {
109
+ xfree (in );
110
+ rb_raise (rb_eRuntimeError , "ZSTD_createDCtx failed" );
111
+ }
112
+
113
+ VALUE out = decode_one_frame (dctx , in + off , in_size - off , kwargs );
90
114
91
- size_t const decompress_size = zstd_decompress (dctx , output_data , uncompressed_size , input_data , input_size , false);
92
- if (ZSTD_isError (decompress_size )) {
93
- rb_raise (rb_eRuntimeError , "%s: %s" , "decompress error" , ZSTD_getErrorName (decompress_size ));
115
+ ZSTD_freeDCtx (dctx );
116
+ xfree (in );
117
+ RB_GC_GUARD (input_value );
118
+ return out ;
119
+ }
120
+
121
+ off += 1 ;
94
122
}
95
- ZSTD_freeDCtx (dctx );
96
- return output ;
123
+
124
+ xfree (in );
125
+ RB_GC_GUARD (input_value );
126
+ rb_raise (rb_eRuntimeError , "not a zstd frame (magic not found)" );
97
127
}
98
128
99
129
static void free_cdict (void * dict )
0 commit comments