@@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
172172 return (_zstd_state * )state ;
173173}
174174
175+ static Py_ssize_t
176+ calculate_samples_stats (PyBytesObject * samples_bytes , PyObject * samples_sizes ,
177+ size_t * * chunk_sizes )
178+ {
179+ Py_ssize_t chunks_number ;
180+ Py_ssize_t sizes_sum ;
181+ Py_ssize_t i ;
182+
183+ chunks_number = Py_SIZE (samples_sizes );
184+ if ((size_t ) chunks_number > UINT32_MAX ) {
185+ PyErr_Format (PyExc_ValueError ,
186+ "The number of samples should be <= %u." , UINT32_MAX );
187+ return -1 ;
188+ }
189+
190+ /* Prepare chunk_sizes */
191+ * chunk_sizes = PyMem_New (size_t , chunks_number );
192+ if (* chunk_sizes == NULL ) {
193+ PyErr_NoMemory ();
194+ return -1 ;
195+ }
196+
197+ sizes_sum = 0 ;
198+ for (i = 0 ; i < chunks_number ; i ++ ) {
199+ PyObject * size = PyTuple_GetItem (samples_sizes , i );
200+ (* chunk_sizes )[i ] = PyLong_AsSize_t (size );
201+ if ((* chunk_sizes )[i ] == (size_t )-1 && PyErr_Occurred ()) {
202+ PyErr_Format (PyExc_ValueError ,
203+ "Items in samples_sizes should be an int "
204+ "object, with a value between 0 and %u." , SIZE_MAX );
205+ return -1 ;
206+ }
207+ sizes_sum += (* chunk_sizes )[i ];
208+ }
209+
210+ if (sizes_sum != Py_SIZE (samples_bytes )) {
211+ PyErr_SetString (PyExc_ValueError ,
212+ "The samples size tuple doesn't match the concatenation's size." );
213+ return -1 ;
214+ }
215+ return chunks_number ;
216+ }
217+
175218
176219/*[clinic input]
177220_zstd.train_dict
@@ -192,54 +235,25 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
192235 PyObject * samples_sizes , Py_ssize_t dict_size )
193236/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
194237{
195- // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
196- // are pretty similar. We should see if we can refactor them to share that code.
197- Py_ssize_t chunks_number ;
198- size_t * chunk_sizes = NULL ;
199238 PyObject * dst_dict_bytes = NULL ;
239+ size_t * chunk_sizes = NULL ;
240+ Py_ssize_t chunks_number ;
200241 size_t zstd_ret ;
201- Py_ssize_t sizes_sum ;
202- Py_ssize_t i ;
203242
204243 /* Check arguments */
205244 if (dict_size <= 0 ) {
206245 PyErr_SetString (PyExc_ValueError , "dict_size argument should be positive number." );
207246 return NULL ;
208247 }
209248
210- chunks_number = Py_SIZE (samples_sizes );
211- if ((size_t ) chunks_number > UINT32_MAX ) {
212- PyErr_Format (PyExc_ValueError ,
213- "The number of samples should be <= %u." , UINT32_MAX );
249+ /* Check that the samples are valid and get their sizes */
250+ chunks_number = calculate_samples_stats (samples_bytes , samples_sizes ,
251+ & chunk_sizes );
252+ if (chunks_number < 0 )
253+ {
214254 return NULL ;
215255 }
216256
217- /* Prepare chunk_sizes */
218- chunk_sizes = PyMem_New (size_t , chunks_number );
219- if (chunk_sizes == NULL ) {
220- PyErr_NoMemory ();
221- goto error ;
222- }
223-
224- sizes_sum = 0 ;
225- for (i = 0 ; i < chunks_number ; i ++ ) {
226- PyObject * size = PyTuple_GetItem (samples_sizes , i );
227- chunk_sizes [i ] = PyLong_AsSize_t (size );
228- if (chunk_sizes [i ] == (size_t )-1 && PyErr_Occurred ()) {
229- PyErr_Format (PyExc_ValueError ,
230- "Items in samples_sizes should be an int "
231- "object, with a value between 0 and %u." , SIZE_MAX );
232- goto error ;
233- }
234- sizes_sum += chunk_sizes [i ];
235- }
236-
237- if (sizes_sum != Py_SIZE (samples_bytes )) {
238- PyErr_SetString (PyExc_ValueError ,
239- "The samples size tuple doesn't match the concatenation's size." );
240- goto error ;
241- }
242-
243257 /* Allocate dict buffer */
244258 dst_dict_bytes = PyBytes_FromStringAndSize (NULL , dict_size );
245259 if (dst_dict_bytes == NULL ) {
@@ -307,48 +321,21 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
307321 PyObject * dst_dict_bytes = NULL ;
308322 size_t zstd_ret ;
309323 ZDICT_params_t params ;
310- Py_ssize_t sizes_sum ;
311- Py_ssize_t i ;
312324
313325 /* Check arguments */
314326 if (dict_size <= 0 ) {
315327 PyErr_SetString (PyExc_ValueError , "dict_size argument should be positive number." );
316328 return NULL ;
317329 }
318330
319- chunks_number = Py_SIZE (samples_sizes );
320- if ((size_t ) chunks_number > UINT32_MAX ) {
321- PyErr_Format (PyExc_ValueError ,
322- "The number of samples should be <= %u." , UINT32_MAX );
331+ /* Check that the samples are valid and get their sizes */
332+ chunks_number = calculate_samples_stats (samples_bytes , samples_sizes ,
333+ & chunk_sizes );
334+ if (chunks_number < 0 )
335+ {
323336 return NULL ;
324337 }
325338
326- /* Prepare chunk_sizes */
327- chunk_sizes = PyMem_New (size_t , chunks_number );
328- if (chunk_sizes == NULL ) {
329- PyErr_NoMemory ();
330- goto error ;
331- }
332-
333- sizes_sum = 0 ;
334- for (i = 0 ; i < chunks_number ; i ++ ) {
335- PyObject * size = PyTuple_GetItem (samples_sizes , i );
336- chunk_sizes [i ] = PyLong_AsSize_t (size );
337- if (chunk_sizes [i ] == (size_t )-1 && PyErr_Occurred ()) {
338- PyErr_Format (PyExc_ValueError ,
339- "Items in samples_sizes should be an int "
340- "object, with a value between 0 and %u." , SIZE_MAX );
341- goto error ;
342- }
343- sizes_sum += chunk_sizes [i ];
344- }
345-
346- if (sizes_sum != Py_SIZE (samples_bytes )) {
347- PyErr_SetString (PyExc_ValueError ,
348- "The samples size tuple doesn't match the concatenation's size." );
349- goto error ;
350- }
351-
352339 /* Allocate dict buffer */
353340 dst_dict_bytes = PyBytes_FromStringAndSize (NULL , dict_size );
354341 if (dst_dict_bytes == NULL ) {
0 commit comments