Skip to content

Commit 7a2d2d8

Browse files
authored
Merge pull request #5847 from bluetech/type-annotations-4
2/X Fix check_untyped_defs = True mypy errors
2 parents 6242777 + 1cc1ac5 commit 7a2d2d8

File tree

10 files changed

+130
-78
lines changed

10 files changed

+130
-78
lines changed

src/_pytest/assertion/rewrite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def find_spec(self, name, path=None, target=None):
7878
# there's nothing to rewrite there
7979
# python3.5 - python3.6: `namespace`
8080
# python3.7+: `None`
81-
or spec.origin in {None, "namespace"}
81+
or spec.origin == "namespace"
82+
or spec.origin is None
8283
# we can only rewrite source files
8384
or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
8485
# if the file doesn't exist, we can't rewrite it
@@ -743,8 +744,7 @@ def visit_Assert(self, assert_):
743744
from _pytest.warning_types import PytestAssertRewriteWarning
744745
import warnings
745746

746-
# Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121
747-
warnings.warn_explicit( # type: ignore
747+
warnings.warn_explicit(
748748
PytestAssertRewriteWarning(
749749
"assertion is always true, perhaps remove parentheses?"
750750
),

src/_pytest/cacheprovider.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
from collections import OrderedDict
10+
from typing import List
1011

1112
import attr
1213
import py
@@ -15,6 +16,9 @@
1516
from .pathlib import Path
1617
from .pathlib import resolve_from_str
1718
from .pathlib import rm_rf
19+
from _pytest import nodes
20+
from _pytest.config import Config
21+
from _pytest.main import Session
1822

1923
README_CONTENT = """\
2024
# pytest cache directory #
@@ -263,10 +267,12 @@ def __init__(self, config):
263267
self.active = config.option.newfirst
264268
self.cached_nodeids = config.cache.get("cache/nodeids", [])
265269

266-
def pytest_collection_modifyitems(self, session, config, items):
270+
def pytest_collection_modifyitems(
271+
self, session: Session, config: Config, items: List[nodes.Item]
272+
) -> None:
267273
if self.active:
268-
new_items = OrderedDict()
269-
other_items = OrderedDict()
274+
new_items = OrderedDict() # type: OrderedDict[str, nodes.Item]
275+
other_items = OrderedDict() # type: OrderedDict[str, nodes.Item]
270276
for item in items:
271277
if item.nodeid not in self.cached_nodeids:
272278
new_items[item.nodeid] = item

src/_pytest/capture.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import pytest
1414
from _pytest.compat import CaptureIO
15+
from _pytest.fixtures import FixtureRequest
1516

1617
patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
1718

@@ -241,13 +242,12 @@ def pytest_internalerror(self, excinfo):
241242
capture_fixtures = {"capfd", "capfdbinary", "capsys", "capsysbinary"}
242243

243244

244-
def _ensure_only_one_capture_fixture(request, name):
245-
fixtures = set(request.fixturenames) & capture_fixtures - {name}
245+
def _ensure_only_one_capture_fixture(request: FixtureRequest, name):
246+
fixtures = sorted(set(request.fixturenames) & capture_fixtures - {name})
246247
if fixtures:
247-
fixtures = sorted(fixtures)
248-
fixtures = fixtures[0] if len(fixtures) == 1 else fixtures
248+
arg = fixtures[0] if len(fixtures) == 1 else fixtures
249249
raise request.raiseerror(
250-
"cannot use {} and {} at the same time".format(fixtures, name)
250+
"cannot use {} and {} at the same time".format(arg, name)
251251
)
252252

253253

src/_pytest/doctest.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
import traceback
77
import warnings
88
from contextlib import contextmanager
9+
from typing import Dict
10+
from typing import List
11+
from typing import Optional
912
from typing import Sequence
1013
from typing import Tuple
14+
from typing import Union
1115

1216
import pytest
1317
from _pytest import outcomes
@@ -20,6 +24,10 @@
2024
from _pytest.python_api import approx
2125
from _pytest.warning_types import PytestWarning
2226

27+
if False: # TYPE_CHECKING
28+
import doctest
29+
from typing import Type
30+
2331
DOCTEST_REPORT_CHOICE_NONE = "none"
2432
DOCTEST_REPORT_CHOICE_CDIFF = "cdiff"
2533
DOCTEST_REPORT_CHOICE_NDIFF = "ndiff"
@@ -36,6 +44,8 @@
3644

3745
# Lazy definition of runner class
3846
RUNNER_CLASS = None
47+
# Lazy definition of output checker class
48+
CHECKER_CLASS = None # type: Optional[Type[doctest.OutputChecker]]
3949

4050

4151
def pytest_addoption(parser):
@@ -139,7 +149,7 @@ def __init__(self, failures):
139149
self.failures = failures
140150

141151

142-
def _init_runner_class():
152+
def _init_runner_class() -> "Type[doctest.DocTestRunner]":
143153
import doctest
144154

145155
class PytestDoctestRunner(doctest.DebugRunner):
@@ -177,12 +187,19 @@ def report_unexpected_exception(self, out, test, example, exc_info):
177187
return PytestDoctestRunner
178188

179189

180-
def _get_runner(checker=None, verbose=None, optionflags=0, continue_on_failure=True):
190+
def _get_runner(
191+
checker: Optional["doctest.OutputChecker"] = None,
192+
verbose: Optional[bool] = None,
193+
optionflags: int = 0,
194+
continue_on_failure: bool = True,
195+
) -> "doctest.DocTestRunner":
181196
# We need this in order to do a lazy import on doctest
182197
global RUNNER_CLASS
183198
if RUNNER_CLASS is None:
184199
RUNNER_CLASS = _init_runner_class()
185-
return RUNNER_CLASS(
200+
# Type ignored because the continue_on_failure argument is only defined on
201+
# PytestDoctestRunner, which is lazily defined so can't be used as a type.
202+
return RUNNER_CLASS( # type: ignore
186203
checker=checker,
187204
verbose=verbose,
188205
optionflags=optionflags,
@@ -211,7 +228,7 @@ def setup(self):
211228
def runtest(self):
212229
_check_all_skipped(self.dtest)
213230
self._disable_output_capturing_for_darwin()
214-
failures = []
231+
failures = [] # type: List[doctest.DocTestFailure]
215232
self.runner.run(self.dtest, out=failures)
216233
if failures:
217234
raise MultipleDoctestFailures(failures)
@@ -232,7 +249,9 @@ def _disable_output_capturing_for_darwin(self):
232249
def repr_failure(self, excinfo):
233250
import doctest
234251

235-
failures = None
252+
failures = (
253+
None
254+
) # type: Optional[List[Union[doctest.DocTestFailure, doctest.UnexpectedException]]]
236255
if excinfo.errisinstance((doctest.DocTestFailure, doctest.UnexpectedException)):
237256
failures = [excinfo.value]
238257
elif excinfo.errisinstance(MultipleDoctestFailures):
@@ -255,8 +274,10 @@ def repr_failure(self, excinfo):
255274
self.config.getoption("doctestreport")
256275
)
257276
if lineno is not None:
277+
assert failure.test.docstring is not None
258278
lines = failure.test.docstring.splitlines(False)
259279
# add line numbers to the left of the error message
280+
assert test.lineno is not None
260281
lines = [
261282
"%03d %s" % (i + test.lineno + 1, x)
262283
for (i, x) in enumerate(lines)
@@ -288,7 +309,7 @@ def reportinfo(self):
288309
return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name
289310

290311

291-
def _get_flag_lookup():
312+
def _get_flag_lookup() -> Dict[str, int]:
292313
import doctest
293314

294315
return dict(
@@ -340,7 +361,7 @@ def collect(self):
340361
optionflags = get_optionflags(self)
341362

342363
runner = _get_runner(
343-
verbose=0,
364+
verbose=False,
344365
optionflags=optionflags,
345366
checker=_get_checker(),
346367
continue_on_failure=_get_continue_on_failure(self.config),
@@ -419,7 +440,8 @@ def _find(self, tests, obj, name, module, source_lines, globs, seen):
419440
return
420441
with _patch_unwrap_mock_aware():
421442

422-
doctest.DocTestFinder._find(
443+
# Type ignored because this is a private function.
444+
doctest.DocTestFinder._find( # type: ignore
423445
self, tests, obj, name, module, source_lines, globs, seen
424446
)
425447

@@ -437,7 +459,7 @@ def _find(self, tests, obj, name, module, source_lines, globs, seen):
437459
finder = MockAwareDocTestFinder()
438460
optionflags = get_optionflags(self)
439461
runner = _get_runner(
440-
verbose=0,
462+
verbose=False,
441463
optionflags=optionflags,
442464
checker=_get_checker(),
443465
continue_on_failure=_get_continue_on_failure(self.config),
@@ -466,24 +488,7 @@ def func():
466488
return fixture_request
467489

468490

469-
def _get_checker():
470-
"""
471-
Returns a doctest.OutputChecker subclass that supports some
472-
additional options:
473-
474-
* ALLOW_UNICODE and ALLOW_BYTES options to ignore u'' and b''
475-
prefixes (respectively) in string literals. Useful when the same
476-
doctest should run in Python 2 and Python 3.
477-
478-
* NUMBER to ignore floating-point differences smaller than the
479-
precision of the literal number in the doctest.
480-
481-
An inner class is used to avoid importing "doctest" at the module
482-
level.
483-
"""
484-
if hasattr(_get_checker, "LiteralsOutputChecker"):
485-
return _get_checker.LiteralsOutputChecker()
486-
491+
def _init_checker_class() -> "Type[doctest.OutputChecker]":
487492
import doctest
488493
import re
489494

@@ -573,11 +578,31 @@ def _remove_unwanted_precision(self, want, got):
573578
offset += w.end() - w.start() - (g.end() - g.start())
574579
return got
575580

576-
_get_checker.LiteralsOutputChecker = LiteralsOutputChecker
577-
return _get_checker.LiteralsOutputChecker()
581+
return LiteralsOutputChecker
582+
583+
584+
def _get_checker() -> "doctest.OutputChecker":
585+
"""
586+
Returns a doctest.OutputChecker subclass that supports some
587+
additional options:
588+
589+
* ALLOW_UNICODE and ALLOW_BYTES options to ignore u'' and b''
590+
prefixes (respectively) in string literals. Useful when the same
591+
doctest should run in Python 2 and Python 3.
592+
593+
* NUMBER to ignore floating-point differences smaller than the
594+
precision of the literal number in the doctest.
595+
596+
An inner class is used to avoid importing "doctest" at the module
597+
level.
598+
"""
599+
global CHECKER_CLASS
600+
if CHECKER_CLASS is None:
601+
CHECKER_CLASS = _init_checker_class()
602+
return CHECKER_CLASS()
578603

579604

580-
def _get_allow_unicode_flag():
605+
def _get_allow_unicode_flag() -> int:
581606
"""
582607
Registers and returns the ALLOW_UNICODE flag.
583608
"""
@@ -586,7 +611,7 @@ def _get_allow_unicode_flag():
586611
return doctest.register_optionflag("ALLOW_UNICODE")
587612

588613

589-
def _get_allow_bytes_flag():
614+
def _get_allow_bytes_flag() -> int:
590615
"""
591616
Registers and returns the ALLOW_BYTES flag.
592617
"""
@@ -595,7 +620,7 @@ def _get_allow_bytes_flag():
595620
return doctest.register_optionflag("ALLOW_BYTES")
596621

597622

598-
def _get_number_flag():
623+
def _get_number_flag() -> int:
599624
"""
600625
Registers and returns the NUMBER flag.
601626
"""
@@ -604,7 +629,7 @@ def _get_number_flag():
604629
return doctest.register_optionflag("NUMBER")
605630

606631

607-
def _get_report_choice(key):
632+
def _get_report_choice(key: str) -> int:
608633
"""
609634
This function returns the actual `doctest` module flag value, we want to do it as late as possible to avoid
610635
importing `doctest` and all its dependencies when parsing options, as it adds overhead and breaks tests.

src/_pytest/logging.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import logging
33
import re
44
from contextlib import contextmanager
5+
from typing import AbstractSet
6+
from typing import Dict
7+
from typing import List
8+
from typing import Mapping
59

610
import py
711

@@ -32,14 +36,15 @@ class ColoredLevelFormatter(logging.Formatter):
3236
logging.INFO: {"green"},
3337
logging.DEBUG: {"purple"},
3438
logging.NOTSET: set(),
35-
}
39+
} # type: Mapping[int, AbstractSet[str]]
3640
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)")
3741

