Skip to content

Commit 2feeb60

Browse files
committed
added type fixes to tests
1 parent a45669f commit 2feeb60

File tree

1 file changed

+56
-58
lines changed

1 file changed

+56
-58
lines changed

tests/test_ssl.py

Lines changed: 56 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,9 @@ def get_ssl_error_reason(ssl_error: SSL.Error) -> typing.Optional[str]:
426426
if isinstance(error_details, list) and len(error_details) > 0:
427427
first_error_tuple = error_details[0]
428428
if isinstance(first_error_tuple, tuple) and len(first_error_tuple) >= 3:
429-
return first_error_tuple[2]
429+
reason = first_error_tuple[2]
430+
if isinstance(reason, str):
431+
return reason
430432
return None
431433

432434
def create_ssl_nonblocking_connection(mode: int) -> tuple[socket, socket, Connection, Connection]:
@@ -2472,7 +2474,7 @@ def test_bio_write(self) -> None:
24722474
context = Context(SSLv23_METHOD)
24732475
connection = Connection(context, None)
24742476
connection.bio_write(b"xy")
2475-
connection.bio_write(bytearray(b"za"))
2477+
connection.bio_write(bytearray(b"za")) # type: ignore[arg-type]
24762478
with pytest.warns(DeprecationWarning):
24772479
connection.bio_write("deprecated") # type: ignore[arg-type]
24782480

@@ -3113,69 +3115,63 @@ def test_wantWriteError(self) -> None:
31133115

31143116
# XXX want_read
31153117

3116-
def _badwriteretry(self, mode) -> bool:
3118+
def _badwriteretry(self, mode: int) -> bool:
31173119
"""
3118-
`Connection` methods which generate output raise
3119-
`OpenSSL.SSL.WantWriteError` if writing to the connection's BIO
3120-
fail indicating a should-write state.
3120+
Tries to force a "bad write retry" error over an SSL connection by using a moving buffer
3121+
Returns True if a bad write retry error occurs.
31213122
"""
31223123
client_socket, server_socket, client, server = create_ssl_nonblocking_connection(mode)
3123-
3124+
result = False # Default return value
31243125
written = 0
31253126

3126-
import os
3127-
print("Client PID:")
3128-
print(os.getpid())
3129-
3130-
# Fill up the client's raw send buffer so the SSL connection won't be able to write
3131-
# anything. Start by sending larger chunks
3132-
# and continue by writing smaller chunks so we can be sure we
3133-
# completely fill the buffer.
3134-
msg = 'test'
3135-
for msg in [b"x" * 65536, b"x" * 16]:
3136-
for i in range(1024 * 1024 * 64):
3137-
try:
3138-
written = client_socket.send(msg)
3139-
print(f"Sent {written} bytes to fill buffer")
3140-
except OSError as e:
3141-
if e.errno == EWOULDBLOCK:
3142-
break
3143-
raise
3144-
else:
3145-
pytest.fail(
3146-
"Failed to fill socket buffer, cannot test bad write error"
3147-
)
3127+
try:
3128+
# Fill up the client's raw send buffer so the SSL connection won't be able to write
3129+
# anything. Start by sending larger chunks
3130+
# and continue by writing smaller chunks so we can be sure we
3131+
# completely fill the buffer.
3132+
for msg in [b"x" * 65536, b"x" * 16]:
3133+
for i in range(1024 * 1024 * 64):
3134+
try:
3135+
written = client_socket.send(msg)
3136+
print(f"Sent {written} bytes to fill buffer")
3137+
except OSError as e:
3138+
if e.errno == EWOULDBLOCK:
3139+
break
3140+
raise
3141+
else:
3142+
pytest.fail(
3143+
"Failed to fill socket buffer, cannot test bad write error"
3144+
)
31483145

3149-
# Now, attempt to send application data over the *established* SSL connection.
3150-
# Since the underlying raw socket's buffer is full, this should cause a WantWriteError.
3151-
print("Attempting to send data over SSL connection with full buffer...")
3152-
msg2 = b"Y" * 65536 #b"This data should trigger WantWriteError due to full buffer."
3146+
# Now, attempt to send application data over the *established* SSL connection.
3147+
# Since the underlying raw socket's buffer is full, this should cause a WantWriteError.
3148+
print("Attempting to send data over SSL connection with full buffer...")
3149+
msg2 = b"Y" * 65536
31533150

