Skip to content

Commit ec31019

Browse files
committed
Put locks around ZstdDict usage
This also reverts the setdefault usage which we don't need anymore and refactors the _zstd_load_(c,d)_dict functions to not use goto/properly lock around dictionary usage.
1 parent 84c5b30 commit ec31019

File tree

4 files changed

+114
-93
lines changed

4 files changed

+114
-93
lines changed

Modules/_zstd/compressor.c

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,11 @@ _get_CDict(ZstdDict *self, int compressionLevel)
165165
}
166166

167167
/* Get PyCapsule object from self->c_dicts */
168-
int result = PyDict_GetItemRef(self->c_dicts, level, &capsule);
169-
if (result < 0) {
170-
goto error;
171-
}
172-
168+
capsule = PyDict_GetItemWithError(self->c_dicts, level);
173169
if (capsule == NULL) {
170+
if (PyErr_Occurred()) {
171+
goto error;
172+
}
174173
/* Create ZSTD_CDict instance */
175174
char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
176175
Py_ssize_t dict_len = Py_SIZE(self->dict_content);
@@ -198,18 +197,15 @@ _get_CDict(ZstdDict *self, int compressionLevel)
198197
}
199198

200199
/* Add PyCapsule object to self->c_dicts if not already inserted */
201-
PyObject *capsule_value;
202-
int result = PyDict_SetDefaultRef(self->c_dicts, level, capsule,
203-
&capsule_value);
204-
if (result < 0) {
200+
if (PyDict_SetItem(self->c_dicts, level, capsule) < 0) {
201+
Py_DECREF(capsule);
205202
goto error;
206203
}
207-
Py_XDECREF(capsule_value);
204+
Py_DECREF(capsule);
208205
}
209206
else {
210207
/* ZSTD_CDict instance already exists */
211208
cdict = PyCapsule_GetPointer(capsule, NULL);
212-
Py_DECREF(capsule);
213209
}
214210
goto success;
215211

@@ -221,10 +217,50 @@ _get_CDict(ZstdDict *self, int compressionLevel)
221217
}
222218

223219
static int
224-
_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
220+
_zstd_load_impl(ZstdCompressor *self, ZstdDict *zd,
221+
_zstd_state *mod_state, int type)
225222
{
226-
227223
size_t zstd_ret;
224+
if (type == DICT_TYPE_DIGESTED) {
225+
/* Get ZSTD_CDict */
226+
ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level);
227+
if (c_dict == NULL) {
228+
return -1;
229+
}
230+
/* Reference a prepared dictionary.
231+
It overrides some compression context's parameters. */
232+
zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
233+
}
234+
else if (type == DICT_TYPE_UNDIGESTED) {
235+
/* Load a dictionary.
236+
It doesn't override compression context's parameters. */
237+
zstd_ret = ZSTD_CCtx_loadDictionary(
238+
self->cctx,
239+
PyBytes_AS_STRING(zd->dict_content),
240+
Py_SIZE(zd->dict_content));
241+
}
242+
else if (type == DICT_TYPE_PREFIX) {
243+
/* Load a prefix */
244+
zstd_ret = ZSTD_CCtx_refPrefix(
245+
self->cctx,
246+
PyBytes_AS_STRING(zd->dict_content),
247+
Py_SIZE(zd->dict_content));
248+
}
249+
else {
250+
Py_UNREACHABLE();
251+
}
252+
253+
/* Check error */
254+
if (ZSTD_isError(zstd_ret)) {
255+
set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret);
256+
return -1;
257+
}
258+
return 0;
259+
}
260+
261+
static int
262+
_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
263+
{
228264
_zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
229265
if (mod_state == NULL) {
230266
return -1;
@@ -241,7 +277,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
241277
/* When compressing, use undigested dictionary by default. */
242278
zd = (ZstdDict*)dict;
243279
type = DICT_TYPE_UNDIGESTED;
244-
goto load;
280+
PyMutex_Lock(&zd->lock);
281+
ret = _zstd_load_impl(self, zd, mod_state, type);
282+
PyMutex_Unlock(&zd->lock);
283+
return ret;
245284
}
246285

247286
/* Check (ZstdDict, type) */
@@ -261,7 +300,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
261300
{
262301
assert(type >= 0);
263302
zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
264-
goto load;
303+
PyMutex_Lock(&zd->lock);
304+
ret = _zstd_load_impl(self, zd, mod_state, type);
305+
PyMutex_Unlock(&zd->lock);
306+
return ret;
265307
}
266308
}
267309
}
@@ -270,43 +312,6 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
270312
PyErr_SetString(PyExc_TypeError,
271313
"zstd_dict argument should be ZstdDict object.");
272314
return -1;
273-
274-
load:
275-
if (type == DICT_TYPE_DIGESTED) {
276-
/* Get ZSTD_CDict */
277-
ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level);
278-
if (c_dict == NULL) {
279-
return -1;
280-
}
281-
/* Reference a prepared dictionary.
282-
It overrides some compression context's parameters. */
283-
zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
284-
}
285-
else if (type == DICT_TYPE_UNDIGESTED) {
286-
/* Load a dictionary.
287-
It doesn't override compression context's parameters. */
288-
zstd_ret = ZSTD_CCtx_loadDictionary(
289-
self->cctx,
290-
PyBytes_AS_STRING(zd->dict_content),
291-
Py_SIZE(zd->dict_content));
292-
}
293-
else if (type == DICT_TYPE_PREFIX) {
294-
/* Load a prefix */
295-
zstd_ret = ZSTD_CCtx_refPrefix(
296-
self->cctx,
297-
PyBytes_AS_STRING(zd->dict_content),
298-
Py_SIZE(zd->dict_content));
299-
}
300-
else {
301-
Py_UNREACHABLE();
302-
}
303-
304-
/* Check error */
305-
if (ZSTD_isError(zstd_ret)) {
306-
set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret);
307-
return -1;
308-
}
309-
return 0;
310315
}
311316

