Skip to content

Commit 2717b9a

Browse files
committed
Python: Extend import resolution tests
Extends the tests to 1. Account parts of the test code that may be specific to Python 2 or 3, 2. Also track which arguments passed to `check` are references to modules. The latter revealed a bunch of spurious results, which I have annotated accordingly.
1 parent f92d836 commit 2717b9a

File tree

3 files changed

+72
-14
lines changed

3 files changed

+72
-14
lines changed

python/ql/test/experimental/import-resolution/importflow.ql

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ import python
22
import semmle.python.dataflow.new.DataFlow
33
import semmle.python.ApiGraphs
44
import TestUtilities.InlineExpectationsTest
5+
import semmle.python.dataflow.new.internal.ImportResolution
56

7+
/** A string that appears on the right hand side of an assignment. */
68
private class SourceString extends DataFlow::Node {
79
string contents;
810

@@ -14,6 +16,45 @@ private class SourceString extends DataFlow::Node {
1416
string getContents() { result = contents }
1517
}
1618

19+
/** An argument that is checked using the `check` function. */
20+
private class CheckArgument extends DataFlow::Node {
21+
CheckArgument() { this = API::moduleImport("trace").getMember("check").getACall().getArg(1) }
22+
}
23+
24+
/** A data-flow node that is a reference to a module. */
25+
private class ModuleRef extends DataFlow::Node {
26+
Module mod;
27+
28+
ModuleRef() {
29+
this = ImportResolution::getModuleReference(mod) and
30+
not mod.getName() in ["__future__", "trace"]
31+
}
32+
33+
string getName() { result = mod.getName() }
34+
}
35+
36+
/**
37+
* A data-flow node that is guarded by a version check. Only supports checks of the form `if
38+
*sys.version_info[0] == ...` where the right hand side is either `2` or `3`.
39+
*/
40+
private class VersionGuardedNode extends DataFlow::Node {
41+
int version;
42+
43+
VersionGuardedNode() {
44+
version in [2, 3] and
45+
exists(If parent, CompareNode c | parent.getBody().contains(this.asExpr()) |
46+
c.operands(API::moduleImport("sys")
47+
.getMember("version_info")
48+
.getASubscript()
49+
.asSource()
50+
.asCfgNode(), any(Eq eq),
51+
any(IntegerLiteral lit | lit.getValue() = version).getAFlowNode())
52+
)
53+
}
54+
55+
int getVersion() { result = version }
56+
}
57+
1758
private class ImportConfiguration extends DataFlow::Configuration {
1859
ImportConfiguration() { this = "ImportConfiguration" }
1960

@@ -30,12 +71,29 @@ class ResolutionTest extends InlineExpectationsTest {
3071
override string getARelevantTag() { result = "prints" }
3172

3273
override predicate hasActualResult(Location location, string element, string tag, string value) {
33-
exists(DataFlow::PathNode source, DataFlow::PathNode sink, ImportConfiguration config |
34-
config.hasFlowPath(source, sink) and
35-
tag = "prints" and
36-
location = sink.getNode().getLocation() and
37-
value = source.getNode().(SourceString).getContents() and
38-
element = sink.getNode().toString()
74+
(
75+
exists(DataFlow::PathNode source, DataFlow::PathNode sink, ImportConfiguration config |
76+
config.hasFlowPath(source, sink) and
77+
correct_version(sink.getNode()) and
78+
tag = "prints" and
79+
location = sink.getNode().getLocation() and
80+
value = source.getNode().(SourceString).getContents() and
81+
element = sink.getNode().toString()
82+
)
83+
or
84+
exists(ModuleRef ref |
85+
correct_version(ref) and
86+
ref instanceof CheckArgument and
87+
tag = "prints" and
88+
location = ref.getLocation() and
89+
value = "\"<module " + ref.getName() + ">\"" and
90+
element = ref.toString()
91+
)
3992
)
4093
}
4194
}
95+
96+
private predicate correct_version(DataFlow::Node n) {
97+
not n instanceof VersionGuardedNode or
98+
n.(VersionGuardedNode).getVersion() = major_version()
99+
}

python/ql/test/experimental/import-resolution/main.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@
3939

4040
# A simple "import from" statement.
4141
from bar import bar_attr
42-
check("bar_attr", bar_attr, "bar_attr", globals()) #$ prints=bar_attr
42+
check("bar_attr", bar_attr, "bar_attr", globals()) #$ prints=bar_attr SPURIOUS: prints="<module bar>"
4343

4444
# Importing an attribute from a subpackage of a package.
4545
from package.subpackage import subpackage_attr
46-
check("subpackage_attr", subpackage_attr, "subpackage_attr", globals()) #$ prints=subpackage_attr
46+
check("subpackage_attr", subpackage_attr, "subpackage_attr", globals()) #$ prints=subpackage_attr SPURIOUS: prints="<module package.subpackage.__init__>"
4747

4848
# Importing a package attribute under an alias.
4949
from package import package_attr as package_attr_alias
50-
check("package_attr_alias", package_attr_alias, "package_attr", globals()) #$ prints=package_attr
50+
check("package_attr_alias", package_attr_alias, "package_attr", globals()) #$ prints=package_attr SPURIOUS: prints="<module package.__init__>"
5151

5252
# Importing a subpackage under an alias.
5353
from package import subpackage as aliased_subpackage #$ imports=package.subpackage.__init__ as=aliased_subpackage
@@ -68,15 +68,15 @@ def local_import():
6868
import package.subpackage #$ imports=package.__init__ as=package
6969
check("package.package_attr", package.package_attr, "package_attr", globals()) #$ prints=package_attr
7070

71-
if sys.version_info[0] >= 3:
71+
if sys.version_info[0] == 3:
7272
# Importing from a namespace module.
7373
from namespace_package.namespace_module import namespace_module_attr
74-
check("namespace_module_attr", namespace_module_attr, "namespace_module_attr", globals()) #$ prints=namespace_module_attr
74+
check("namespace_module_attr", namespace_module_attr, "namespace_module_attr", globals()) #$ prints=namespace_module_attr SPURIOUS: prints="<module namespace_package.namespace_module>"
7575

7676

7777
from attr_clash import clashing_attr, non_clashing_submodule #$ imports=attr_clash.clashing_attr as=clashing_attr imports=attr_clash.non_clashing_submodule as=non_clashing_submodule
78-
check("clashing_attr", clashing_attr, "clashing_attr", globals()) #$ prints=clashing_attr
79-
check("non_clashing_submodule", non_clashing_submodule, "<module attr_clash.non_clashing_submodule>", globals())
78+
check("clashing_attr", clashing_attr, "clashing_attr", globals()) #$ prints=clashing_attr SPURIOUS: prints="<module attr_clash.clashing_attr>" SPURIOUS: prints="<module attr_clash.__init__>"
79+
check("non_clashing_submodule", non_clashing_submodule, "<module attr_clash.non_clashing_submodule>", globals()) #$ prints="<module attr_clash.non_clashing_submodule>" SPURIOUS: prints="<module attr_clash.__init__>"
8080

8181
exit(__file__)
8282

python/ql/test/experimental/import-resolution/package/subpackage/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# Importing an attribute from the parent package.
77
from .. import attr_used_in_subpackage as imported_attr
8-
check("imported_attr", imported_attr, "attr_used_in_subpackage", globals()) #$ prints=attr_used_in_subpackage
8+
check("imported_attr", imported_attr, "attr_used_in_subpackage", globals()) #$ prints=attr_used_in_subpackage SPURIOUS: prints="<module package.__init__>"
99

1010
# Importing an irrelevant attribute from a sibling module binds the name to the module.
1111
from .submodule import irrelevant_attr

0 commit comments

Comments
 (0)