Skip to content

Commit 62a0478

Browse files
committed
Add type hints to the main plugin components as per #176
1 parent f39adda commit 62a0478

File tree

4 files changed

+103
-31
lines changed

4 files changed

+103
-31
lines changed

src/pytest_flask/_internal.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import functools
2+
from typing import Callable
3+
from typing import Literal
24
import warnings
35

6+
from pytest import Config as _PytestConfig
47

5-
def deprecated(reason):
8+
9+
_PytestScopeName = Literal["session", "package", "module", "class", "function"]
10+
11+
12+
def deprecated(reason: str) -> Callable:
613
"""Decorator which can be used to mark function or method as deprecated.
714
It will result a warning being emitted when the function is called."""
815

@@ -19,15 +26,15 @@ def deprecated_call(*args, **kwargs):
1926
return decorator
2027

2128

22-
def _rewrite_server_name(server_name, new_port):
29+
def _rewrite_server_name(server_name: str, new_port: str) -> str:
2330
"""Rewrite server port in ``server_name`` with ``new_port`` value."""
2431
sep = ":"
2532
if sep in server_name:
2633
server_name, _ = server_name.split(sep, 1)
2734
return sep.join((server_name, new_port))
2835

2936

30-
def _determine_scope(*, fixture_name, config):
37+
def _determine_scope(*, fixture_name: str, config: _PytestConfig) -> _PytestScopeName:
3138
return config.getini("live_server_scope")
3239

3340

src/pytest_flask/fixtures.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
#!/usr/bin/env python
22
import socket
3-
3+
from typing import Any
4+
from typing import cast
5+
from typing import Generator
6+
7+
from flask import Flask as _FlaskApp
8+
from flask.config import Config as _FlaskAppConfig
9+
from flask.testing import FlaskClient as _FlaskTestClient
10+
from pytest import Config as _PytestConfig
11+
from pytest import FixtureRequest as _PytestFixtureRequest
412
import pytest
513

614
from ._internal import _determine_scope
@@ -10,7 +18,7 @@
1018

1119

1220
@pytest.fixture
13-
def client(app):
21+
def client(app: _FlaskApp) -> Generator[_FlaskTestClient, Any, Any]:
1422
"""A Flask test client. An instance of :class:`flask.testing.TestClient`
1523
by default.
1624
"""
@@ -19,7 +27,7 @@ def client(app):
1927

2028

2129
@pytest.fixture
22-
def client_class(request, client):
30+
def client_class(request: _PytestFixtureRequest, client: _FlaskTestClient) -> None:
2331
"""Uses to set a ``client`` class attribute to current Flask test client::
2432
2533
@pytest.mark.usefixtures('client_class')
@@ -37,8 +45,10 @@ def test_login(self):
3745
request.cls.client = client
3846

3947

40-
@pytest.fixture(scope=_determine_scope)
41-
def live_server(request, app, pytestconfig): # pragma: no cover
48+
@pytest.fixture(scope=_determine_scope) # type: ignore[arg-type]
49+
def live_server(
50+
request: _PytestFixtureRequest, app: _FlaskApp, pytestconfig: _PytestConfig
51+
) -> Generator[LiveServer, Any, Any]: # pragma: no cover
4252
"""Run application in a separate process.
4353
4454
When the ``live_server`` fixture is applied, the ``url_for`` function
@@ -64,34 +74,36 @@ def test_server_is_up_and_running(live_server):
6474
port = s.getsockname()[1]
6575
s.close()
6676

67-
host = pytestconfig.getvalue("live_server_host")
77+
host = cast(str, pytestconfig.getvalue("live_server_host"))
6878

6979
# Explicitly set application ``SERVER_NAME`` for test suite
7080
original_server_name = app.config["SERVER_NAME"] or "localhost.localdomain"
7181
final_server_name = _rewrite_server_name(original_server_name, str(port))
7282
app.config["SERVER_NAME"] = final_server_name
7383

74-
wait = request.config.getvalue("live_server_wait")
75-
clean_stop = request.config.getvalue("live_server_clean_stop")
84+
wait = cast(int, request.config.getvalue("live_server_wait"))
85+
clean_stop = cast(bool, request.config.getvalue("live_server_clean_stop"))
86+
7687
server = LiveServer(app, host, port, wait, clean_stop)
7788
if request.config.getvalue("start_live_server"):
7889
server.start()
7990

8091
request.addfinalizer(server.stop)
92+
8193
yield server
8294

8395
if original_server_name is not None:
8496
app.config["SERVER_NAME"] = original_server_name
8597

8698

8799
@pytest.fixture
88-
def config(app):
100+
def config(app: _FlaskApp) -> _FlaskAppConfig:
89101
"""An application config."""
90102
return app.config
91103

92104

93105
@pytest.fixture(params=["application/json", "text/html"])
94-
def mimetype(request):
106+
def mimetype(request) -> str:
95107
return request.param
96108

97109

src/pytest_flask/live_server.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,34 @@
11
import logging
22
import multiprocessing
3+
from multiprocessing import Process
34
import os
45
import platform
56
import signal
67
import socket
78
import time
9+
from typing import Any
10+
from typing import cast
11+
from typing import Protocol
12+
from typing import Union
813

914
import pytest
1015

1116

17+
class _SupportsFlaskAppRun(Protocol):
18+
def run(
19+
self,
20+
host: str | None = None,
21+
port: int | None = None,
22+
debug: bool | None = None,
23+
load_dotenv: bool = True,
24+
**options: Any,
25+
) -> None:
26+
...
27+
28+
1229
# force 'fork' on macOS
1330
if platform.system() == "Darwin":
14-
multiprocessing = multiprocessing.get_context("fork")
31+
multiprocessing = multiprocessing.get_context("fork") # type: ignore
1532

