Skip to content

Commit 2dc1f0c

Browse files
Add a few annotations.
- Add a typing configuration file with pyright set to strict. - Add annotations for normalizers.py and validators.py.
1 parent b8ef055 commit 2dc1f0c

File tree

3 files changed

+68
-40
lines changed

3 files changed

+68
-40
lines changed

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[tool.pyright]
2+
include = ["src/rfc3986"]
3+
ignore = ["tests"]
4+
pythonVersion = "3.8"
5+
typeCheckingMode = "strict"
6+
7+
reportPrivateUsage = "none"
8+
reportImportCycles = "warning"
9+
reportPropertyTypeMismatch = "warning"
10+
reportUnnecessaryTypeIgnoreComment = "warning"

src/rfc3986/normalizers.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,21 @@
1313
# limitations under the License.
1414
"""Module with functions to normalize components."""
1515
import re
16+
import typing as t
1617
from urllib.parse import quote as urlquote
1718

1819
from . import compat
1920
from . import misc
2021

2122

22-
def normalize_scheme(scheme):
23+
def normalize_scheme(scheme: str) -> str:
2324
"""Normalize the scheme component."""
2425
return scheme.lower()
2526

2627

27-
def normalize_authority(authority):
28+
def normalize_authority(
29+
authority: t.Tuple[t.Optional[str], t.Optional[str], t.Optional[str]],
30+
) -> str:
2831
"""Normalize an authority tuple to a string."""
2932
userinfo, host, port = authority
3033
result = ""
@@ -37,17 +40,17 @@ def normalize_authority(authority):
3740
return result
3841

3942

40-
def normalize_username(username):
43+
def normalize_username(username: str) -> str:
4144
"""Normalize a username to make it safe to include in userinfo."""
4245
return urlquote(username)
4346

4447

45-
def normalize_password(password):
48+
def normalize_password(password: str) -> str:
4649
"""Normalize a password to make safe for userinfo."""
4750
return urlquote(password)
4851

4952

50-
def normalize_host(host):
53+
def normalize_host(host: str) -> str:
5154
"""Normalize a host string."""
5255
if misc.IPv6_MATCHER.match(host):
5356
percent = host.find("%")
@@ -70,7 +73,7 @@ def normalize_host(host):
7073
return host.lower()
7174

7275

73-
def normalize_path(path):
76+
def normalize_path(path: str) -> str:
7477
"""Normalize the path string."""
7578
if not path:
7679
return path
@@ -79,14 +82,14 @@ def normalize_path(path):
7982
return remove_dot_segments(path)
8083

8184

82-
def normalize_query(query):
85+
def normalize_query(query: str) -> str:
8386
"""Normalize the query string."""
8487
if not query:
8588
return query
8689
return normalize_percent_characters(query)
8790

8891

89-
def normalize_fragment(fragment):
92+
def normalize_fragment(fragment: str) -> str:
9093
"""Normalize the fragment string."""
9194
if not fragment:
9295
return fragment
@@ -96,7 +99,7 @@ def normalize_fragment(fragment):
9699
PERCENT_MATCHER = re.compile("%[A-Fa-f0-9]{2}")
97100

98101

99-
def normalize_percent_characters(s):
102+
def normalize_percent_characters(s: str) -> str:
100103
"""All percent characters should be upper-cased.
101104
102105
For example, ``"%3afoo%DF%ab"`` should be turned into ``"%3Afoo%DF%AB"``.
@@ -108,14 +111,14 @@ def normalize_percent_characters(s):
108111
return s
109112

110113

111-
def remove_dot_segments(s):
114+
def remove_dot_segments(s: str) -> str:
112115
"""Remove dot segments from the string.
113116
114117
See also Section 5.2.4 of :rfc:`3986`.
115118
"""
116119
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
117120
segments = s.split("/") # Turn the path into a list of segments
118-
output = [] # Initialize the variable to use to store output
121+
output: list[str] = [] # Initialize the variable to use to store output
119122

120123
for segment in segments:
121124
# '.' is the current directory, so ignore it, it is superfluous
@@ -141,8 +144,13 @@ def remove_dot_segments(s):
141144

142145
return "/".join(output)
143146

144-
145-
def encode_component(uri_component, encoding):
147+
@t.overload
148+
def encode_component(uri_component: None, encoding: str) -> None:
149+
...
150+
@t.overload
151+
def encode_component(uri_component: str, encoding: str) -> str:
152+
...
153+
def encode_component(uri_component: t.Optional[str], encoding: str) -> t.Optional[str]:
146154
"""Encode the specific component in the provided encoding."""
147155
if uri_component is None:
148156
return uri_component

src/rfc3986/validators.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Module containing the validation logic for rfc3986."""
15+
import typing as t
16+
1517
from . import exceptions
1618
from . import misc
1719
from . import normalizers
20+
from . import uri
1821

