Skip to content

Commit 9422c36

Browse files
authored
Bring us under 30 test_ssl type-check issues (#1406)
1 parent a3972a0 commit 9422c36

File tree

1 file changed

+64
-46
lines changed

1 file changed

+64
-46
lines changed

tests/test_ssl.py

Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,9 @@ def test_set_passwd_cb(self, tmpfile: bytes) -> None:
901901
pemFile = self._write_encrypted_pem(passphrase, tmpfile)
902902
calledWith = []
903903

904-
def passphraseCallback(maxlen, verify, extra):
904+
def passphraseCallback(
905+
maxlen: int, verify: bool, extra: None
906+
) -> bytes:
905907
calledWith.append((maxlen, verify, extra))
906908
return passphrase
907909

@@ -920,7 +922,9 @@ def test_passwd_callback_exception(self, tmpfile: bytes) -> None:
920922
"""
921923
pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
922924

923-
def passphraseCallback(maxlen, verify, extra):
925+
def passphraseCallback(
926+
maxlen: int, verify: bool, extra: None
927+
) -> bytes:
924928
raise RuntimeError("Sorry, I am a fail.")
925929

926930
context = Context(SSLv23_METHOD)
@@ -935,7 +939,9 @@ def test_passwd_callback_false(self, tmpfile: bytes) -> None:
935939
"""
936940
pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
937941

938-
def passphraseCallback(maxlen, verify, extra):
942+
def passphraseCallback(
943+
maxlen: int, verify: bool, extra: None
944+
) -> bytes:
939945
return b""
940946

941947
context = Context(SSLv23_METHOD)
@@ -950,11 +956,11 @@ def test_passwd_callback_non_string(self, tmpfile: bytes) -> None:
950956
"""
951957
pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
952958

953-
def passphraseCallback(maxlen, verify, extra):
959+
def passphraseCallback(maxlen: int, verify: bool, extra: None) -> int:
954960
return 10
955961

956962
context = Context(SSLv23_METHOD)
957-
context.set_passwd_cb(passphraseCallback)
963+
context.set_passwd_cb(passphraseCallback) # type: ignore[arg-type]
958964
# TODO: Surely this is the wrong error?
959965
with pytest.raises(ValueError):
960966
context.use_privatekey_file(pemFile)
@@ -968,7 +974,9 @@ def test_passwd_callback_too_long(self, tmpfile: bytes) -> None:
968974
passphrase = b"x" * 1024
969975
pemFile = self._write_encrypted_pem(passphrase, tmpfile)
970976

971-
def passphraseCallback(maxlen, verify, extra):
977+
def passphraseCallback(
978+
maxlen: int, verify: bool, extra: None
979+
) -> bytes:
972980
assert maxlen == 1024
973981
return passphrase + b"y"
974982

@@ -990,7 +998,7 @@ def test_set_info_callback(self) -> None:
990998

991999
called = []
9921000

993-
def info(conn, where, ret):
1001+
def info(conn: Connection, where: int, ret: int) -> None:
9941002
called.append((conn, where, ret))
9951003

9961004
context = Context(SSLv23_METHOD)
@@ -1028,7 +1036,7 @@ def test_set_keylog_callback(self) -> None:
10281036
"""
10291037
called = []
10301038

1031-
def keylog(conn, line):
1039+
def keylog(conn: Connection, line: bytes) -> None:
10321040
called.append((conn, line))
10331041

10341042
server_context = Context(TLSv1_2_METHOD)
@@ -1385,9 +1393,9 @@ def test_set_verify_callback_connection_argument(self) -> None:
13851393
serverConnection = Connection(serverContext, None)
13861394

13871395
class VerifyCallback:
1388-
def callback(self, connection, *args):
1396+
def callback(self, connection: Connection, *args) -> bool:
13891397
self.connection = connection
1390-
return 1
1398+
return True
13911399

13921400
verify = VerifyCallback()
13931401
clientContext = Context(SSLv23_METHOD)
@@ -1415,9 +1423,11 @@ def test_x509_in_verify_works(self) -> None:
14151423
)
14161424
serverConnection = Connection(serverContext, None)
14171425

1418-
def verify_cb_get_subject(conn, cert, errnum, depth, ok):
1426+
def verify_cb_get_subject(
1427+
conn: Connection, cert: X509, errnum: int, depth: int, ok: int
1428+
) -> bool:
14191429
assert cert.get_subject()
1420-
return 1
1430+
return True
14211431