312317
/*[clinic input]

Modules/_zstd/decompressor.c

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,53 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
147147
return 0;
148148
}
149149

150+
static int
151+
_zstd_load_impl(ZstdDecompressor *self, ZstdDict *zd,
152+
_zstd_state *mod_state, int type)
153+
{
154+
size_t zstd_ret;
155+
if (type == DICT_TYPE_DIGESTED) {
156+
/* Get ZSTD_DDict */
157+
ZSTD_DDict *d_dict = _get_DDict(zd);
158+
if (d_dict == NULL) {
159+
return -1;
160+
}
161+
/* Reference a prepared dictionary */
162+
zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
163+
}
164+
else if (type == DICT_TYPE_UNDIGESTED) {
165+
/* Load a dictionary */
166+
zstd_ret = ZSTD_DCtx_loadDictionary(
167+
self->dctx,
168+
PyBytes_AS_STRING(zd->dict_content),
169+
Py_SIZE(zd->dict_content));
170+
}
171+
else if (type == DICT_TYPE_PREFIX) {
172+
/* Load a prefix */
173+
zstd_ret = ZSTD_DCtx_refPrefix(
174+
self->dctx,
175+
PyBytes_AS_STRING(zd->dict_content),
176+
Py_SIZE(zd->dict_content));
177+
}
178+
else {
179+
/* Impossible code path */
180+
PyErr_SetString(PyExc_SystemError,
181+
"load_d_dict() impossible code path");
182+
return -1;
183+
}
184+
185+
/* Check error */
186+
if (ZSTD_isError(zstd_ret)) {
187+
set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret);
188+
return -1;
189+
}
190+
return 0;
191+
}
192+
150193
/* Load dictionary or prefix to decompression context */
151194
static int
152195
_zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
153196
{
154-
size_t zstd_ret;
155197
_zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
156198
if (mod_state == NULL) {
157199
return -1;
@@ -168,7 +210,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
168210
/* When decompressing, use digested dictionary by default. */
169211
zd = (ZstdDict*)dict;
170212
type = DICT_TYPE_DIGESTED;
171-
goto load;
213+
PyMutex_Lock(&zd->lock);
214+
ret = _zstd_load_impl(self, zd, mod_state, type);
215+
PyMutex_Unlock(&zd->lock);
216+
return ret;
172217
}
173218

174219
/* Check (ZstdDict, type) */
@@ -188,7 +233,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
188233
{
189234
assert(type >= 0);
190235
zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
191-
goto load;
236+
PyMutex_Lock(&zd->lock);
237+
ret = _zstd_load_impl(self, zd, mod_state, type);
238+
PyMutex_Unlock(&zd->lock);
239+
return ret;
192240
}
193241
}
194242
}
@@ -197,44 +245,6 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
197245
PyErr_SetString(PyExc_TypeError,
198246
"zstd_dict argument should be ZstdDict object.");
199247
return -1;
200-
201-
load:
202-
if (type == DICT_TYPE_DIGESTED) {
203-
/* Get ZSTD_DDict */
204-
ZSTD_DDict *d_dict = _get_DDict(zd);
205-
if (d_dict == NULL) {
206-
return -1;
207-
}
208-
/* Reference a prepared dictionary */
209-
zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
210-
}
211-
else if (type == DICT_TYPE_UNDIGESTED) {
212-
/* Load a dictionary */
213-
zstd_ret = ZSTD_DCtx_loadDictionary(
214-
self->dctx,
215-
PyBytes_AS_STRING(zd->dict_content),
216-
Py_SIZE(zd->dict_content));
217-
}
218-
else if (type == DICT_TYPE_PREFIX) {
219-
/* Load a prefix */
220-
zstd_ret = ZSTD_DCtx_refPrefix(
221-
self->dctx,
222-
PyBytes_AS_STRING(zd->dict_content),
223-
Py_SIZE(zd->dict_content));
224-
}
225-
else {
226-
/* Impossible code path */
227-
PyErr_SetString(PyExc_SystemError,
228-
"load_d_dict() impossible code path");
229-
return -1;
230-
}
231-
232-
/* Check error */
233-
if (ZSTD_isError(zstd_ret)) {
234-
set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret);
235-
return -1;
236-
}
237-
return 0;
238248
}
239249

240250
/*

Modules/_zstd/zstddict.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ _zstd_ZstdDict_new_impl(PyTypeObject *type, PyObject *dict_content,
5353
self->dict_content = NULL;
5454
self->d_dict = NULL;
5555
self->dict_id = 0;
56+
self->lock = (PyMutex){0};
5657

5758
/* ZSTD_CDict dict */
5859
self->c_dicts = PyDict_New();
@@ -109,6 +110,8 @@ ZstdDict_dealloc(PyObject *ob)
109110
ZSTD_freeDDict(self->d_dict);
110111
}
111112

113+
assert(!PyMutex_IsLocked(&self->lock));
114+
112115
/* Release dict_content after Free ZSTD_CDict/ZSTD_DDict instances */
113116
Py_CLEAR(self->dict_content);
114117
Py_CLEAR(self->c_dicts);

Modules/_zstd/zstddict.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ typedef struct {
1919
PyObject *dict_content;
2020
/* Dictionary id */
2121
uint32_t dict_id;
22+
23+
/* Lock to protect the digested dictionaries */
24+
PyMutex lock;
2225
} ZstdDict;
2326

2427
#endif // !ZSTD_DICT_H

0 commit comments

Comments
 (0)