38-
def __init__(self, terminalwriter, *args, **kwargs):
42+
def __init__(self, terminalwriter, *args, **kwargs) -> None:
3943
super().__init__(*args, **kwargs)
4044
self._original_fmt = self._style._fmt
41-
self._level_to_fmt_mapping = {}
45+
self._level_to_fmt_mapping = {} # type: Dict[int, str]
4246

47+
assert self._fmt is not None
4348
levelname_fmt_match = self.LEVELNAME_FMT_REGEX.search(self._fmt)
4449
if not levelname_fmt_match:
4550
return
@@ -216,31 +221,31 @@ def catching_logs(handler, formatter=None, level=None):
216221
class LogCaptureHandler(logging.StreamHandler):
217222
"""A logging handler that stores log records and the log text."""
218223

219-
def __init__(self):
224+
def __init__(self) -> None:
220225
"""Creates a new log handler."""
221226
logging.StreamHandler.__init__(self, py.io.TextIO())
222-
self.records = []
227+
self.records = [] # type: List[logging.LogRecord]
223228

224-
def emit(self, record):
229+
def emit(self, record: logging.LogRecord) -> None:
225230
"""Keep the log records in a list in addition to the log text."""
226231
self.records.append(record)
227232
logging.StreamHandler.emit(self, record)
228233

