Skip to content

Commit 935c06d

Browse files
committed
WIP: don't collect instead of filtering out
1 parent eb8592c commit 935c06d

File tree

4 files changed

+115
-53
lines changed

4 files changed

+115
-53
lines changed

doc/en/reference/reference.rst

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,20 @@ passed multiple times. The expected format is ``name=value``. For example::
13011301
variables, that will be expanded. For more information about cache plugin
13021302
please refer to :ref:`cache_provider`.
13031303

1304+
.. confval:: collect_imported_tests
1305+
1306+
.. versionadded:: 8.4
1307+
1308+
Setting this to ``false`` will make pytest collect classes/functions from test
1309+
files only if they are defined in that file (as opposed to imported there).
1310+
1311+
.. code-block:: ini
1312+
1313+
[pytest]
1314+
collect_imported_tests = false
1315+
1316+
Default: ``true``
1317+
13041318
.. confval:: consider_namespace_packages
13051319

13061320
Controls if pytest should attempt to identify `namespace packages <https://packaging.python.org/en/latest/guides/packaging-namespace-packages>`__
@@ -1838,17 +1852,6 @@ passed multiple times. The expected format is ``name=value``. For example::
18381852
18391853
pytest testing doc
18401854
1841-
1842-
.. confval:: collect_imported_tests
1843-
1844-
Setting this to `false` will make pytest collect classes/functions from test
1845-
files only if they are defined in that file (as opposed to imported there).
1846-
1847-
.. code-block:: ini
1848-
1849-
[pytest]
1850-
collect_imported_tests = false
1851-
18521855
.. confval:: tmp_path_retention_count
18531856

18541857
How many sessions should we keep the `tmp_path` directories,

src/_pytest/main.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -964,41 +964,16 @@ def collect(self) -> Iterator[nodes.Item | nodes.Collector]:
964964
self.trace.root.indent -= 1
965965

966966
def genitems(self, node: nodes.Item | nodes.Collector) -> Iterator[nodes.Item]:
967-
import inspect
968-
969-
from _pytest.python import Class
970-
from _pytest.python import Function
971-
from _pytest.python import Module
972-
973967
self.trace("genitems", node)
974968
if isinstance(node, nodes.Item):
975969
node.ihook.pytest_itemcollected(item=node)
976-
if not self.config.getini("collect_imported_tests"):
977-
if isinstance(node.parent, Module) and isinstance(node, Function):
978-
if inspect.isfunction(node._getobj()):
979-
fn_defined_at = node._getobj().__module__
980-
in_module = node.parent._getobj().__name__
981-
if fn_defined_at != in_module:
982-
return
983970
yield node
984971
else:
985972
assert isinstance(node, nodes.Collector)
986973
keepduplicates = self.config.getoption("keepduplicates")
987974
# For backward compat, dedup only applies to files.
988975
handle_dupes = not (keepduplicates and isinstance(node, nodes.File))
989976
rep, duplicate = self._collect_one_node(node, handle_dupes)
990-
991-
if not self.config.getini("collect_imported_tests"):
992-
for subnode in rep.result:
993-
if isinstance(subnode, Class) and isinstance(
994-
subnode.parent, Module
995-
):
996-
if inspect.isclass(subnode._getobj()):
997-
class_defined_at = subnode._getobj().__module__
998-
in_module = subnode.parent._getobj().__name__
999-
if class_defined_at != in_module:
1000-
rep.result.remove(subnode)
1001-
1002977
if duplicate and not keepduplicates:
1003978
return
1004979
if rep.passed:

src/_pytest/python.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,15 @@ def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
416416
if name in seen:
417417
continue
418418
seen.add(name)
419+
420+
if not self.session.config.getini("collect_imported_tests"):
421+
# Do not collect imported functions
422+
if inspect.isfunction(obj) and isinstance(self, Module):
423+
fn_defined_at = obj.__module__
424+
in_module = self._getobj().__name__
425+
if fn_defined_at != in_module:
426+
continue
427+
419428
res = ihook.pytest_pycollect_makeitem(
420429
collector=self, name=name, obj=obj
421430
)
@@ -741,6 +750,16 @@ def newinstance(self):
741750
return self.obj()
742751

743752
def collect(self) -> Iterable[nodes.Item | nodes.Collector]:
753+
if not self.config.getini("collect_imported_tests"):
754+
# This entire branch will discard (not collect) a class
755+
# if it is imported (defined in a different module)
756+
if isinstance(self, Class) and isinstance(self.parent, Module):
757+
if inspect.isclass(self._getobj()):
758+
class_defined_at = self._getobj().__module__
759+
in_module = self.parent._getobj().__name__
760+
if class_defined_at != in_module:
761+
return []
762+
744763
if not safe_getattr(self.obj, "__test__", True):
745764
return []
746765
if hasinit(self.obj):

testing/test_collect_imports.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from _pytest.pytester import Pytester
66

77

8+
# Start of tests for classes
9+
10+
811
def run_import_class_test(pytester: Pytester, passed: int = 0, errors: int = 0) -> None:
912
src_dir = pytester.mkdir("src")
1013
tests_dir = pytester.mkdir("tests")
@@ -57,26 +60,37 @@ def test_collect_imports_disabled(pytester: Pytester) -> None:
5760

5861
run_import_class_test(pytester, passed=1)
5962

63+
# Verify that the state of hooks
64+
reprec = pytester.inline_run()
65+
items_collected = reprec.getcalls("pytest_itemcollected")
66+
assert len(items_collected) == 1
67+
for x in items_collected:
68+
assert x.item._getobj().__name__ == "test_testament"
6069