14221432
clientContext = Context(SSLv23_METHOD)
14231433
clientContext.set_verify(VERIFY_PEER, verify_cb_get_subject)
@@ -1817,10 +1827,10 @@ def test_old_callback_forgotten(self) -> None:
18171827
a new callback, the one it replaces is dereferenced.
18181828
"""
18191829

1820-
def callback(connection): # pragma: no cover
1830+
def callback(connection: Connection) -> None: # pragma: no cover
18211831
pass
18221832

1823-
def replacement(connection): # pragma: no cover
1833+
def replacement(connection: Connection) -> None: # pragma: no cover
18241834
pass
18251835

18261836
context = Context(SSLv23_METHOD)
@@ -1851,7 +1861,7 @@ def test_no_servername(self) -> None:
18511861
"""
18521862
args = []
18531863

1854-
def servername(conn):
1864+
def servername(conn: Connection) -> None:
18551865
args.append((conn, conn.get_servername()))
18561866

18571867
context = Context(SSLv23_METHOD)
@@ -1888,7 +1898,7 @@ def test_servername(self) -> None:
18881898
"""
18891899
args = []
18901900

1891-
def servername(conn):
1901+
def servername(conn: Connection) -> None:
18921902
args.append((conn, conn.get_servername()))
18931903

18941904
context = Context(SSLv23_METHOD)
@@ -1926,7 +1936,7 @@ def test_alpn_success(self) -> None:
19261936
"""
19271937
select_args = []
19281938

1929-
def select(conn, options):
1939+
def select(conn: Connection, options: list[bytes]) -> bytes:
19301940
select_args.append((conn, options))
19311941
return b"spdy/2"
19321942

@@ -1974,7 +1984,7 @@ def test_alpn_set_on_connection(self) -> None:
19741984
"""
19751985
select_args = []
19761986

1977-
def select(conn, options):
1987+
def select(conn: Connection, options: list[bytes]) -> bytes:
19781988
select_args.append((conn, options))
19791989
return b"spdy/2"
19801990

@@ -2015,7 +2025,7 @@ def test_alpn_server_fail(self) -> None:
20152025
"""
20162026
select_args = []
20172027

2018-
def select(conn, options):
2028+
def select(conn: Connection, options: list[bytes]) -> bytes:
20192029
select_args.append((conn, options))
20202030
return b""
20212031

@@ -2054,7 +2064,7 @@ def test_alpn_no_server_overlap(self) -> None:
20542064
"""
20552065
refusal_args = []
20562066

2057-
def refusal(conn, options):
2067+
def refusal(conn: Connection, options: list[bytes]):
20582068
refusal_args.append((conn, options))
20592069
return NO_OVERLAPPING_PROTOCOLS
20602070

@@ -2094,15 +2104,15 @@ def test_alpn_select_cb_returns_invalid_value(self) -> None:
20942104
"""
20952105
invalid_cb_args = []
20962106

2097-
def invalid_cb(conn, options):
2107+
def invalid_cb(conn: Connection, options: list[bytes]) -> str:
20982108
invalid_cb_args.append((conn, options))
20992109
return "can't return unicode"
21002110

21012111
client_context = Context(SSLv23_METHOD)
21022112
client_context.set_alpn_protos([b"http/1.1", b"spdy/2"])
21032113

21042114
server_context = Context(SSLv23_METHOD)
2105-
server_context.set_alpn_select_callback(invalid_cb)
2115+
server_context.set_alpn_select_callback(invalid_cb) # type: ignore[arg-type]
21062116