229-
def reset(self):
234+
def reset(self) -> None:
230235
self.records = []
231236
self.stream = py.io.TextIO()
232237

233238

234239
class LogCaptureFixture:
235240
"""Provides access and control of log capturing."""
236241

237-
def __init__(self, item):
242+
def __init__(self, item) -> None:
238243
"""Creates a new funcarg."""
239244
self._item = item
240245
# dict of log name -> log level
241-
self._initial_log_levels = {} # Dict[str, int]
246+
self._initial_log_levels = {} # type: Dict[str, int]
242247

243-
def _finalize(self):
248+
def _finalize(self) -> None:
244249
"""Finalizes the fixture.
245250
246251
This restores the log levels changed by :meth:`set_level`.
@@ -453,7 +458,7 @@ def _create_formatter(self, log_format, log_date_format):
453458
):
454459
formatter = ColoredLevelFormatter(
455460
create_terminal_writer(self._config), log_format, log_date_format
456-
)
461+
) # type: logging.Formatter
457462
else:
458463
formatter = logging.Formatter(log_format, log_date_format)
459464

src/_pytest/nodes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def warn(self, warning):
139139
)
140140
)
141141
path, lineno = get_fslocation_from_item(self)
142-
# Type ignored: https://github.com/python/typeshed/pull/3121
143-
warnings.warn_explicit( # type: ignore
142+
warnings.warn_explicit(
144143
warning,
145144
category=None,
146145
filename=str(path),

0 commit comments

Comments
 (0)