Skip to content

Commit 3a81d2e

Browse files
committed
conftest files now use assertion rewriting
Fix #1619
1 parent 54872e9 commit 3a81d2e

File tree

4 files changed

+112
-40
lines changed

4 files changed

+112
-40
lines changed

_pytest/assertion/__init__.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import py
55
import os
66
import sys
7+
8+
from _pytest.config import hookimpl
79
from _pytest.monkeypatch import monkeypatch
810
from _pytest.assertion import util
911

@@ -42,9 +44,13 @@ def __init__(self, config, mode):
4244
self.trace = config.trace.root.get("assertion")
4345

4446

45-
def pytest_configure(config):
46-
mode = config.getvalue("assertmode")
47-
if config.getvalue("noassert") or config.getvalue("nomagic"):
47+
@hookimpl(tryfirst=True)
48+
def pytest_load_initial_conftests(early_config, parser, args):
49+
ns, ns_unknown_args = parser.parse_known_and_unknown_args(args)
50+
mode = ns.assertmode
51+
no_assert = ns.noassert
52+
no_magic = ns.nomagic
53+
if no_assert or no_magic:
4854
mode = "plain"
4955
if mode == "rewrite":
5056
try:
@@ -57,25 +63,30 @@ def pytest_configure(config):
5763
if (sys.platform.startswith('java') or
5864
sys.version_info[:3] == (2, 6, 0)):
5965
mode = "reinterp"
66+
67+
early_config._assertstate = AssertionState(early_config, mode)
68+
warn_about_missing_assertion(mode, early_config.pluginmanager)
69+
6070
if mode != "plain":
6171
_load_modules(mode)
6272
m = monkeypatch()
63-
config._cleanup.append(m.undo)
73+
early_config._cleanup.append(m.undo)
6474
m.setattr(py.builtin.builtins, 'AssertionError',
6575
reinterpret.AssertionError) # noqa
76+
6677
hook = None
6778
if mode == "rewrite":
6879
hook = rewrite.AssertionRewritingHook() # noqa
80+
hook.set_config(early_config)
6981
sys.meta_path.insert(0, hook)
70-
warn_about_missing_assertion(mode)
71-
config._assertstate = AssertionState(config, mode)
72-
config._assertstate.hook = hook
73-
config._assertstate.trace("configured with mode set to %r" % (mode,))
82+
83+
early_config._assertstate.hook = hook
84+
early_config._assertstate.trace("configured with mode set to %r" % (mode,))
7485
def undo():
75-
hook = config._assertstate.hook
86+
hook = early_config._assertstate.hook
7687
if hook is not None and hook in sys.meta_path:
7788
sys.meta_path.remove(hook)
78-
config.add_cleanup(undo)
89+
early_config.add_cleanup(undo)
7990

8091

8192
def pytest_collection(session):
@@ -154,7 +165,8 @@ def _load_modules(mode):
154165
from _pytest.assertion import rewrite # noqa
155166

156167

157-
def warn_about_missing_assertion(mode):
168+
def warn_about_missing_assertion(mode, pluginmanager):
169+
print('got here')
158170
try:
159171
assert False
160172
except AssertionError:
@@ -166,10 +178,18 @@ def warn_about_missing_assertion(mode):
166178
else:
167179
specifically = "failing tests may report as passing"
168180

169-
sys.stderr.write("WARNING: " + specifically +
170-
" because assert statements are not executed "
171-
"by the underlying Python interpreter "
172-
"(are you using python -O?)\n")
181+
# temporarily disable capture so we can print our warning
182+
capman = pluginmanager.getplugin('capturemanager')
183+
try:
184+
out, err = capman.suspendcapture()
185+
sys.stderr.write("WARNING: " + specifically +
186+
" because assert statements are not executed "
187+
"by the underlying Python interpreter "
188+
"(are you using python -O?)\n")
189+
finally:
190+
capman.resumecapture()
191+
sys.stdout.write(out)
192+
sys.stderr.write(err)
173193

174194

175195
# Expose this plugin's implementation for the pytest_assertrepr_compare hook

_pytest/assertion/rewrite.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ def __init__(self):
5050
self._register_with_pkg_resources()
5151

5252
def set_session(self, session):
53-
self.fnpats = session.config.getini("python_files")
5453
self.session = session
5554

