Skip to content

Commit 958f88d

Browse files
committed
sha3: faster checks on RELEASE builds, safer ones on DEBUG builds
1 parent eb99719 commit 958f88d

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

Modules/sha3module.c

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
#endif
2222

2323
#include "Python.h"
24-
#include "pycore_strhex.h" // _Py_strhex()
25-
#include "pycore_typeobject.h" // _PyType_GetModuleState()
24+
#include "pycore_moduleobject.h" // _PyModule_GetState()
25+
#include "pycore_strhex.h" // _Py_strhex()
26+
#include "pycore_typeobject.h" // _PyType_GetModuleState()
27+
2628
#include "hashlib.h"
2729

2830
#include "_hacl/Hacl_Hash_SHA3.h"
@@ -42,21 +44,32 @@
4244

4345
// --- Module state -----------------------------------------------------------
4446

47+
static struct PyModuleDef sha3module_def;
48+
4549
typedef struct {
4650
PyTypeObject *sha3_224_type;
4751
PyTypeObject *sha3_256_type;
4852
PyTypeObject *sha3_384_type;
4953
PyTypeObject *sha3_512_type;
5054
PyTypeObject *shake_128_type;
5155
PyTypeObject *shake_256_type;
52-
} SHA3State;
56+
} sha3module_state;
57+
58+
static inline sha3module_state *
59+
get_sha3module_state(PyObject *module)
60+
{
61+
void *state = _PyModule_GetState(module);
62+
assert(state != NULL);
63+
return (sha3module_state *)state;
64+
}
5365

54-
static inline SHA3State*
55-
sha3_get_state(PyObject *module)
66+
static inline sha3module_state *
67+
get_sha3module_state_by_cls(PyTypeObject *cls)
5668
{
57-
void *state = PyModule_GetState(module);
69+
_Py_hashlib_check_exported_type(cls, &sha3module_def);
70+
void *state = _PyType_GetModuleState(cls);
5871
assert(state != NULL);
59-
return (SHA3State *)state;
72+
return (sha3module_state *)state;
6073
}
6174

6275
// --- Module objects ---------------------------------------------------------
@@ -90,6 +103,7 @@ class _sha3.shake_256 "SHA3object *" "&PyType_Type"
90103
static SHA3object *
91104
newSHA3object(PyTypeObject *type)
92105
{
106+
_Py_hashlib_check_exported_type(type, &sha3module_def);
93107
SHA3object *newobj = PyObject_GC_New(SHA3object, type);
94108
if (newobj == NULL) {
95109
return NULL;
@@ -142,14 +156,12 @@ py_sha3_new_impl(PyTypeObject *type, PyObject *data_obj, int usedforsecurity,
142156
}
143157

144158
Py_buffer buf = {NULL, NULL};
145-
SHA3State *state = _PyType_GetModuleState(type);
146159
SHA3object *self = newSHA3object(type);
147160
if (self == NULL) {
148161
goto error;
149162
}
150163

151-
assert(state != NULL);
152-
164+
sha3module_state *state = get_sha3module_state_by_cls(type);
153165
if (type == state->sha3_224_type) {
154166
self->hash_state = Hacl_Hash_SHA3_malloc(Spec_Hash_Definitions_SHA3_224);
155167
}
@@ -349,7 +361,7 @@ SHA3_get_name(PyObject *self, void *Py_UNUSED(closure))
349361
{
350362
PyTypeObject *type = Py_TYPE(self);
351363

352-
SHA3State *state = _PyType_GetModuleState(type);
364+
sha3module_state *state = get_sha3module_state_by_cls(type);
353365
assert(state != NULL);
354366

355367
if (type == state->sha3_224_type) {
@@ -617,7 +629,7 @@ SHA3_TYPE_SPEC(SHAKE256_spec, "shake_256", SHAKE256slots);
617629
static int
618630
_sha3_traverse(PyObject *module, visitproc visit, void *arg)
619631
{
620-
SHA3State *state = sha3_get_state(module);
632+
sha3module_state *state = get_sha3module_state(module);
621633
Py_VISIT(state->sha3_224_type);
622634
Py_VISIT(state->sha3_256_type);
623635
Py_VISIT(state->sha3_384_type);
@@ -630,7 +642,7 @@ _sha3_traverse(PyObject *module, visitproc visit, void *arg)
630642
static int
631643
_sha3_clear(PyObject *module)
632644
{
633-
SHA3State *state = sha3_get_state(module);
645+
sha3module_state *state = get_sha3module_state(module);
634646
Py_CLEAR(state->sha3_224_type);
635647
Py_CLEAR(state->sha3_256_type);
636648
Py_CLEAR(state->sha3_384_type);
@@ -649,7 +661,7 @@ _sha3_free(void *module)
649661
static int
650662
_sha3_exec(PyObject *m)
651663
{
652-
SHA3State *st = sha3_get_state(m);
664+
sha3module_state *st = get_sha3module_state(m);
653665

654666
#define init_sha3type(type, typespec) \
655667
do { \
@@ -689,10 +701,10 @@ static PyModuleDef_Slot _sha3_slots[] = {
689701
};
690702

691703
/* Initialize this module. */
692-
static struct PyModuleDef _sha3module = {
704+
static struct PyModuleDef sha3module_def = {
693705
PyModuleDef_HEAD_INIT,
694706
.m_name = "_sha3",
695-
.m_size = sizeof(SHA3State),
707+
.m_size = sizeof(sha3module_state),
696708
.m_slots = _sha3_slots,
697709
.m_traverse = _sha3_traverse,
698710
.m_clear = _sha3_clear,
@@ -703,5 +715,5 @@ static struct PyModuleDef _sha3module = {
703715
PyMODINIT_FUNC
704716
PyInit__sha3(void)
705717
{
706-
return PyModuleDef_Init(&_sha3module);
718+
return PyModuleDef_Init(&sha3module_def);
707719
}

0 commit comments

Comments
 (0)