61-
def test_collect_imports_default(pytester: Pytester) -> None:
62-
pytester.makeini("""
63-
[pytest]
64-
testpaths = "tests"
65-
""")
6670

71+
def test_collect_imports_default(pytester: Pytester) -> None:
6772
run_import_class_test(pytester, errors=1)
6873

74+
# TODO, hooks
75+
6976

7077
def test_collect_imports_enabled(pytester: Pytester) -> None:
7178
pytester.makeini("""
7279
[pytest]
73-
testpaths = "tests"
7480
collect_imported_tests = true
7581
""")
7682

7783
run_import_class_test(pytester, errors=1)
7884

7985

86+
# # TODO, hooks
87+
88+
89+
# End of tests for classes
90+
#################################
91+
# Start of tests for functions
92+
93+
8094
def run_import_functions_test(
8195
pytester: Pytester, passed: int, errors: int, failed: int
8296
) -> None:
@@ -85,8 +99,8 @@ def run_import_functions_test(
8599

86100
src_file = src_dir / "foo.py"
87101

88-
# Note that these "tests" are should _not_ be treated as tests.
89-
# They are normal functions that happens to have test_* or *_test in the name.
102+
# Note that these "tests" should _not_ be treated as tests if `collect_imported_tests = false`
103+
# They are normal functions in that case, that happens to have test_* or *_test in the name.
90104
# Thus should _not_ be collected!
91105
src_file.write_text(
92106
textwrap.dedent("""\
@@ -121,7 +135,6 @@ def test_important(self):
121135
res = test_function()
122136
if res == 5:
123137
pass
124-
125138
"""),
126139
encoding="utf-8",
127140
)
@@ -138,31 +151,83 @@ def test_collect_function_imports_enabled(pytester: Pytester) -> None:
138151
""")
139152

140153
run_import_functions_test(pytester, passed=2, errors=0, failed=1)
154+
reprec = pytester.inline_run()
155+
items_collected = reprec.getcalls("pytest_itemcollected")
156+
# Recall that the default is `collect_imported_tests = true`.
157+
# Which means that the normal functions are now interpreted as
158+
# valid tests and `test_function()` will fail.
159+
assert len(items_collected) == 3
160+
for x in items_collected:
161+
assert x.item._getobj().__name__ in [
162+
"test_important",
163+
"test_bar",
164+
"test_function",
165+
]
141166

142167

143-
def test_collect_function_imports_disabled(pytester: Pytester) -> None:
168+
def test_behaviour_without_testpaths_set_and_false(pytester: Pytester) -> None:
169+
# Make sure `collect_imported_tests` has no dependence on `testpaths`
144170
pytester.makeini("""
145171
[pytest]
146-
# testpaths = "tests"
147172
collect_imported_tests = false
148173
""")
149174

150175
run_import_functions_test(pytester, passed=1, errors=0, failed=0)
176+
reprec = pytester.inline_run()
177+
items_collected = reprec.getcalls("pytest_itemcollected")
178+
assert len(items_collected) == 1
179+
for x in items_collected:
180+
assert x.item._getobj().__name__ == "test_important"
151181

152182

153-
def test_behaviour_without_testpaths_set_and_false(pytester: Pytester) -> None:
183+
def test_behaviour_without_testpaths_set_and_true(pytester: Pytester) -> None:
184+
# Make sure `collect_imported_tests` has no dependence on `testpaths`
154185
pytester.makeini("""
155186
[pytest]
156-
collect_imported_tests = false
187+
collect_imported_tests = true
157188
""")
158189

159-
run_import_functions_test(pytester, passed=1, errors=0, failed=0)
190+
run_import_functions_test(pytester, passed=2, errors=0, failed=1)
191+
reprec = pytester.inline_run()
192+
items_collected = reprec.getcalls("pytest_itemcollected")
193+
assert len(items_collected) == 3
160194

161195

162-
def test_behaviour_without_testpaths_set_and_true(pytester: Pytester) -> None:
196+
def test_hook_behaviour_when_collect_off(pytester: Pytester) -> None:
163197
pytester.makeini("""
164198
[pytest]
165-
collect_imported_tests = true
199+
collect_imported_tests = false
166200
""")
167201

168-
run_import_functions_test(pytester, passed=2, errors=0, failed=1)
202+
run_import_functions_test(pytester, passed=1, errors=0, failed=0)
203+
reprec = pytester.inline_run()
204+
205+
# reports = reprec.getreports("pytest_collectreport")
206+
items_collected = reprec.getcalls("pytest_itemcollected")
207+
modified = reprec.getcalls("pytest_collection_modifyitems")
208+
209+
# print("Reports: ----------------")
210+
# print(reports)
211+
# for r in reports:
212+
# print(r)
213+
214+
# TODO this is want I want, I think....
215+
# <CollectReport '' lenresult=1 outcome='passed'>
216+
# <CollectReport 'tests/foo_test.py::TestDomain' lenresult=1 outcome='passed'>
217+
# <CollectReport 'tests/foo_test.py' lenresult=1 outcome='passed'>
218+
# <CollectReport 'tests' lenresult=1 outcome='passed'>
219+
# <CollectReport '.' lenresult=1 outcome='passed'>
220+
221+
# TODO
222+
# assert(reports.outcome == "passed")
223+
# assert(len(reports.result) == 1)
224+
225+
# print("Items collected: ----------------")
226+
# print(items_collected)
227+
# print("Modified : ----------------")
228+
229+
assert len(items_collected) == 1
230+
for x in items_collected:
231+
assert x.item._getobj().__name__ == "test_important"
232+
233+
assert len(modified) == 1

0 commit comments

Comments
 (0)