21072117
# Necessary to actually accept the connection
21082118
server_context.use_privatekey(
@@ -2163,7 +2173,7 @@ def test_alpn_callback_exception(self) -> None:
21632173
"""
21642174
select_args = []
21652175

2166-
def select(conn, options):
2176+
def select(conn: Connection, options: list[bytes]) -> bytes:
21672177
select_args.append((conn, options))
21682178
raise TypeError()
21692179

@@ -2790,8 +2800,10 @@ def test_set_verify_callback_reference(self) -> None:
27902800
the context and all connections created by it do not use it anymore.
27912801
"""
27922802

2793-
def callback(conn, cert, errnum, depth, ok): # pragma: no cover
2794-
return ok
2803+
def callback(
2804+
conn: Connection, cert: X509, errnum: int, depth: int, ok: int
2805+
) -> bool: # pragma: no cover
2806+
return bool(ok)
27952807

27962808
tracker = ref(callback)
27972809

@@ -2872,7 +2884,7 @@ def test_client_set_session(self) -> None:
28722884
ctx.use_certificate(cert)
28732885
ctx.set_session_id(b"unity-test")
28742886

2875-
def makeServer(socket):
2887+
def makeServer(socket: socket) -> Connection:
28762888
server = Connection(ctx, socket)
28772889
server.set_accept_state()
28782890
return server
@@ -2881,7 +2893,7 @@ def makeServer(socket):
28812893
originalSession = originalClient.get_session()
28822894
assert originalSession is not None
28832895

2884-
def makeClient(socket):
2896+
def makeClient(socket: socket) -> Connection:
28852897
client = loopback_client_factory(socket)
28862898
client.set_session(originalSession)
28872899
return client
@@ -2914,12 +2926,12 @@ def test_set_session_wrong_method(self) -> None:
29142926
ctx.use_certificate(cert)
29152927
ctx.set_session_id(b"unity-test")
29162928

2917-
def makeServer(socket):
2929+
def makeServer(socket: socket) -> Connection:
29182930
server = Connection(ctx, socket)
29192931
server.set_accept_state()
29202932
return server
29212933

2922-
def makeOriginalClient(socket):
2934+
def makeOriginalClient(socket: socket) -> Connection:
29232935
client = Connection(Context(v1), socket)
29242936
client.set_connect_state()
29252937
return client
@@ -2930,7 +2942,7 @@ def makeOriginalClient(socket):
29302942
originalSession = originalClient.get_session()
29312943
assert originalSession is not None
29322944

2933-
def makeClient(socket):
2945+
def makeClient(socket: socket) -> Connection:
29342946
# Intentionally use a different, incompatible method here.
29352947
client = Connection(Context(v2), socket)
29362948
client.set_connect_state()
@@ -3193,7 +3205,7 @@ class VeryLarge(bytes):
31933205
Mock object so that we don't have to allocate 2**31 bytes
31943206
"""
31953207

3196-
def __len__(self):
3208+
def __len__(self) -> int:
31973209
return 2**31
31983210

31993211

@@ -3275,7 +3287,7 @@ def test_buf_too_large(self) -> None:
32753287
exc_info.match(r"Cannot send more than .+ bytes at once")
32763288

32773289

3278-
def _make_memoryview(size):
3290+
def _make_memoryview(size: int) -> memoryview:
32793291
"""
32803292
Create a new ``memoryview`` wrapped around a ``bytearray`` of the given
32813293
size.
@@ -3933,7 +3945,7 @@ def test_set_empty_ca_list(self) -> None:
39333945
after the connection is set up.
39343946
"""
39353947

3936-
def no_ca(ctx):
3948+
def no_ca(ctx: Context) -> list[X509Name]:
39373949
ctx.set_client_ca_list([])
39383950
return []
39393951

@@ -3950,7 +3962,7 @@ def test_set_one_ca_list(self) -> None:
39503962
cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
39513963
cadesc = cacert.get_subject()
39523964

3953-
def single_ca(ctx):
3965+
def single_ca(ctx: Context) -> list[X509Name]:
39543966
ctx.set_client_ca_list([cadesc])
39553967
return [cadesc]
39563968

@@ -3970,7 +3982,7 @@ def test_set_multiple_ca_list(self) -> None:
39703982
sedesc = secert.get_subject()
39713983
cldesc = clcert.get_subject()
39723984

3973-
def multiple_ca(ctx):
3985+
def multiple_ca(ctx: Context) -> list[X509Name]:
39743986
L = [sedesc, cldesc]
39753987
ctx.set_client_ca_list(L)
39763988
return L
@@ -3991,7 +4003,7 @@ def test_reset_ca_list(self) -> None:
39914003
sedesc = secert.get_subject()
39924004
cldesc = clcert.get_subject()
39934005

3994-
def changed_ca(ctx):
4006+
def changed_ca(ctx: Context) -> list[X509Name]:
39954007
ctx.set_client_ca_list([sedesc, cldesc])
39964008
ctx.set_client_ca_list([cadesc])
39974009
return [cadesc]
@@ -4010,7 +4022,7 @@ def test_mutated_ca_list(self) -> None:
40104022
cadesc = cacert.get_subject()
40114023
sedesc = secert.get_subject()
40124024

4013-
def mutated_ca(ctx):
4025+
def mutated_ca(ctx: Context) -> list[X509Name]:
40144026
L = [cadesc]
40154027
ctx.set_client_ca_list([cadesc])
40164028
L.append(sedesc)
@@ -4035,7 +4047,7 @@ def test_one_add_client_ca(self) -> None:
40354047
cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
40364048
cadesc = cacert.get_subject()
40374049

4038-
def single_ca(ctx):
4050+
def single_ca(ctx: Context) -> list[X509Name]:
40394051
ctx.add_client_ca(cacert)
40404052
return [cadesc]
40414053

@@ -4052,7 +4064,7 @@ def test_multiple_add_client_ca(self) -> None:
40524064
cadesc = cacert.get_subject()
40534065
sedesc = secert.get_subject()
40544066

4055-
def multiple_ca(ctx):
4067+
def multiple_ca(ctx: Context) -> list[X509Name]:
40564068
ctx.add_client_ca(cacert)
40574069
ctx.add_client_ca(secert.to_cryptography())
40584070
return [cadesc, sedesc]
@@ -4073,7 +4085,7 @@ def test_set_and_add_client_ca(self) -> None:
40734085
sedesc = secert.get_subject()
40744086
cldesc = clcert.get_subject()
40754087

4076-
def mixed_set_add_ca(ctx):
4088+
def mixed_set_add_ca(ctx: Context) -> list[X509Name]:
40774089
ctx.set_client_ca_list([cadesc, sedesc])
40784090
ctx.add_client_ca(clcert)
40794091
return [cadesc, sedesc, cldesc]
@@ -4093,7 +4105,7 @@ def test_set_after_add_client_ca(self) -> None:
40934105
cadesc = cacert.get_subject()
40944106
sedesc = secert.get_subject()
40954107

4096-
def set_replaces_add_ca(ctx):
4108+
def set_replaces_add_ca(ctx: Context) -> list[X509Name]:
40974109
ctx.add_client_ca(clcert.to_cryptography())
40984110
ctx.set_client_ca_list([cadesc])
40994111
ctx.add_client_ca(secert)
@@ -4253,7 +4265,9 @@ def test_client_negotiates_without_server(self) -> None:
42534265
"""
42544266
called = []
42554267

4256-
def ocsp_callback(conn, ocsp_data, ignored):
4268+
def ocsp_callback(
4269+
conn: Connection, ocsp_data: bytes, ignored: None
4270+
) -> bool:
42574271
called.append(ocsp_data)
42584272
return True
42594273

@@ -4273,7 +4287,9 @@ def test_client_receives_servers_data(self) -> None:
42734287
def server_callback(*args, **kwargs):
42744288
return self.sample_ocsp_data
42754289

4276-
def client_callback(conn, ocsp_data, ignored):
4290+
def client_callback(
4291+
conn: Connection, ocsp_data: bytes, ignored: None
4292+
) -> bool:
42774293
calls.append(ocsp_data)
42784294
return True
42794295

@@ -4347,7 +4363,9 @@ def test_server_returns_empty_string(self) -> None:
43474363
def server_callback(*args):
43484364
return b""
43494365

4350-
def client_callback(conn, ocsp_data, ignored):
4366+
def client_callback(
4367+
conn: Connection, ocsp_data: bytes, ignored: None
4368+
) -> bool:
43514369
client_calls.append(ocsp_data)
43524370
return True
43534371

@@ -4509,10 +4527,10 @@ class TestDTLS:
45094527
def _test_handshake_and_data(self, srtp_profile: bytes | None) -> None:
45104528
s_ctx = Context(DTLS_METHOD)
45114529

4512-
def generate_cookie(ssl):
4530+
def generate_cookie(ssl: Connection) -> bytes:
45134531
return b"xyzzy"
45144532

4515-
def verify_cookie(ssl, cookie):
4533+
def verify_cookie(ssl: Connection, cookie: bytes) -> bool:
45164534
return cookie == b"xyzzy"
45174535

45184536
s_ctx.set_cookie_generate_callback(generate_cookie)

0 commit comments

Comments
 (0)