1111
1212import pytest
1313import numpy as np
14+ import numpy .typing as npt
1415from numpy .typing .mypy_plugin import (
1516 _PRECISION_DICT ,
1617 _EXTENDED_PRECISION_LIST ,
@@ -150,9 +151,9 @@ def test_fail(path: str) -> None:
150151
151152 target_line = lines [lineno - 1 ]
152153 if "# E:" in target_line :
153- marker = target_line .split ( " # E:" )[ - 1 ]. strip ( )
154- expected_error = errors . get ( lineno )
155- _test_fail (path , marker , expected_error , lineno )
154+ expression , _ , marker = target_line .partition ( " # E: " )
155+ expected_error = errors [ lineno ]
156+ _test_fail (path , expression , marker . strip () , expected_error . strip () , lineno )
156157 else :
157158 pytest .fail (
158159 f"Unexpected mypy output at line { lineno } \n \n { errors [lineno ]} "
@@ -161,26 +162,29 @@ def test_fail(path: str) -> None:
161162
162163_FAIL_MSG1 = """Extra error at line {}
163164
165+ Expression: {}
164166Extra error: {!r}
165167"""
166168
167169_FAIL_MSG2 = """Error mismatch at line {}
168170
171+ Expression: {}
169172Expected error: {!r}
170173Observed error: {!r}
171174"""
172175
173176
174177def _test_fail (
175178 path : str ,
179+ expression : str ,
176180 error : str ,
177181 expected_error : None | str ,
178182 lineno : int ,
179183) -> None :
180184 if expected_error is None :
181- raise AssertionError (_FAIL_MSG1 .format (lineno , error ))
185+ raise AssertionError (_FAIL_MSG1 .format (lineno , expression , error ))
182186 elif error not in expected_error :
183- raise AssertionError (_FAIL_MSG2 .format (lineno , expected_error , error ))
187+ raise AssertionError (_FAIL_MSG2 .format (lineno , expression , expected_error , error ))
184188
185189
186190def _construct_format_dict () -> dict [str , str ]:
@@ -259,7 +263,7 @@ def _construct_format_dict() -> dict[str, str]:
259263FORMAT_DICT : dict [str , str ] = _construct_format_dict ()
260264
261265
262- def _parse_reveals (file : IO [str ]) -> list [str ]:
266+ def _parse_reveals (file : IO [str ]) -> tuple [ npt . NDArray [ np . str_ ], list [str ] ]:
263267 """Extract and parse all ``" # E: "`` comments from the passed
264268 file-like object.
265269
@@ -269,8 +273,10 @@ def _parse_reveals(file: IO[str]) -> list[str]:
269273 """
270274 string = file .read ().replace ("*" , "" )
271275
272- # Grab all `# E:`-based comments
273- comments_array = np .char .partition (string .split ("\n " ), sep = " # E: " )[:, 2 ]
276+ # Grab all `# E:`-based comments and matching expressions
277+ expression_array , _ , comments_array = np .char .partition (
278+ string .split ("\n " ), sep = " # E: "
279+ ).T
274280 comments = "/n" .join (comments_array )
275281
276282 # Only search for the `{*}` pattern within comments, otherwise
@@ -282,7 +288,7 @@ def _parse_reveals(file: IO[str]) -> list[str]:
282288 }
283289 fmt_str = comments .format (** kwargs )
284290
285- return fmt_str .split ("/n" )
291+ return expression_array , fmt_str .split ("/n" )
286292
287293
288294@pytest .mark .slow
@@ -295,7 +301,7 @@ def test_reveal(path: str) -> None:
295301 __tracebackhide__ = True
296302
297303 with open (path ) as fin :
298- lines = _parse_reveals (fin )
304+ expression_array , reveal_list = _parse_reveals (fin )
299305
300306 output_mypy = OUTPUT_MYPY
301307 assert path in output_mypy
@@ -310,27 +316,30 @@ def test_reveal(path: str) -> None:
310316 lineno = int (match .group ('lineno' )) - 1
311317 assert "Revealed type is" in error_line
312318
313- marker = lines [lineno ]
314- _test_reveal (path , marker , error_line , 1 + lineno )
319+ marker = reveal_list [lineno ]
320+ expression = expression_array [lineno ]
321+ _test_reveal (path , expression , marker , error_line , 1 + lineno )
315322
316323
317324_REVEAL_MSG = """Reveal mismatch at line {}
318325
326+ Expression: {}
319327Expected reveal: {!r}
320328Observed reveal: {!r}
321329"""
322330
323331
324332def _test_reveal (
325333 path : str ,
334+ expression : str ,
326335 reveal : str ,
327336 expected_reveal : str ,
328337 lineno : int ,
329338) -> None :
330339 """Error-reporting helper function for `test_reveal`."""
331340 if reveal not in expected_reveal :
332341 raise AssertionError (
333- _REVEAL_MSG .format (lineno , expected_reveal , reveal )
342+ _REVEAL_MSG .format (lineno , expression , expected_reveal , reveal )
334343 )
335344
336345
@@ -375,11 +384,15 @@ def test_extended_precision() -> None:
375384 output_mypy = OUTPUT_MYPY
376385 assert path in output_mypy
377386
387+ with open (path , "r" ) as f :
388+ expression_list = f .readlines ()
389+
378390 for _msg in output_mypy [path ]:
379391 * _ , _lineno , msg_typ , msg = _msg .split (":" )
380392
381393 msg = _strip_filename (msg )
382394 lineno = int (_lineno )
395+ expression = expression_list [lineno - 1 ].rstrip ("\n " )
383396 msg_typ = msg_typ .strip ()
384397 assert msg_typ in {"error" , "note" }
385398
@@ -388,8 +401,8 @@ def test_extended_precision() -> None:
388401 raise ValueError (f"Unexpected reveal line format: { lineno } " )
389402 else :
390403 marker = FORMAT_DICT [LINENO_MAPPING [lineno ]]
391- _test_reveal (path , marker , msg , lineno )
404+ _test_reveal (path , expression , marker , msg , lineno )
392405 else :
393406 if msg_typ == "error" :
394407 marker = "Module has no attribute"
395- _test_fail (path , marker , msg , lineno )
408+ _test_fail (path , expression , marker , msg , lineno )
0 commit comments