2121#endif
2222
2323#include "Python.h"
24- #include "pycore_strhex.h" // _Py_strhex()
25- #include "pycore_typeobject.h" // _PyType_GetModuleState()
24+ #include "pycore_moduleobject.h" // _PyModule_GetState()
25+ #include "pycore_strhex.h" // _Py_strhex()
26+ #include "pycore_typeobject.h" // _PyType_GetModuleState()
27+
2628#include "hashlib.h"
2729
2830#include "_hacl/Hacl_Hash_SHA3.h"
4244
4345// --- Module state -----------------------------------------------------------
4446
47+ static struct PyModuleDef sha3module_def ;
48+
4549typedef struct {
4650 PyTypeObject * sha3_224_type ;
4751 PyTypeObject * sha3_256_type ;
4852 PyTypeObject * sha3_384_type ;
4953 PyTypeObject * sha3_512_type ;
5054 PyTypeObject * shake_128_type ;
5155 PyTypeObject * shake_256_type ;
52- } SHA3State ;
56+ } sha3module_state ;
57+
58+ static inline sha3module_state *
59+ get_sha3module_state (PyObject * module )
60+ {
61+ void * state = _PyModule_GetState (module );
62+ assert (state != NULL );
63+ return (sha3module_state * )state ;
64+ }
5365
54- static inline SHA3State *
55- sha3_get_state ( PyObject * module )
66+ static inline sha3module_state *
67+ get_sha3module_state_by_cls ( PyTypeObject * cls )
5668{
57- void * state = PyModule_GetState (module );
69+ _Py_hashlib_check_exported_type (cls , & sha3module_def );
70+ void * state = _PyType_GetModuleState (cls );
5871 assert (state != NULL );
59- return (SHA3State * )state ;
72+ return (sha3module_state * )state ;
6073}
6174
6275// --- Module objects ---------------------------------------------------------
@@ -90,6 +103,7 @@ class _sha3.shake_256 "SHA3object *" "&PyType_Type"
90103static SHA3object *
91104newSHA3object (PyTypeObject * type )
92105{
106+ _Py_hashlib_check_exported_type (type , & sha3module_def );
93107 SHA3object * newobj = PyObject_GC_New (SHA3object , type );
94108 if (newobj == NULL ) {
95109 return NULL ;
@@ -142,14 +156,12 @@ py_sha3_new_impl(PyTypeObject *type, PyObject *data_obj, int usedforsecurity,
142156 }
143157
144158 Py_buffer buf = {NULL , NULL };
145- SHA3State * state = _PyType_GetModuleState (type );
146159 SHA3object * self = newSHA3object (type );
147160 if (self == NULL ) {
148161 goto error ;
149162 }
150163
151- assert (state != NULL );
152-
164+ sha3module_state * state = get_sha3module_state_by_cls (type );
153165 if (type == state -> sha3_224_type ) {
154166 self -> hash_state = Hacl_Hash_SHA3_malloc (Spec_Hash_Definitions_SHA3_224 );
155167 }
@@ -349,7 +361,7 @@ SHA3_get_name(PyObject *self, void *Py_UNUSED(closure))
349361{
350362 PyTypeObject * type = Py_TYPE (self );
351363
352- SHA3State * state = _PyType_GetModuleState (type );
364+ sha3module_state * state = get_sha3module_state_by_cls (type );
353365 assert (state != NULL );
354366
355367 if (type == state -> sha3_224_type ) {
@@ -617,7 +629,7 @@ SHA3_TYPE_SPEC(SHAKE256_spec, "shake_256", SHAKE256slots);
617629static int
618630_sha3_traverse (PyObject * module , visitproc visit , void * arg )
619631{
620- SHA3State * state = sha3_get_state (module );
632+ sha3module_state * state = get_sha3module_state (module );
621633 Py_VISIT (state -> sha3_224_type );
622634 Py_VISIT (state -> sha3_256_type );
623635 Py_VISIT (state -> sha3_384_type );
@@ -630,7 +642,7 @@ _sha3_traverse(PyObject *module, visitproc visit, void *arg)
630642static int
631643_sha3_clear (PyObject * module )
632644{
633- SHA3State * state = sha3_get_state (module );
645+ sha3module_state * state = get_sha3module_state (module );
634646 Py_CLEAR (state -> sha3_224_type );
635647 Py_CLEAR (state -> sha3_256_type );
636648 Py_CLEAR (state -> sha3_384_type );
@@ -649,7 +661,7 @@ _sha3_free(void *module)
649661static int
650662_sha3_exec (PyObject * m )
651663{
652- SHA3State * st = sha3_get_state (m );
664+ sha3module_state * st = get_sha3module_state (m );
653665
654666#define init_sha3type (type , typespec ) \
655667 do { \
@@ -689,10 +701,10 @@ static PyModuleDef_Slot _sha3_slots[] = {
689701};
690702
691703/* Initialize this module. */
692- static struct PyModuleDef _sha3module = {
704+ static struct PyModuleDef sha3module_def = {
693705 PyModuleDef_HEAD_INIT ,
694706 .m_name = "_sha3" ,
695- .m_size = sizeof (SHA3State ),
707+ .m_size = sizeof (sha3module_state ),
696708 .m_slots = _sha3_slots ,
697709 .m_traverse = _sha3_traverse ,
698710 .m_clear = _sha3_clear ,
@@ -703,5 +715,5 @@ static struct PyModuleDef _sha3module = {
703715PyMODINIT_FUNC
704716PyInit__sha3 (void )
705717{
706- return PyModuleDef_Init (& _sha3module );
718+ return PyModuleDef_Init (& sha3module_def );
707719}
0 commit comments