Skip to content

Commit 743f59a

Browse files
committed
Introduce pytest.register_assert_rewrite()
Plugins can now explicitly mark modules to be re-written. By default only the modules containing the plugin entrypoint are re-written.
1 parent 944da5b commit 743f59a

File tree

4 files changed

+127
-21
lines changed

4 files changed

+127
-21
lines changed

_pytest/assertion/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from _pytest.monkeypatch import monkeypatch
99
from _pytest.assertion import util
10+
from _pytest.assertion import rewrite
1011

1112

1213
def pytest_addoption(parser):
@@ -26,6 +27,34 @@ def pytest_addoption(parser):
2627
provide assert expression information. """)
2728

2829

30+
def pytest_namespace():
31+
return {'register_assert_rewrite': register_assert_rewrite}
32+
33+
34+
def register_assert_rewrite(*names):
35+
"""Register a module name to be rewritten on import.
36+
37+
This function will make sure that the module will get it's assert
38+
statements rewritten when it is imported. Thus you should make
39+
sure to call this before the module is actually imported, usually
40+
in your __init__.py if you are a plugin using a package.
41+
"""
42+
for hook in sys.meta_path:
43+
if isinstance(hook, rewrite.AssertionRewritingHook):
44+
importhook = hook
45+
break
46+
else:
47+
importhook = DummyRewriteHook()
48+
importhook.mark_rewrite(*names)
49+
50+
51+
class DummyRewriteHook(object):
52+
"""A no-op import hook for when rewriting is disabled."""
53+
54+
def mark_rewrite(self, *names):
55+
pass
56+
57+
2958
class AssertionState:
3059
"""State for the assertion plugin."""
3160

_pytest/assertion/rewrite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def _should_rewrite(self, name, fn_pypath, state):
163163
self.session = session
164164
del session
165165
else:
166-
toplevel_name = name.split('.', 1)[0]
167-
if toplevel_name in self._must_rewrite:
168-
return True
166+
for marked in self._must_rewrite:
167+
if marked.startswith(name):
168+
return True
169169
return False
170170

171171
def mark_rewrite(self, *names):

_pytest/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import sys, os
1212
import _pytest._code
1313
import _pytest.hookspec # the extension point definitions
14+
import _pytest.assertion
1415
from _pytest._pluggy import PluginManager, HookimplMarker, HookspecMarker
1516

1617
hookimpl = HookimplMarker("pytest")
@@ -154,6 +155,9 @@ def __init__(self):
154155
self.trace.root.setwriter(err.write)
155156
self.enable_tracing()
156157

158+
# Config._consider_importhook will set a real object if required.
159+
self.rewrite_hook = _pytest.assertion.DummyRewriteHook()
160+
157161
def addhooks(self, module_or_class):
158162
"""
159163
.. deprecated:: 2.8
@@ -362,7 +366,9 @@ def consider_env(self):
362366
self._import_plugin_specs(os.environ.get("PYTEST_PLUGINS"))
363367

364368
def consider_module(self, mod):
365-
self._import_plugin_specs(getattr(mod, "pytest_plugins", None))
369+
plugins = getattr(mod, 'pytest_plugins', [])
370+
self.rewrite_hook.mark_rewrite(*plugins)
371+
self._import_plugin_specs(plugins)
366372

