Skip to content

Commit 68ac4a1

Browse files
committed
collect_imported_tests option
1 parent 222457d commit 68ac4a1

File tree

4 files changed

+133
-51
lines changed

4 files changed

+133
-51
lines changed

changelog/12749.feature.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Add :confval:`discover_imports`, when disabled (default) will make sure to not consider classes which are imported by a test file and starts with Test.
1+
Add :confval:`collect_imported_tests`, when enabled (default is disabled) will make sure to not consider classes/functions which are imported by a test file and contains Test/test_*/*_test.
22
3-
-- by :user:`FreerGit`
3+
-- by :user:`FreerGit`

src/_pytest/main.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ def pytest_addoption(parser: Parser) -> None:
7979
default=[],
8080
)
8181
parser.addini(
82-
"discover_imports",
83-
"Whether to discover tests in imported modules outside `testpaths`",
82+
"collect_imported_tests",
83+
"Whether to collect tests in imported modules outside `testpaths`",
84+
type="bool",
8485
default=False,
8586
)
8687
group = parser.getgroup("general", "Running and selection options")
@@ -963,16 +964,41 @@ def collect(self) -> Iterator[nodes.Item | nodes.Collector]:
963964
self.trace.root.indent -= 1
964965

965966
def genitems(self, node: nodes.Item | nodes.Collector) -> Iterator[nodes.Item]:
967+
import inspect
968+
969+
from _pytest.python import Class
970+
from _pytest.python import Function
971+
from _pytest.python import Module
972+
966973
self.trace("genitems", node)
967974
if isinstance(node, nodes.Item):
968975
node.ihook.pytest_itemcollected(item=node)
976+
if self.config.getini("collect_imported_tests"):
977+
if isinstance(node.parent, Module) and isinstance(node, Function):
978+
if inspect.isfunction(node._getobj()):
979+
fn_defined_at = node._getobj().__module__
980+
in_module = node.parent._getobj().__name__
981+
if fn_defined_at != in_module:
982+
return
969983
yield node
970984
else:
971985
assert isinstance(node, nodes.Collector)
972986
keepduplicates = self.config.getoption("keepduplicates")
973987
# For backward compat, dedup only applies to files.
974988
handle_dupes = not (keepduplicates and isinstance(node, nodes.File))
975989
rep, duplicate = self._collect_one_node(node, handle_dupes)
990+
991+
if self.config.getini("collect_imported_tests"):
992+
for subnode in rep.result:
993+
if isinstance(subnode, Class) and isinstance(
994+
subnode.parent, Module
995+
):
996+
if inspect.isclass(subnode._getobj()):
997+
class_defined_at = subnode._getobj().__module__
998+
in_module = subnode.parent._getobj().__name__
999+
if class_defined_at != in_module:
1000+
rep.result.remove(subnode)
1001+
9761002
if duplicate and not keepduplicates:
9771003
return
9781004
if rep.passed:

src/_pytest/python.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -741,12 +741,6 @@ def newinstance(self):
741741
return self.obj()
742742

743743
def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
744-
if self.config.getini("discover_imports") == ("false" or False):
745-
paths = self.config.getini("testpaths")
746-
class_file = inspect.getfile(self.obj)
747-
if not any(string in class_file for string in paths):
748-
return []
749-
750744
if not safe_getattr(self.obj, "__test__", True):
751745
return []
752746
if hasinit(self.obj):

testing/test_discover_imports.py

Lines changed: 103 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
1-
import pytest
1+
from __future__ import annotations
2+
23
import textwrap
34

4-
def test_discover_imports_enabled(pytester):
5+
from _pytest.pytester import Pytester
6+
7+
8+
def run_import_class_test(pytester: Pytester, passed: int = 0, errors: int = 0) -> None:
59
src_dir = pytester.mkdir("src")
610
tests_dir = pytester.mkdir("tests")
7-
pytester.makeini("""
8-
[pytest]
9-
testpaths = "tests"
10-
discover_imports = true
11-
""")
12-
1311
src_file = src_dir / "foo.py"
1412

15-
src_file.write_text(textwrap.dedent("""\
16-
class TestClass(object):
13+
src_file.write_text(
14+
textwrap.dedent("""\
15+
class Testament(object):
1716
def __init__(self):
1817
super().__init__()
18+
self.collections = ["stamp", "coin"]
1919
20-
def test_foobar(self):
21-
return true
22-
"""
23-
), encoding="utf-8")
20+
def personal_property(self):
21+
return [f"my {x} collection" for x in self.collections]
22+
"""),
23+
encoding="utf-8",
24+
)
2425

2526
test_file = tests_dir / "foo_test.py"
26-
test_file.write_text(textwrap.dedent("""\
27+
test_file.write_text(
28+
textwrap.dedent("""\
2729
import sys
2830
import os
2931
@@ -32,42 +34,78 @@ def test_foobar(self):
3234
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
3335
sys.path.append(parent_dir)
3436
35-
from src.foo import TestClass
37+
from src.foo import Testament
3638
3739
class TestDomain:
3840
def test_testament(self):
39-
testament = TestClass()
40-
pass
41-
"""), encoding="utf-8")
41+
testament = Testament()
42+
assert testament.personal_property()
43+
"""),
44+
encoding="utf-8",
45+
)
4246

4347
result = pytester.runpytest()
44-
result.assert_outcomes(errors=1)
48+
result.assert_outcomes(passed=passed, errors=errors)
4549

46-
def test_discover_imports_disabled(pytester):
47-
48-
src_dir = pytester.mkdir("src")
49-
tests_dir = pytester.mkdir("tests")
50+
51+
def test_collect_imports_disabled(pytester: Pytester) -> None:
52+
pytester.makeini("""
53+
[pytest]
54+
testpaths = "tests"
55+
collect_imported_tests = false
56+
""")
57+
58+
run_import_class_test(pytester, errors=1)
59+
60+
61+
def test_collect_imports_default(pytester: Pytester) -> None:
62+
pytester.makeini("""
63+
[pytest]
64+
testpaths = "tests"
65+
""")
66+
67+
run_import_class_test(pytester, errors=1)
68+
69+
70+
def test_collect_imports_enabled(pytester: Pytester) -> None:
5071
pytester.makeini("""
5172
[pytest]
5273
testpaths = "tests"
53-
discover_imports = false
74+
collect_imported_tests = true
5475
""")
5576

77+
run_import_class_test(pytester, passed=1)
78+
79+
80+
def run_import_functions_test(
81+
pytester: Pytester, passed: int, errors: int, failed: int
82+
) -> None:
83+
src_dir = pytester.mkdir("src")
84+
tests_dir = pytester.mkdir("tests")
85+
5686
src_file = src_dir / "foo.py"
5787

58-
src_file.write_text(textwrap.dedent("""\
59-
class Testament(object):
60-
def __init__(self):
61-
super().__init__()
62-
self.collections = ["stamp", "coin"]
88+
# Note that these "tests" are should _not_ be treated as tests.
89+
# They are normal functions that happens to have test_* or *_test in the name.
90+
# Thus should _not_ be collected!
91+
src_file.write_text(
92+
textwrap.dedent("""\
93+
def test_function():
94+
some_random_computation = 5
95+
return some_random_computation
6396
64-
def personal_property(self):
65-
return [f"my {x} collection" for x in self.collections]
66-
"""
67-
), encoding="utf-8")
97+
def test_bar():
98+
pass
99+
"""),
100+
encoding="utf-8",
101+
)
68102

69103
test_file = tests_dir / "foo_test.py"
70-
test_file.write_text(textwrap.dedent("""\
104+
105+
# Inferred from the comment above, this means that there is _only_ one actual test
106+
# which should result in only 1 passing test being ran.
107+
test_file.write_text(
108+
textwrap.dedent("""\
71109
import sys
72110
import os
73111
@@ -76,13 +114,37 @@ def personal_property(self):
76114
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
77115
sys.path.append(parent_dir)
78116
79-
from src.foo import Testament
117+
from src.foo import *
80118
81119
class TestDomain:
82-
def test_testament(self):
83-
testament = Testament()
84-
assert testament.personal_property()
85-
"""), encoding="utf-8")
120+
def test_important(self):
121+
res = test_function()
122+
if res == 5:
123+
pass
124+
125+
"""),
126+
encoding="utf-8",
127+
)
86128

87129
result = pytester.runpytest()
88-
result.assert_outcomes(passed=1)
130+
result.assert_outcomes(passed=passed, errors=errors, failed=failed)
131+
132+
133+
def test_collect_function_imports_enabled(pytester: Pytester) -> None:
134+
pytester.makeini("""
135+
[pytest]
136+
testpaths = "tests"
137+
collect_imported_tests = true
138+
""")
139+
140+
run_import_functions_test(pytester, passed=1, errors=0, failed=0)
141+
142+
143+
def test_collect_function_imports_disabled(pytester: Pytester) -> None:
144+
pytester.makeini("""
145+
[pytest]
146+
testpaths = "tests"
147+
collect_imported_tests = false
148+
""")
149+
150+
run_import_functions_test(pytester, passed=2, errors=0, failed=1)

0 commit comments

Comments
 (0)