Skip to content

Commit 6d21218

Browse files
committed
Add strict type annotations
1 parent 53e090e commit 6d21218

23 files changed

+493
-266
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
* Provide ``X-Forwarded`` middleware that filters out trusted values (#153)
99

10+
* Add type annotations
11+
1012
0.1.2 (2018-03-01)
1113
==================
1214

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ endif
1111

1212
.PHONY: lint
1313
lint: fmt
14-
mypy aiohttp_remotes
14+
mypy --strict --show-error-codes aiohttp_remotes tests
1515

1616
test:
1717
pytest tests

aiohttp_remotes/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
__version__ = "1.0.0a0"
1010

1111

12+
from typing_extensions import Protocol
13+
14+
from aiohttp import web
15+
1216
from .allowed_hosts import AllowedHosts
1317
from .basic_auth import BasicAuth
1418
from .cloudflare import Cloudflare
@@ -17,7 +21,12 @@
1721
from .x_forwarded import XForwardedFiltered, XForwardedRelaxed, XForwardedStrict
1822

1923

20-
async def setup(app, *tools):
24+
class _Tool(Protocol):
25+
async def setup(self, app: web.Application) -> None:
26+
...
27+
28+
29+
async def setup(app: web.Application, *tools: _Tool) -> None:
2130
for tool in tools:
2231
await tool.setup(app)
2332

aiohttp_remotes/abc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import abc
22

3+
from typing_extensions import NoReturn
4+
35
from aiohttp import web
46

57

68
class ABC(abc.ABC):
79
@abc.abstractmethod
8-
async def setup(self, app):
10+
async def setup(self, app: web.Application) -> None:
911
pass # pragma: no cover
1012

11-
async def raise_error(self, request):
13+
async def raise_error(self, request: web.Request) -> NoReturn:
1214
raise web.HTTPBadRequest()

aiohttp_remotes/allowed_hosts.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,43 @@
1+
from typing import Awaitable, Callable, Iterable, Set, Union
2+
13
from aiohttp import web
24

35
from .abc import ABC
46

57

68
class ANY:
7-
def __contains__(self, item):
9+
def __contains__(self, item: object) -> bool:
810
return True
911

1012

1113
class AllowedHosts(ABC):
12-
def __init__(self, allowed_hosts=("*",), *, white_paths=()):
13-
allowed_hosts = set(allowed_hosts)
14-
15-
if "*" in allowed_hosts:
16-
allowed_hosts = ANY()
17-
18-
self._allowed_hosts = allowed_hosts
14+
def __init__(
15+
self,
16+
allowed_hosts: Iterable[str] = ("*",),
17+
*,
18+
white_paths: Iterable[str] = (),
19+
) -> None:
20+
real_allowed_hosts: Union[Set[str], ANY] = set(allowed_hosts)
21+
22+
if "*" in real_allowed_hosts:
23+
real_allowed_hosts = ANY()
24+
25+
self._allowed_hosts = real_allowed_hosts
1926
self._white_paths = set(white_paths)
2027

21-
async def setup(self, app):
28+
async def setup(self, app: web.Application) -> None:
2229
app.middlewares.append(self.middleware)
2330

2431
@web.middleware
25-
async def middleware(self, request, handler):
32+
async def middleware(
33+
self,
34+
request: web.Request,
35+
handler: Callable[[web.Request], Awaitable[web.StreamResponse]],
36+
) -> web.StreamResponse:
2637
if (
2738
request.path not in self._white_paths
2839
and request.host not in self._allowed_hosts
2940
):
30-
await self.raise_error(request)
41+
return await self.raise_error(request)
3142

3243
return await handler(request)

aiohttp_remotes/basic_auth.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,43 @@
11
import base64
22
import binascii
3+
from typing import Awaitable, Callable, Iterable
4+
5+
from typing_extensions import NoReturn
36

47
from aiohttp import hdrs, web
58

69
from .abc import ABC
710

811

912
class BasicAuth(ABC):
10-
def __init__(self, username, password, realm, *, white_paths=()):
13+
def __init__(
14+
self,
15+
username: str,
16+
password: str,
17+
realm: str,
18+
*,
19+
white_paths: Iterable[str] = (),
20+
) -> None:
1121
self._username = username
1222
self._password = password
1323
self._realm = realm
1424
self._white_paths = set(white_paths)
1525

16-
async def setup(self, app):
26+
async def setup(self, app: web.Application) -> None:
1727
app.middlewares.append(self.middleware)
1828

19-
async def raise_error(self, request):
29+
async def raise_error(self, request: web.Request) -> NoReturn:
2030
raise web.HTTPUnauthorized(
2131
headers={hdrs.WWW_AUTHENTICATE: f"Basic realm={self._realm}"},
2232
)
2333

2434
@web.middleware
25-
async def middleware(self, request, handler):
35+
async def middleware(
36+
self,
37+
request: web.Request,
38+
handler: Callable[[web.Request], Awaitable[web.StreamResponse]],
39+
) -> web.StreamResponse:
40+
2641
if request.path not in self._white_paths:
2742
auth_header = request.headers.get(hdrs.AUTHORIZATION)
2843

@@ -34,16 +49,16 @@ async def middleware(self, request, handler):
3449

3550
auth_decoded = base64.decodebytes(secret).decode("utf-8")
3651
except (UnicodeDecodeError, UnicodeEncodeError, binascii.Error):
37-
await self.raise_error(request)
52+
return await self.raise_error(request)
3853

3954
credentials = auth_decoded.split(":")
4055

4156
if len(credentials) != 2:
42-
await self.raise_error(request)
57+
return await self.raise_error(request)
4358

4459
username, password = credentials
4560

4661
if username != self._username or password != self._password:
47-
await self.raise_error(request)
62+
return await self.raise_error(request)
4863

4964
return await handler(request)

aiohttp_remotes/cloudflare.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
from ipaddress import ip_address, ip_network
2+
from typing import Awaitable, Callable, Optional, Set
23

34
import aiohttp
45
from aiohttp import web
56

67
from .abc import ABC
8+
from .exceptions import IPNetwork
79
from .log import logger
810

911

1012
class Cloudflare(ABC):
11-
def __init__(self, client=None):
12-
self._ip_networks = set()
13+
def __init__(self, client: Optional[aiohttp.ClientSession] = None) -> None:
14+
self._ip_networks: Set[IPNetwork] = set()
1315
self._client = client
1416

15-
def _parse_mask(self, text):
17+
def _parse_mask(self, text: str) -> Set[IPNetwork]:
1618
ret = set()
1719
for mask in text.splitlines():
1820
try:
19-
mask = ip_network(mask)
21+
real_mask = ip_network(mask)
2022
except (ValueError, TypeError):
2123
continue
2224

23-
ret.add(mask)
25+
ret.add(real_mask)
2426
return ret
2527

26-
async def setup(self, app):
28+
async def setup(self, app: web.Application) -> None:
2729
if self._client is not None: # pragma: no branch
2830
client = self._client
2931
else:
@@ -42,7 +44,11 @@ async def setup(self, app):
4244
app.middlewares.append(self.middleware)
4345

4446
@web.middleware
45-
async def middleware(self, request, handler):
47+
async def middleware(
48+
self,
49+
request: web.Request,
50+
handler: Callable[[web.Request], Awaitable[web.StreamResponse]],
51+
) -> web.StreamResponse:
4652
remote_ip = ip_address(request.remote)
4753

4854
for network in self._ip_networks:
@@ -54,4 +60,4 @@ async def middleware(self, request, handler):
5460
context = {"remote_ip": remote_ip}
5561
logger.error(msg, context)
5662

57-
await self.raise_error(request)
63+
return await self.raise_error(request)

aiohttp_remotes/exceptions.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,98 @@
1+
import builtins
2+
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
3+
from typing import Any, Dict, Sequence, Union, cast
4+
5+
from aiohttp import web
6+
17
from .log import logger
28

9+
IPAddress = Union[IPv4Address, IPv6Address]
10+
IPNetwork = Union[IPv4Network, IPv6Network]
11+
IPRule = Union[IPAddress, IPNetwork]
12+
Trusted = Sequence[Union["builtins.ellipsis", Sequence[IPRule]]]
13+
314

415
class RemoteError(Exception):
5-
def log(self, request):
16+
def log(self, request: web.Request) -> None:
617
raise NotImplementedError # pragma: no cover
718

819

920
class TooManyHeaders(RemoteError):
1021
@property
11-
def header(self):
12-
return self.args[0]
22+
def header(self) -> str:
23+
return cast(str, self.args[0])
1324

14-
def log(self, request):
25+
def log(self, request: web.Request) -> None:
1526
msg = "Too many headers for %(header)s"
16-
context = {"header": self.header}
27+
context: Dict[str, Any] = {"header": self.header}
1728
extra = context.copy()
1829
extra["request"] = request
1930
logger.error(msg, context, extra=extra)
2031

2132

2233
class IncorrectIPCount(RemoteError):
2334
@property
24-
def expected(self):
25-
return self.args[0]
35+
def expected(self) -> int:
36+
return cast(int, self.args[0])
2637

2738
@property
28-
def actual(self):
29-
return self.args[1]
39+
def actual(self) -> Sequence[IPAddress]:
40+
return cast(Sequence[IPAddress], self.args[1])
3041

31-
def log(self, request):
42+
def log(self, request: web.Request) -> None:
3243
msg = "Too many X-Forwarded-For values: %(actual)s, " "expected %(expected)s"
33-
context = {"actual": self.actual, "expected": self.expected}
44+
context: Dict[str, Any] = {"actual": self.actual, "expected": self.expected}
3445
extra = context.copy()
3546
extra["request"] = request
3647
logger.error(msg, context, extra=extra)
3748

3849

3950
class IncorrectForwardedCount(RemoteError):
4051
@property
41-
def expected(self):
42-
return self.args[0]
52+
def expected(self) -> int:
53+
return cast(int, self.args[0])
4354

4455
@property
45-
def actual(self):
46-
return self.args[1]
56+
def actual(self) -> int:
57+
return cast(int, self.args[1])
4758

48-
def log(self, request):
59+
def log(self, request: web.Request) -> None:
4960
msg = "Too many Forwarded values: %(actual)s, " "expected %(expected)s"
50-
context = {"actual": self.actual, "expected": self.expected}
61+
context: Dict[str, Any] = {"actual": self.actual, "expected": self.expected}
5162
extra = context.copy()
5263
extra["request"] = request
5364
logger.error(msg, context, extra=extra)
5465

5566

5667
class IncorrectProtoCount(RemoteError):
5768
@property
58-
def expected(self):
59-
return self.args[0]
69+
def expected(self) -> int:
70+
return cast(int, self.args[0])
6071

6172
@property
62-
def actual(self):
63-
return self.args[1]
73+
def actual(self) -> Sequence[str]:
74+
return cast(Sequence[str], self.args[1])
6475

65-
def log(self, request):
76+
def log(self, request: web.Request) -> None:
6677
msg = "Too many X-Forwarded-Proto values: %(actual)s, " "expected %(expected)s"
67-
context = {"actual": self.actual, "expected": self.expected}
78+
context: Dict[str, Any] = {"actual": self.actual, "expected": self.expected}
6879
extra = context.copy()
6980
extra["request"] = request
7081
logger.error(msg, context, extra=extra)
7182

7283

7384
class UntrustedIP(RemoteError):
7485
@property
75-
def ip(self):
76-
return self.args[0]
86+
def ip(self) -> IPAddress:
87+
return cast(IPAddress, self.args[0])
7788

7889
@property
79-
def trusted(self):
80-
return self.args[1]
90+
def trusted(self) -> Sequence[IPAddress]:
91+
return cast(Sequence[IPAddress], self.args[1])
8192

82-
def log(self, request):
93+
def log(self, request: web.Request) -> None:
8394
msg = "Untrusted IP: %(ip)s, trusted: %(trusted)s"
84-
context = {"ip": self.ip, "trusted": self.trusted}
95+
context: Dict[str, Any] = {"ip": self.ip, "trusted": self.trusted}
8596
extra = context.copy()
8697
extra["request"] = request
8798
logger.error(msg, context, extra=extra)

0 commit comments

Comments
 (0)