Skip to content

Commit 6e3105d

Browse files
authored
Merge pull request #1787 from nicoddemus/fix-rewrite-conftest
Rewrite asserts in test-modules loaded very early in the startup
2 parents d5be6cb + 6711b1d commit 6e3105d

File tree

3 files changed

+35
-22
lines changed

3 files changed

+35
-22
lines changed

_pytest/assertion/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@ def pytest_namespace():
3131
def register_assert_rewrite(*names):
3232
"""Register a module name to be rewritten on import.
3333
34-
This function will make sure that the module will get it's assert
35-
statements rewritten when it is imported. Thus you should make
36-
sure to call this before the module is actually imported, usually
37-
in your __init__.py if you are a plugin using a package.
34+
This function will make sure that this module or all modules inside
35+
the package will get their assert statements rewritten.
36+
Thus you should make sure to call this before the module is
37+
actually imported, usually in your __init__.py if you are a plugin
38+
using a package.
3839
"""
3940
for hook in sys.meta_path:
4041
if isinstance(hook, rewrite.AssertionRewritingHook):

_pytest/assertion/rewrite.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import struct
1212
import sys
1313
import types
14+
from fnmatch import fnmatch
1415

1516
import py
1617
from _pytest.assertion import util
@@ -144,28 +145,29 @@ def _should_rewrite(self, name, fn_pypath, state):
144145
if fn_pypath.basename == 'conftest.py':
145146
state.trace("rewriting conftest file: %r" % (fn,))
146147
return True
147-
elif self.session is not None:
148+
149+
if self.session is not None:
148150
if self.session.isinitpath(fn):
149151
state.trace("matched test file (was specified on cmdline): %r" %
150152
(fn,))
151153
return True
152-
else:
153-
# modules not passed explicitly on the command line are only
154-
# rewritten if they match the naming convention for test files
155-
session = self.session # avoid a cycle here
156-
self.session = None
157-
try:
158-
for pat in self.fnpats:
159-
if fn_pypath.fnmatch(pat):
160-
state.trace("matched test file %r" % (fn,))
161-
return True
162-
finally:
163-
self.session = session
164-
del session
165-
else:
166-
for marked in self._must_rewrite:
167-
if marked.startswith(name):
168-
return True
154+
155+
# modules not passed explicitly on the command line are only
156+
# rewritten if they match the naming convention for test files
157+
for pat in self.fnpats:
158+
# use fnmatch instead of fn_pypath.fnmatch because the
159+
# latter might trigger an import to fnmatch.fnmatch
160+
# internally, which would cause this method to be
161+
# called recursively
162+
if fnmatch(fn_pypath.basename, pat):
163+
state.trace("matched test file %r" % (fn,))
164+
return True
165+
166+
for marked in self._must_rewrite:
167+
if name.startswith(marked):
168+
state.trace("matched marked file %r (from %r)" % (name, marked))
169+
return True
170+
169171
return False
170172

171173
def mark_rewrite(self, *names):

testing/test_assertrewrite.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,16 @@ def mywarn(code, msg):
533533
hook.mark_rewrite('_pytest')
534534
assert '_pytest' in warnings[0][1]
535535

536+
def test_rewrite_module_imported_from_conftest(self, testdir):
537+
testdir.makeconftest('''
538+
import test_rewrite_module_imported
539+
''')
540+
testdir.makepyfile(test_rewrite_module_imported='''
541+
def test_rewritten():
542+
assert "@py_builtins" in globals()
543+
''')
544+
assert testdir.runpytest_subprocess().ret == 0
545+
536546

537547
class TestAssertionRewriteHookDetails(object):
538548
def test_loader_is_package_false_for_module(self, testdir):

0 commit comments

Comments
 (0)