Skip to content

Commit a2b4bb7

Browse files
author
Bas van Beek
committed
TST,ENH: Print the relevant expression whenever a test_fail or test_reveal test fails
1 parent 90a8d4a commit a2b4bb7

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

numpy/typing/tests/test_typing.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import pytest
1313
import numpy as np
14+
import numpy.typing as npt
1415
from numpy.typing.mypy_plugin import (
1516
_PRECISION_DICT,
1617
_EXTENDED_PRECISION_LIST,
@@ -150,9 +151,9 @@ def test_fail(path: str) -> None:
150151

151152
target_line = lines[lineno - 1]
152153
if "# E:" in target_line:
153-
marker = target_line.split("# E:")[-1].strip()
154-
expected_error = errors.get(lineno)
155-
_test_fail(path, marker, expected_error, lineno)
154+
expression, _, marker = target_line.partition(" # E: ")
155+
expected_error = errors[lineno]
156+
_test_fail(path, expression, marker.strip(), expected_error.strip(), lineno)
156157
else:
157158
pytest.fail(
158159
f"Unexpected mypy output at line {lineno}\n\n{errors[lineno]}"
@@ -161,26 +162,29 @@ def test_fail(path: str) -> None:
161162

162163
_FAIL_MSG1 = """Extra error at line {}
163164
165+
Expression: {}
164166
Extra error: {!r}
165167
"""
166168

167169
_FAIL_MSG2 = """Error mismatch at line {}
168170
171+
Expression: {}
169172
Expected error: {!r}
170173
Observed error: {!r}
171174
"""
172175

173176

174177
def _test_fail(
175178
path: str,
179+
expression: str,
176180
error: str,
177181
expected_error: None | str,
178182
lineno: int,
179183
) -> None:
180184
if expected_error is None:
181-
raise AssertionError(_FAIL_MSG1.format(lineno, error))
185+
raise AssertionError(_FAIL_MSG1.format(lineno, expression, error))
182186
elif error not in expected_error:
183-
raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error))
187+
raise AssertionError(_FAIL_MSG2.format(lineno, expression, expected_error, error))
184188

185189

186190
def _construct_format_dict() -> dict[str, str]:
@@ -259,7 +263,7 @@ def _construct_format_dict() -> dict[str, str]:
259263
FORMAT_DICT: dict[str, str] = _construct_format_dict()
260264

261265

262-
def _parse_reveals(file: IO[str]) -> list[str]:
266+
def _parse_reveals(file: IO[str]) -> tuple[npt.NDArray[np.str_], list[str]]:
263267
"""Extract and parse all ``" # E: "`` comments from the passed
264268
file-like object.
265269
@@ -269,8 +273,10 @@ def _parse_reveals(file: IO[str]) -> list[str]:
269273
"""
270274
string = file.read().replace("*", "")
271275

272-
# Grab all `# E:`-based comments
273-
comments_array = np.char.partition(string.split("\n"), sep=" # E: ")[:, 2]
276+
# Grab all `# E:`-based comments and matching expressions
277+
expression_array, _, comments_array = np.char.partition(
278+
string.split("\n"), sep=" # E: "
279+
).T
274280
comments = "/n".join(comments_array)
275281

276282
# Only search for the `{*}` pattern within comments, otherwise
@@ -282,7 +288,7 @@ def _parse_reveals(file: IO[str]) -> list[str]:
282288
}
283289
fmt_str = comments.format(**kwargs)
284290

285-
return fmt_str.split("/n")
291+
return expression_array, fmt_str.split("/n")
286292

287293

288294
@pytest.mark.slow
@@ -295,7 +301,7 @@ def test_reveal(path: str) -> None:
295301
__tracebackhide__ = True
296302

297303
with open(path) as fin:
298-
lines = _parse_reveals(fin)
304+
expression_array, reveal_list = _parse_reveals(fin)
299305

300306
output_mypy = OUTPUT_MYPY
301307
assert path in output_mypy
@@ -310,27 +316,30 @@ def test_reveal(path: str) -> None:
310316
lineno = int(match.group('lineno')) - 1
311317
assert "Revealed type is" in error_line
312318

313-
marker = lines[lineno]
314-
_test_reveal(path, marker, error_line, 1 + lineno)
319+
marker = reveal_list[lineno]
320+
expression = expression_array[lineno]
321+
_test_reveal(path, expression, marker, error_line, 1 + lineno)
315322

316323

317324
_REVEAL_MSG = """Reveal mismatch at line {}
318325
326+
Expression: {}
319327
Expected reveal: {!r}
320328
Observed reveal: {!r}
321329
"""
322330

323331

324332
def _test_reveal(
325333
path: str,
334+
expression: str,
326335
reveal: str,
327336
expected_reveal: str,
328337
lineno: int,
329338
) -> None:
330339
"""Error-reporting helper function for `test_reveal`."""
331340
if reveal not in expected_reveal:
332341
raise AssertionError(
333-
_REVEAL_MSG.format(lineno, expected_reveal, reveal)
342+
_REVEAL_MSG.format(lineno, expression, expected_reveal, reveal)
334343
)
335344

336345

@@ -375,11 +384,15 @@ def test_extended_precision() -> None:
375384
output_mypy = OUTPUT_MYPY
376385
assert path in output_mypy
377386

387+
with open(path, "r") as f:
388+
expression_list = f.readlines()
389+
378390
for _msg in output_mypy[path]:
379391
*_, _lineno, msg_typ, msg = _msg.split(":")
380392

381393
msg = _strip_filename(msg)
382394
lineno = int(_lineno)
395+
expression = expression_list[lineno - 1].rstrip("\n")
383396
msg_typ = msg_typ.strip()
384397
assert msg_typ in {"error", "note"}
385398

@@ -388,8 +401,8 @@ def test_extended_precision() -> None:
388401
raise ValueError(f"Unexpected reveal line format: {lineno}")
389402
else:
390403
marker = FORMAT_DICT[LINENO_MAPPING[lineno]]
391-
_test_reveal(path, marker, msg, lineno)
404+
_test_reveal(path, expression, marker, msg, lineno)
392405
else:
393406
if msg_typ == "error":
394407
marker = "Module has no attribute"
395-
_test_fail(path, marker, msg, lineno)
408+
_test_fail(path, expression, marker, msg, lineno)

0 commit comments

Comments
 (0)