55+
def set_config(self, config):
56+
self.config = config
57+
self.fnpats = config.getini("python_files")
58+
5659
def find_module(self, name, path=None):
57-
if self.session is None:
58-
return None
59-
sess = self.session
60-
state = sess.config._assertstate
60+
state = self.config._assertstate
6161
state.trace("find_module called for: %s" % name)
6262
names = name.rsplit(".", 1)
6363
lastname = names[-1]
@@ -86,24 +86,11 @@ def find_module(self, name, path=None):
8686
return None
8787
else:
8888
fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
89+
8990
fn_pypath = py.path.local(fn)
90-
# Is this a test file?
91-
if not sess.isinitpath(fn):
92-
# We have to be very careful here because imports in this code can
93-
# trigger a cycle.
94-
self.session = None
95-
try:
96-
for pat in self.fnpats:
97-
if fn_pypath.fnmatch(pat):
98-
state.trace("matched test file %r" % (fn,))
99-
break
100-
else:
101-
return None
102-
finally:
103-
self.session = sess
104-
else:
105-
state.trace("matched test file (was specified on cmdline): %r" %
106-
(fn,))
91+
if not self._should_rewrite(fn_pypath, state):
92+
return
93+
10794
# The requested module looks like a test file, so rewrite it. This is
10895
# the most magical part of the process: load the source, rewrite the
10996
# asserts, and load the rewritten source. We also cache the rewritten
@@ -151,6 +138,32 @@ def find_module(self, name, path=None):
151138
self.modules[name] = co, pyc
152139
return self
153140

141+
def _should_rewrite(self, fn_pypath, state):
142+
# always rewrite conftest files
143+
fn = str(fn_pypath)
144+
if fn_pypath.basename == 'conftest.py':
145+
state.trace("rewriting conftest file: %r" % (fn,))
146+
return True
147+
elif self.session is not None:
148+
if self.session.isinitpath(fn):
149+
state.trace("matched test file (was specified on cmdline): %r" %
150+
(fn,))
151+
return True
152+
else:
153+
# modules not passed explicitly on the command line are only
154+
# rewritten if they match the naming convention for test files
155+
session = self.session # avoid a cycle here
156+
self.session = None
157+
try:
158+
for pat in self.fnpats:
159+
if fn_pypath.fnmatch(pat):
160+
state.trace("matched test file %r" % (fn,))
161+
return True
162+
finally:
163+
self.session = session
164+
del session
165+
return False
166+
154167
def load_module(self, name):
155168
# If there is an existing module object named 'fullname' in
156169
# sys.modules, the loader must use that existing module. (Otherwise,

testing/test_assertrewrite.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,40 @@ def test_foo(self):
694694
result = testdir.runpytest()
695695
result.stdout.fnmatch_lines('*1 passed*')
696696

697+
@pytest.mark.parametrize('initial_conftest', [True, False])
698+
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
699+
def test_conftest_assertion_rewrite(self, testdir, initial_conftest, mode):
700+
"""Test that conftest files are using assertion rewrite on import.
701+
(#1619)
702+
"""
703+
testdir.tmpdir.join('foo/tests').ensure(dir=1)
704+
conftest_path = 'conftest.py' if initial_conftest else 'foo/conftest.py'
705+
contents = {
706+
conftest_path: """
707+
import pytest
708+
@pytest.fixture
709+
def check_first():
710+
def check(values, value):
711+
assert values.pop(0) == value
712+
return check
713+
""",
714+
'foo/tests/test_foo.py': """
715+
def test(check_first):
716+
check_first([10, 30], 30)
717+
"""
718+
}
719+
testdir.makepyfile(**contents)
720+
result = testdir.runpytest_subprocess('--assert=%s' % mode)
721+
if mode == 'plain':
722+
expected = 'E AssertionError'
723+
elif mode == 'rewrite':
724+
expected = '*assert 10 == 30*'
725+
elif mode == 'reinterp':
726+
expected = '*AssertionError:*was re-run*'
727+
else:
728+
assert 0
729+
result.stdout.fnmatch_lines([expected])
730+
697731

698732
def test_issue731(testdir):
699733
testdir.makepyfile("""

testing/test_config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,14 @@ def pytest_load_initial_conftests(self):
485485
pm.register(m)
486486
hc = pm.hook.pytest_load_initial_conftests
487487
l = hc._nonwrappers + hc._wrappers
488-
assert l[-1].function.__module__ == "_pytest.capture"
489-
assert l[-2].function == m.pytest_load_initial_conftests
490-
assert l[-3].function.__module__ == "_pytest.config"
488+
expected = [
489+
"_pytest.config",
490+
'test_config',
491+
'_pytest.assertion',
492+
'_pytest.capture',
493+
]
494+
assert [x.function.__module__ for x in l] == expected
495+
491496

492497
class TestWarning:
493498
def test_warn_config(self, testdir):

0 commit comments

Comments
 (0)