3154-
try:
3155-
written = client.send(msg2)
3156-
except SSL.WantWriteError as e:
3157-
print(f"Raised OpenSSL.SSL.WantWriteError as expected: {e} {written} bytes")
31583151
try:
3159-
# do a retry write which should fail unless SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER is set
3160-
# because we are passing a new buffer that has the same size as the previous but a different location
3161-
msg3 = b"Z" * 65536
3162-
written = client.send(msg3)
3163-
print(f"Retry succeeded unexpectedly: {written} bytes")
3152+
written = client.send(msg2)
3153+
except SSL.WantWriteError as e:
3154+
print(f"Raised OpenSSL.SSL.WantWriteError as expected: {e} {written} bytes")
3155+
try:
3156+
# do a retry write which should fail unless SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER is set
3157+
msg3 = b"Z" * 65536
3158+
written = client.send(msg3)
3159+
print(f"Retry succeeded unexpectedly: {written} bytes")
3160+
result = False # Unexpected success
3161+
except SSL.Error as e:
3162+
reason = get_ssl_error_reason(e)
3163+
if reason == "bad write retry":
3164+
print(f"Got expected SSL error on retry: {e}")
3165+
result = True # Expected behavior
3166+
else:
3167+
print(f"Got unexpected SSL error on retry: {e}")
3168+
result = False # Unexpected error
3169+
31643170
except SSL.Error as e:
31653171
reason = get_ssl_error_reason(e)
3166-
if reason == "bad write retry":
3167-
print(f"Got expected SSL error on retry: {e}")
3168-
return True
3169-
else:
3170-
print(f"Got unexpected SSL error on retry: {e}")
3171-
return False
3172-
#print(f"Got expected SSL error on retry: {e} {written} bytes")
3173-
3174-
except SSL.Error as e:
3175-
reason = get_ssl_error_reason(e)
3176-
pytest.fail(f"Got unexpected SSL error on retry: {e} {reason}")
3177-
except Exception as e:
3178-
pytest.fail(f"Unexpected exception during send: {e}")
3172+
pytest.fail(f"Got unexpected SSL error on retry: {e} {reason}")
3173+
except Exception as e:
3174+
pytest.fail(f"Unexpected exception during send: {e}")
31793175

31803176
finally:
31813177
# Cleanup: shut down SSL connections and close raw sockets
@@ -3191,6 +3187,8 @@ def _badwriteretry(self, mode) -> bool:
31913187
client_socket.close()
31923188
if server_socket:
31933189
server_socket.close()
3190+
3191+
return result # Return the result after cleanup
31943192

31953193
def test_moving_write_buffer_should_pass(self) -> None:
31963194
"""
@@ -3491,7 +3489,7 @@ def test_short_memoryview(self) -> None:
34913489
of bytes sent.
34923490
"""
34933491
server, client = loopback()
3494-
count = server.send(memoryview(b"xy"))
3492+
count = server.send(memoryview(b"xy")) # type: ignore[arg-type]
34953493
assert count == 2
34963494
assert client.recv(2) == b"xy"
34973495

@@ -3501,7 +3499,7 @@ def test_short_bytearray(self) -> None:
35013499
it and returns the number of bytes sent.
35023500
"""
35033501
server, client = loopback()
3504-
count = server.send(bytearray(b"xy"))
3502+
count = server.send(bytearray(b"xy")) # type: ignore[arg-type]
35053503
assert count == 2
35063504
assert client.recv(2) == b"xy"
35073505

@@ -3703,7 +3701,7 @@ def test_short_memoryview(self) -> None:
37033701
`Connection.sendall` transmits all of them.
37043702
"""
37053703
server, client = loopback()
3706-
server.sendall(memoryview(b"x"))
3704+
server.sendall(memoryview(b"x")) # type: ignore[arg-type]
37073705
assert client.recv(1) == b"x"
37083706

37093707
def test_long(self) -> None:

0 commit comments

Comments
 (0)