Skip to content

Commit 38888ab

Browse files
authored
Fix more than 100 mypy errors in test_ssl.py (#1395)
1 parent ee017b2 commit 38888ab

File tree

1 file changed

+40
-22
lines changed

1 file changed

+40
-22
lines changed

tests/test_ssl.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import select
1313
import sys
1414
import time
15+
import typing
1516
import uuid
1617
from errno import (
1718
EAFNOSUPPORT,
@@ -156,7 +157,7 @@
156157
"""
157158

158159

159-
def socket_any_family():
160+
def socket_any_family() -> socket:
160161
try:
161162
return socket(AF_INET)
162163
except OSError as e:
@@ -165,7 +166,7 @@ def socket_any_family():
165166
raise
166167

167168

168-
def loopback_address(socket):
169+
def loopback_address(socket: socket) -> str:
169170
if socket.family == AF_INET:
170171
return "127.0.0.1"
171172
else:
@@ -194,7 +195,7 @@ def verify_cb(conn, cert, errnum, depth, ok):
194195
return ok
195196

196197

197-
def socket_pair():
198+
def socket_pair() -> tuple[socket, socket]:
198199
"""
199200
Establish and return a pair of network sockets connected to each other.
200201
"""
@@ -225,7 +226,7 @@ def socket_pair():
225226
return (server, client)
226227

227228

228-
def handshake(client, server):
229+
def handshake(client: Connection, server: Connection) -> None:
229230
conns = [client, server]
230231
while conns:
231232
for conn in conns:
@@ -322,13 +323,17 @@ def _create_certificate_chain():
322323
]
323324

324325

325-
def loopback_client_factory(socket, version=SSLv23_METHOD):
326+
def loopback_client_factory(
327+
socket: socket, version: int = SSLv23_METHOD
328+
) -> Connection:
326329
client = Connection(Context(version), socket)
327330
client.set_connect_state()
328331
return client
329332

330333

331-
def loopback_server_factory(socket, version=SSLv23_METHOD):
334+
def loopback_server_factory(
335+
socket: socket | None, version: int = SSLv23_METHOD
336+
) -> Connection:
332337
ctx = Context(version)
333338
ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
334339
ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
@@ -337,7 +342,10 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
337342
return server
338343

339344

340-
def loopback(server_factory=None, client_factory=None):
345+
def loopback(
346+
server_factory: typing.Callable[[socket], Connection] | None = None,
347+
client_factory: typing.Callable[[socket], Connection] | None = None,
348+
) -> tuple[Connection, Connection]:
341349
"""
342350
Create a connected socket pair and force two connected SSL sockets
343351
to talk to each other via memory BIOs.
@@ -348,17 +356,19 @@ def loopback(server_factory=None, client_factory=None):
348356
client_factory = loopback_client_factory
349357

350358
(server, client) = socket_pair()
351-
server = server_factory(server)
352-
client = client_factory(client)
359+
tls_server = server_factory(server)
360+
tls_client = client_factory(client)
353361

354-
handshake(client, server)
362+
handshake(tls_client, tls_server)
355363

356-
server.setblocking(True)
357-
client.setblocking(True)
358-
return server, client
364+
tls_server.setblocking(True)
365+
tls_client.setblocking(True)
366+
return tls_server, tls_client
359367

360368

361-
def interact_in_memory(client_conn, server_conn):
369+
def interact_in_memory(
370+
client_conn: Connection, server_conn: Connection
371+
) -> None:
362372
"""
363373
Try to read application bytes from each of the two `Connection` objects.
364374
Copy bytes back and forth between their send/receive buffers for as long
@@ -404,7 +414,9 @@ def interact_in_memory(client_conn, server_conn):
404414
write.bio_write(dirty)
405415

406416

407-
def handshake_in_memory(client_conn, server_conn):
417+
def handshake_in_memory(
418+
client_conn: Connection, server_conn: Connection
419+
) -> None:
408420
"""
409421
Perform the TLS handshake between two `Connection` instances connected to
410422
each other via memory BIOs.
@@ -620,7 +632,7 @@ def test_method(self) -> None:
620632
Context(meth)
621633

622634
with pytest.raises(TypeError):
623-
Context("")
635+
Context("") # type: ignore[arg-type]
624636
with pytest.raises(ValueError):
625637
Context(13)
626638

@@ -690,11 +702,11 @@ def test_use_certificate_file_wrong_args(self) -> None:
690702
"""
691703
ctx = Context(SSLv23_METHOD)
692704
with pytest.raises(TypeError):
693-
ctx.use_certificate_file(object(), FILETYPE_PEM)
705+
ctx.use_certificate_file(object(), FILETYPE_PEM) # type: ignore[arg-type]
694706
with pytest.raises(TypeError):
695-
ctx.use_certificate_file(b"somefile", object())
707+
ctx.use_certificate_file(b"somefile", object()) # type: ignore[arg-type]
696708
with pytest.raises(TypeError):
697-
ctx.use_certificate_file(object(), FILETYPE_PEM)
709+
ctx.use_certificate_file(object(), FILETYPE_PEM) # type: ignore[arg-type]
698710

699711
def test_use_certificate_file_missing(self, tmpfile) -> None:
700712
"""
@@ -1070,7 +1082,7 @@ def _load_verify_locations_test(self, *args):
10701082
# connection will fail.
10711083
clientContext.set_verify(
10721084
VERIFY_PEER,
1073-
lambda conn, cert, errno, depth, preverify_ok: preverify_ok,
1085+
lambda conn, cert, errno, depth, preverify_ok: bool(preverify_ok),
10741086
)
10751087

10761088
clientSSL = Connection(clientContext, client)
@@ -1094,6 +1106,7 @@ def _load_verify_locations_test(self, *args):
10941106
handshake(clientSSL, serverSSL)
10951107

10961108
cert = clientSSL.get_peer_certificate()
1109+
assert cert is not None
10971110
assert cert.get_subject().CN == "Testing Root CA"
10981111

10991112
cryptography_cert = clientSSL.get_peer_certificate(
@@ -1228,6 +1241,7 @@ def test_fallback_default_verify_paths(self, monkeypatch) -> None:
12281241
)
12291242
context.set_default_verify_paths()
12301243
store = context.get_cert_store()
1244+
assert store is not None
12311245
sk_obj = _lib.X509_STORE_get0_objects(store._store)
12321246
assert sk_obj != _ffi.NULL
12331247
num = _lib.sk_X509_OBJECT_num(sk_obj)
@@ -1323,7 +1337,9 @@ def test_add_extra_chain_cert_invalid_cert(self) -> None:
13231337
with pytest.raises(TypeError):
13241338
context.add_extra_chain_cert(object())
13251339

1326-
def _handshake_test(self, serverContext, clientContext):
1340+
def _handshake_test(
1341+
self, serverContext: Context, clientContext: Context
1342+
) -> None:
13271343
"""
13281344
Verify that a client and server created with the given contexts can
13291345
successfully handshake and communicate.
@@ -2691,12 +2707,14 @@ def test_get_verified_chain(self) -> None:
26912707
interact_in_memory(client, server)
26922708

26932709
chain = client.get_verified_chain()
2710+
assert chain is not None
26942711
assert len(chain) == 3
26952712
assert "Server Certificate" == chain[0].get_subject().CN
26962713
assert "Intermediate Certificate" == chain[1].get_subject().CN
26972714
assert "Authority Certificate" == chain[2].get_subject().CN
26982715

26992716
cryptography_chain = client.get_verified_chain(as_cryptography=True)
2717+
assert cryptography_chain is not None
27002718
assert len(cryptography_chain) == 3
27012719
assert (
27022720
cryptography_chain[0].subject.rfc4514_string()
@@ -4509,7 +4527,7 @@ def pump_membio(label, source, sink):
45094527
sink.bio_write(chunk)
45104528
return True
45114529

4512-
def pump():
4530+
def pump() -> None:
45134531
# Raises if there was no data to pump, to avoid infinite loops if
45144532
# we aren't making progress.
45154533
assert pump_membio("s -> c", s, c) or pump_membio("c -> s", c, s)

0 commit comments

Comments
 (0)