Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions Lib/test/test_capi/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ def test_getmodule(self):

nonexistent = 'nonexistent'
self.assertNotIn(nonexistent, sys.modules)
self.assertIsNone(_testlimitedcapi.PyImport_GetModule(nonexistent))
self.assertIsNone(_testlimitedcapi.PyImport_GetModule(''))
self.assertIsNone(_testlimitedcapi.PyImport_GetModule(object()))
for name in (nonexistent, '', object()):
with self.subTest(name=name):
with self.assertRaises(KeyError):
_testlimitedcapi.PyImport_GetModule(name)

# CRASHES PyImport_GetModule(NULL)

def check_addmodule(self, add_module, accept_nonstr=False):
# create a new module
Expand All @@ -84,6 +87,8 @@ def test_addmoduleobject(self):
self.check_addmodule(_testlimitedcapi.PyImport_AddModuleObject,
accept_nonstr=True)

# CRASHES PyImport_AddModuleObject(NULL)

def test_addmodule(self):
# Test PyImport_AddModule()
self.check_addmodule(_testlimitedcapi.PyImport_AddModule)
Expand All @@ -110,15 +115,22 @@ def test_import(self):
# Test PyImport_Import()
self.check_import_func(_testlimitedcapi.PyImport_Import)

with self.assertRaises(SystemError):
_testlimitedcapi.PyImport_Import(NULL)

def test_importmodule(self):
# Test PyImport_ImportModule()
self.check_import_func(_testlimitedcapi.PyImport_ImportModule)

# CRASHES PyImport_ImportModule(NULL)

def test_importmodulenoblock(self):
# Test deprecated PyImport_ImportModuleNoBlock()
with check_warnings(('', DeprecationWarning)):
self.check_import_func(_testlimitedcapi.PyImport_ImportModuleNoBlock)

# CRASHES PyImport_ImportModuleNoBlock(NULL)

def check_frozen_import(self, import_frozen_module):
# Importing a frozen module executes its code, so start by unloading
# the module to execute the code in a new (temporary) module.
Expand Down Expand Up @@ -199,12 +211,19 @@ def test_executecodemodule(self):

def test_executecodemoduleex(self):
# Test PyImport_ExecCodeModuleEx()
pathname = os.path.abspath('pathname')

def execute_code(name, code):
# Test NULL path (it should not crash)
def execute_code1(name, code):
return _testlimitedcapi.PyImport_ExecCodeModuleEx(name, code,
NULL)
self.check_executecodemodule(execute_code1)

# Test non-NULL path
pathname = os.path.abspath('pathname')
def execute_code2(name, code):
return _testlimitedcapi.PyImport_ExecCodeModuleEx(name, code,
pathname)
self.check_executecodemodule(execute_code, pathname)
self.check_executecodemodule(execute_code2, pathname)

def check_executecode_pathnames(self, execute_code_func):
# Test non-NULL pathname and NULL cpathname
Expand Down
31 changes: 21 additions & 10 deletions Modules/_testlimitedcapi/import.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ static PyObject *
pyimport_getmodule(PyObject *Py_UNUSED(module), PyObject *name)
{
assert(!PyErr_Occurred());
NULLABLE(name);
PyObject *module = PyImport_GetModule(name);
if (module == NULL && !PyErr_Occurred()) {
Py_RETURN_NONE;
PyErr_SetString(PyExc_KeyError, "PyImport_GetModule() returned NULL");
return NULL;
}
return module;
}
Expand All @@ -51,6 +53,7 @@ pyimport_getmodule(PyObject *Py_UNUSED(module), PyObject *name)
static PyObject *
pyimport_addmoduleobject(PyObject *Py_UNUSED(module), PyObject *name)
{
NULLABLE(name);
return Py_XNewRef(PyImport_AddModuleObject(name));
}

Expand All @@ -60,7 +63,7 @@ static PyObject *
pyimport_addmodule(PyObject *Py_UNUSED(module), PyObject *args)
{
const char *name;
if (!PyArg_ParseTuple(args, "s", &name)) {
if (!PyArg_ParseTuple(args, "z", &name)) {
return NULL;
}

Expand All @@ -85,6 +88,7 @@ pyimport_addmoduleref(PyObject *Py_UNUSED(module), PyObject *args)
static PyObject *
pyimport_import(PyObject *Py_UNUSED(module), PyObject *name)
{
NULLABLE(name);
return PyImport_Import(name);
}

Expand All @@ -94,7 +98,7 @@ static PyObject *
pyimport_importmodule(PyObject *Py_UNUSED(module), PyObject *args)
{
const char *name;
if (!PyArg_ParseTuple(args, "s", &name)) {
if (!PyArg_ParseTuple(args, "z", &name)) {
return NULL;
}

Expand All @@ -107,7 +111,7 @@ static PyObject *
pyimport_importmodulenoblock(PyObject *Py_UNUSED(module), PyObject *args)
{
const char *name;
if (!PyArg_ParseTuple(args, "s", &name)) {
if (!PyArg_ParseTuple(args, "z", &name)) {
return NULL;
}

Expand All @@ -128,6 +132,9 @@ pyimport_importmoduleex(PyObject *Py_UNUSED(module), PyObject *args)
&name, &globals, &locals, &fromlist)) {
return NULL;
}
NULLABLE(globals);
NULLABLE(locals);
NULLABLE(fromlist);

return PyImport_ImportModuleEx(name, globals, locals, fromlist);
}
Expand All @@ -144,6 +151,9 @@ pyimport_importmodulelevel(PyObject *Py_UNUSED(module), PyObject *args)
&name, &globals, &locals, &fromlist, &level)) {
return NULL;
}
NULLABLE(globals);
NULLABLE(locals);
NULLABLE(fromlist);

return PyImport_ImportModuleLevel(name, globals, locals, fromlist, level);
}
Expand All @@ -159,6 +169,10 @@ pyimport_importmodulelevelobject(PyObject *Py_UNUSED(module), PyObject *args)
&name, &globals, &locals, &fromlist, &level)) {
return NULL;
}
NULLABLE(name);
NULLABLE(globals);
NULLABLE(locals);
NULLABLE(fromlist);

return PyImport_ImportModuleLevelObject(name, globals, locals, fromlist, level);
}
Expand All @@ -174,10 +188,7 @@ pyimport_importfrozenmodule(PyObject *Py_UNUSED(module), PyObject *args)
}

int res = PyImport_ImportFrozenModule(name);
if (res < 0) {
return NULL;
}
return PyLong_FromLong(res);
RETURN_INT(res);
}


Expand Down Expand Up @@ -214,7 +225,7 @@ pyimport_executecodemoduleex(PyObject *Py_UNUSED(module), PyObject *args)
const char *name;
PyObject *code;
const char *pathname;
if (!PyArg_ParseTuple(args, "sOs", &name, &code, &pathname)) {
if (!PyArg_ParseTuple(args, "zOz", &name, &code, &pathname)) {
return NULL;
}

Expand All @@ -230,7 +241,7 @@ pyimport_executecodemodulewithpathnames(PyObject *Py_UNUSED(module), PyObject *a
PyObject *code;
const char *pathname;
const char *cpathname;
if (!PyArg_ParseTuple(args, "sOzz", &name, &code, &pathname, &cpathname)) {
if (!PyArg_ParseTuple(args, "zOzz", &name, &code, &pathname, &cpathname)) {
return NULL;
}

Expand Down
Loading