1922

2023
class Validator:
@@ -50,9 +53,9 @@ class Validator:
5053

5154
def __init__(self):
5255
"""Initialize our default validations."""
53-
self.allowed_schemes = set()
54-
self.allowed_hosts = set()
55-
self.allowed_ports = set()
56+
self.allowed_schemes: set[str] = set()
57+
self.allowed_hosts: set[str] = set()
58+
self.allowed_ports: set[str] = set()
5659
self.allow_password = True
5760
self.required_components = {
5861
"scheme": False,
@@ -65,7 +68,7 @@ def __init__(self):
6568
}
6669
self.validated_components = self.required_components.copy()
6770

68-
def allow_schemes(self, *schemes):
71+
def allow_schemes(self, *schemes: str):
6972
"""Require the scheme to be one of the provided schemes.
7073
7174
.. versionadded:: 1.0
@@ -81,7 +84,7 @@ def allow_schemes(self, *schemes):
8184
self.allowed_schemes.add(normalizers.normalize_scheme(scheme))
8285
return self
8386

84-
def allow_hosts(self, *hosts):
87+
def allow_hosts(self, *hosts: str):
8588
"""Require the host to be one of the provided hosts.
8689
8790
.. versionadded:: 1.0
@@ -97,7 +100,7 @@ def allow_hosts(self, *hosts):
97100
self.allowed_hosts.add(normalizers.normalize_host(host))
98101
return self
99102

100-
def allow_ports(self, *ports):
103+
def allow_ports(self, *ports: str):
101104
"""Require the port to be one of the provided ports.
102105
103106
.. versionadded:: 1.0
@@ -141,7 +144,7 @@ def forbid_use_of_password(self):
141144
self.allow_password = False
142145
return self
143146

144-
def check_validity_of(self, *components):
147+
def check_validity_of(self, *components: str):
145148
"""Check the validity of the components provided.
146149
147150
This can be specified repeatedly.
@@ -155,7 +158,7 @@ def check_validity_of(self, *components):
155158
:rtype:
156159
Validator
157160
"""
158-
components = [c.lower() for c in components]
161+
components = tuple(c.lower() for c in components)
159162
for component in components:
160163
if component not in self.COMPONENT_NAMES:
161164
raise ValueError(f'"{component}" is not a valid component')
@@ -164,7 +167,7 @@ def check_validity_of(self, *components):
164167
)
165168
return self
166169

167-
def require_presence_of(self, *components):
170+
def require_presence_of(self, *components: str):
168171
"""Require the components provided.
169172
170173
This can be specified repeatedly.
@@ -178,7 +181,7 @@ def require_presence_of(self, *components):
178181
:rtype:
179182
Validator
180183
"""
181-
components = [c.lower() for c in components]
184+
components = tuple(c.lower() for c in components)
182185
for component in components:
183186
if component not in self.COMPONENT_NAMES:
184187
raise ValueError(f'"{component}" is not a valid component')
@@ -187,7 +190,7 @@ def require_presence_of(self, *components):
187190
)
188191
return self
189192

190-
def validate(self, uri):
193+
def validate(self, uri: "uri.URIReference"):
191194
"""Check a URI for conditions specified on this validator.
192195
193196
.. versionadded:: 1.0
@@ -229,7 +232,7 @@ def validate(self, uri):
229232
ensure_one_of(self.allowed_ports, uri, "port")
230233

231234

232-
def check_password(uri):
235+
def check_password(uri: "uri.URIReference") -> None:
233236
"""Assert that there is no password present in the uri."""
234237
userinfo = uri.userinfo
235238
if not userinfo:
@@ -240,7 +243,11 @@ def check_password(uri):
240243
raise exceptions.PasswordForbidden(uri)
241244

242245

243-
def ensure_one_of(allowed_values, uri, attribute):
246+
def ensure_one_of(
247+
allowed_values: t.Container[object],
248+
uri: "uri.URIReference",
249+
attribute: str,
250+
) -> None:
244251
"""Assert that the uri's attribute is one of the allowed values."""
245252
value = getattr(uri, attribute)
246253
if value is not None and allowed_values and value not in allowed_values:
@@ -251,7 +258,10 @@ def ensure_one_of(allowed_values, uri, attribute):
251258
)
252259

253260