1633

1734
class LiveServer: # pragma: no cover
@@ -25,27 +42,37 @@ class LiveServer: # pragma: no cover
2542
application is not started.
2643
"""
2744

28-
def __init__(self, app, host, port, wait, clean_stop=False):
45+
def __init__(
46+
self,
47+
app: _SupportsFlaskAppRun,
48+
host: str,
49+
port: int,
50+
wait: int,
51+
clean_stop: bool = False,
52+
):
2953
self.app = app
3054
self.port = port
3155
self.host = host
3256
self.wait = wait
3357
self.clean_stop = clean_stop
34-
self._process = None
58+
self._process: Union[Process, None] = None
3559

36-
def start(self):
60+
def start(self) -> None:
3761
"""Start application in a separate process."""
3862

39-
def worker(app, host, port):
63+
def worker(app: _SupportsFlaskAppRun, host: str, port: int) -> None:
4064
app.run(host=host, port=port, use_reloader=False, threaded=True)
4165

42-
self._process = multiprocessing.Process(
43-
target=worker, args=(self.app, self.host, self.port)
66+
self._process = cast(
67+
Process,
68+
multiprocessing.Process(
69+
target=worker, args=(self.app, self.host, self.port)
70+
),
4471
)
4572
self._process.daemon = True
4673
self._process.start()
4774

48-
keep_trying = True
75+
keep_trying: bool = True
4976
start_time = time.time()
5077
while keep_trying:
5178
elapsed_time = time.time() - start_time
@@ -57,7 +84,7 @@ def worker(app, host, port):
5784
if self._is_ready():
5885
keep_trying = False
5986

60-
def _is_ready(self):
87+
def _is_ready(self) -> bool:
6188
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
6289
try:
6390
sock.connect((self.host, self.port))
@@ -69,13 +96,13 @@ def _is_ready(self):
6996
sock.close()
7097
return ret
7198

72-
def url(self, url=""):
99+
def url(self, url: str = "") -> str:
73100
"""Returns the complete url based on server options."""
74101
return "http://{host!s}:{port!s}{url!s}".format(
75102
host=self.host, port=self.port, url=url
76103
)
77104

78-
def stop(self):
105+
def stop(self) -> None:
79106
"""Stop application process."""
80107
if self._process:
81108
if self.clean_stop and self._stop_cleanly():
@@ -84,14 +111,17 @@ def stop(self):
84111
# If it's still alive, kill it
85112
self._process.terminate()
86113

87-
def _stop_cleanly(self, timeout=5):
114+
def _stop_cleanly(self, timeout: int = 5) -> bool:
88115
"""Attempts to stop the server cleanly by sending a SIGINT
89116
signal and waiting for ``timeout`` seconds.
90117
91118
:return: True if the server was cleanly stopped, False otherwise.
92119
"""
120+
if not self._process:
121+
return True
122+
93123
try:
94-
os.kill(self._process.pid, signal.SIGINT)
124+
os.kill(cast(int, self._process.pid), signal.SIGINT)
95125
self._process.join(timeout)
96126
return True
97127
except Exception as ex:

src/pytest_flask/plugin.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
:copyright: (c) by Vital Kudzelka
66
:license: MIT
77
"""
8+
from typing import Any
9+
from typing import Protocol
10+
from typing import Type
11+
from typing import TypeVar
12+
13+
from _pytest.config import Config as _PytestConfig
814
import pytest
915

1016
from .fixtures import accept_any
@@ -18,19 +24,36 @@
1824
from .pytest_compat import getfixturevalue
1925

2026

27+
_Response = TypeVar("_Response")
28+
29+
30+
class _SupportsPytestFlaskEqual(Protocol):
31+
status_code: int
32+
33+
def __eq__(self, other: Any) -> bool:
34+
...
35+
36+
def __ne__(self, other: Any) -> bool:
37+
...
38+
39+
2140
class JSONResponse:
2241
"""Mixin with testing helper methods for JSON responses."""
2342

24-
def __eq__(self, other):
43+
status_code: int
44+
45+
def __eq__(self, other) -> bool:
2546
if isinstance(other, int):
2647
return self.status_code == other
2748
return super().__eq__(other)
2849

29-
def __ne__(self, other):
50+
def __ne__(self, other) -> bool:
3051
return not self == other
3152

3253

33-
def pytest_assertrepr_compare(op, left, right):
54+
def pytest_assertrepr_compare(
55+
op: str, left: _SupportsPytestFlaskEqual, right: int
56+
) -> list[str] | None:
3457
if isinstance(left, JSONResponse) and op == "==" and isinstance(right, int):
3558
return [
3659
"Mismatch in status code for response: {} != {}".format(
@@ -42,7 +65,7 @@ def pytest_assertrepr_compare(op, left, right):
4265
return None
4366

4467

45-
def _make_test_response_class(response_class):
68+
def _make_test_response_class(response_class: Type[_Response]) -> Type[_Response]:
4669
"""Extends the response class with special attribute to test JSON
4770
responses. Don't override user-defined `json` attribute if any.
4871
@@ -186,7 +209,7 @@ def pytest_addoption(parser):
186209
)
187210

188211

189-
def pytest_configure(config):
212+
def pytest_configure(config: _PytestConfig) -> None:
190213
config.addinivalue_line(
191214
"markers", "app(options): pass options to your application factory"
192215
)

0 commit comments

Comments
 (0)