367373
def _import_plugin_specs(self, spec):
368374
if spec:
@@ -926,15 +932,13 @@ def _consider_importhook(self, args, entrypoint_name):
926932
and find all the installed plugins to mark them for re-writing
927933
by the importhook.
928934
"""
929-
import _pytest.assertion
930935
ns, unknown_args = self._parser.parse_known_and_unknown_args(args)
931936
mode = ns.assertmode
932-
if ns.noassert or ns.nomagic:
933-
mode = "plain"
934937
self._warn_about_missing_assertion(mode)
935938
if mode != 'plain':
936939
hook = _pytest.assertion.install_importhook(self, mode)
937940
if hook:
941+
self.pluginmanager.rewrite_hook = hook
938942
for entrypoint in pkg_resources.iter_entry_points('pytest11'):
939943
for entry in entrypoint.dist._get_metadata('RECORD'):
940944
fn = entry.split(',')[0]

testing/test_assertion.py

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,53 @@ def test(check_first):
6363
assert 0
6464
result.stdout.fnmatch_lines([expected])
6565

66+
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
67+
def test_pytest_plugins_rewrite(self, testdir, mode):
68+
contents = {
69+
'conftest.py': """
70+
pytest_plugins = ['ham']
71+
""",
72+
'ham.py': """
73+
import pytest
74+
@pytest.fixture
75+
def check_first():
76+
def check(values, value):
77+
assert values.pop(0) == value
78+
return check
79+
""",
80+
'test_foo.py': """
81+
def test_foo(check_first):
82+
check_first([10, 30], 30)
83+
""",
84+
}
85+
testdir.makepyfile(**contents)
86+
result = testdir.runpytest_subprocess('--assert=%s' % mode)
87+
if mode == 'plain':
88+
expected = 'E AssertionError'
89+
elif mode == 'rewrite':
90+
expected = '*assert 10 == 30*'
91+
elif mode == 'reinterp':
92+
expected = '*AssertionError:*was re-run*'
93+
else:
94+
assert 0
95+
result.stdout.fnmatch_lines([expected])
96+
6697
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
6798
def test_installed_plugin_rewrite(self, testdir, mode):
6899
# Make sure the hook is installed early enough so that plugins
69100
# installed via setuptools are re-written.
70-
ham = testdir.tmpdir.join('hampkg').ensure(dir=1)
71-
ham.join('__init__.py').write("""
72-
import pytest
101+
testdir.tmpdir.join('hampkg').ensure(dir=1)
102+
contents = {
103+
'hampkg/__init__.py': """
104+
import pytest
73105
74-
@pytest.fixture
75-
def check_first2():
76-
def check(values, value):
77-
assert values.pop(0) == value
78-
return check
79-
""")
80-
testdir.makepyfile(
81-
spamplugin="""
106+
@pytest.fixture
107+
def check_first2():
108+
def check(values, value):
109+
assert values.pop(0) == value
110+
return check
111+
""",
112+
'spamplugin.py': """
82113
import pytest
83114
from hampkg import check_first2
84115
@@ -88,7 +119,7 @@ def check(values, value):
88119
assert values.pop(0) == value
89120
return check
90121
""",
91-
mainwrapper="""
122+
'mainwrapper.py': """
92123
import pytest, pkg_resources
93124
94125
class DummyDistInfo:
@@ -116,14 +147,15 @@ def iter_entry_points(name):
116147
pkg_resources.iter_entry_points = iter_entry_points
117148
pytest.main()
118149
""",
119-
test_foo="""
150+
'test_foo.py': """
120151
def test(check_first):
121152
check_first([10, 30], 30)
122153
123154
def test2(check_first2):
124155
check_first([10, 30], 30)
125156
""",
126-
)
157+
}
158+
testdir.makepyfile(**contents)
127159
result = testdir.run(sys.executable, 'mainwrapper.py', '-s', '--assert=%s' % mode)
128160
if mode == 'plain':
129161
expected = 'E AssertionError'
@@ -135,6 +167,47 @@ def test2(check_first2):
135167
assert 0
136168
result.stdout.fnmatch_lines([expected])
137169

170+
def test_rewrite_ast(self, testdir):
171+
testdir.tmpdir.join('pkg').ensure(dir=1)
172+
contents = {
173+
'pkg/__init__.py': """
174+
import pytest
175+
pytest.register_assert_rewrite('pkg.helper')
176+
""",
177+
'pkg/helper.py': """
178+
def tool():
179+
a, b = 2, 3
180+
assert a == b
181+
""",
182+
'pkg/plugin.py': """
183+
import pytest, pkg.helper
184+
@pytest.fixture
185+
def tool():
186+
return pkg.helper.tool
187+
""",
188+
'pkg/other.py': """
189+
l = [3, 2]
190+
def tool():
191+
assert l.pop() == 3
192+
""",
193+
'conftest.py': """
194+
pytest_plugins = ['pkg.plugin']
195+
""",
196+
'test_pkg.py': """
197+
import pkg.other
198+
def test_tool(tool):
199+
tool()
200+
def test_other():
201+
pkg.other.tool()
202+
""",
203+
}
204+
testdir.makepyfile(**contents)
205+
result = testdir.runpytest_subprocess('--assert=rewrite')
206+
result.stdout.fnmatch_lines(['>*assert a == b*',
207+
'E*assert 2 == 3*',
208+
'>*assert l.pop() == 3*',
209+
'E*AssertionError*re-run*'])
210+
138211

139212
class TestBinReprIntegration:
140213

0 commit comments

Comments
 (0)