254-
def ensure_required_components_exist(uri, required_components):
261+
def ensure_required_components_exist(
262+
uri: "uri.URIReference",
263+
required_components: t.Iterable[str],
264+
):
255265
"""Assert that all required components are present in the URI."""
256266
missing_components = sorted(
257267
component
@@ -262,7 +272,7 @@ def ensure_required_components_exist(uri, required_components):
262272
raise exceptions.MissingComponentError(uri, *missing_components)
263273

264274

265-
def is_valid(value, matcher, require):
275+
def is_valid(value: t.Optional[str], matcher: t.Pattern[str], require: bool) -> bool:
266276
"""Determine if a value is valid based on the provided matcher.
267277
268278
:param str value:
@@ -273,13 +283,13 @@ def is_valid(value, matcher, require):
273283
Whether or not the value is required.
274284
"""
275285
if require:
276-
return value is not None and matcher.match(value)
286+
return value is not None and bool(matcher.match(value))
277287

278288
# require is False and value is not None
279-
return value is None or matcher.match(value)
289+
return value is None or bool(matcher.match(value))
280290

281291

282-
def authority_is_valid(authority, host=None, require=False):
292+
def authority_is_valid(authority: str, host: t.Optional[str] = None, require: bool = False) -> bool:
283293
"""Determine if the authority string is valid.
284294
285295
:param str authority:
@@ -299,7 +309,7 @@ def authority_is_valid(authority, host=None, require=False):
299309
return validated
300310

301311

302-
def host_is_valid(host, require=False):
312+
def host_is_valid(host: t.Optional[str], require: bool = False) -> bool:
303313
"""Determine if the host string is valid.
304314
305315
:param str host:
@@ -319,7 +329,7 @@ def host_is_valid(host, require=False):
319329
return validated
320330

321331

322-
def scheme_is_valid(scheme, require=False):
332+
def scheme_is_valid(scheme: t.Optional[str], require: bool = False) -> bool:
323333
"""Determine if the scheme is valid.
324334
325335
:param str scheme:
@@ -334,7 +344,7 @@ def scheme_is_valid(scheme, require=False):
334344
return is_valid(scheme, misc.SCHEME_MATCHER, require)
335345

336346

337-
def path_is_valid(path, require=False):
347+
def path_is_valid(path: t.Optional[str], require: bool = False) -> bool:
338348
"""Determine if the path component is valid.
339349
340350
:param str path:
@@ -349,7 +359,7 @@ def path_is_valid(path, require=False):
349359
return is_valid(path, misc.PATH_MATCHER, require)
350360

351361

352-
def query_is_valid(query, require=False):
362+
def query_is_valid(query: t.Optional[str], require: bool = False) -> bool:
353363
"""Determine if the query component is valid.
354364
355365
:param str query:
@@ -364,7 +374,7 @@ def query_is_valid(query, require=False):
364374
return is_valid(query, misc.QUERY_MATCHER, require)
365375

366376

367-
def fragment_is_valid(fragment, require=False):
377+
def fragment_is_valid(fragment: t.Optional[str], require: bool = False) -> bool:
368378
"""Determine if the fragment component is valid.
369379
370380
:param str fragment:
@@ -379,7 +389,7 @@ def fragment_is_valid(fragment, require=False):
379389
return is_valid(fragment, misc.FRAGMENT_MATCHER, require)
380390

381391

382-
def valid_ipv4_host_address(host):
392+
def valid_ipv4_host_address(host: str) -> bool:
383393
"""Determine if the given host is a valid IPv4 address."""
384394
# If the host exists, and it might be IPv4, check each byte in the
385395
# address.
@@ -396,7 +406,7 @@ def valid_ipv4_host_address(host):
396406
_SUBAUTHORITY_VALIDATORS = {"userinfo", "host", "port"}
397407

398408

399-
def subauthority_component_is_valid(uri, component):
409+
def subauthority_component_is_valid(uri: "uri.URIReference", component: str) -> bool:
400410
"""Determine if the userinfo, host, and port are valid."""
401411
try:
402412
subauthority_dict = uri.authority_info()
@@ -420,9 +430,9 @@ def subauthority_component_is_valid(uri, component):
420430
return 0 <= port <= 65535
421431

422432

423-
def ensure_components_are_valid(uri, validated_components):
433+
def ensure_components_are_valid(uri: "uri.URIReference", validated_components: t.List[str]) -> None:
424434
"""Assert that all components are valid in the URI."""
425-
invalid_components = set()
435+
invalid_components: set[str] = set()
426436
for component in validated_components:
427437
if component in _SUBAUTHORITY_VALIDATORS:
428438
if not subauthority_component_is_valid(uri, component):

0 commit comments

Comments
 (0)