11import asyncio
22import pytest
3+ from typing import Any
4+ import contextlib
35
46from toller import CircuitBreaker , CircuitState , OpenCircuitError
57
6- pytestmark = pytest .mark .asyncio
7-
88
99class MockException (Exception ):
1010 """Custom exception for testing."""
@@ -18,26 +18,26 @@ class AnotherMockException(Exception):
1818 pass
1919
2020
21- async def mock_successful_call ():
21+ async def mock_successful_call () -> str :
2222 """Simulates a call that succeeds."""
2323 await asyncio .sleep (0.01 )
2424 return "Success"
2525
2626
27- async def mock_failing_call (exception_type = MockException ):
27+ async def mock_failing_call (exception_type : Any = MockException ) -> None :
2828 """Simulates a call that fails."""
2929 await asyncio .sleep (0.01 )
3030 raise exception_type ("Operation failed" )
3131
3232
33- async def test_initial_state_is_closed ():
33+ async def test_initial_state_is_closed () -> None :
3434 """Test that the initial state of the circuit breaker is CLOSED."""
3535 breaker = CircuitBreaker ()
3636 assert breaker .state == CircuitState .CLOSED
3737 assert breaker .current_failures == 0
3838
3939
40- async def test_successful_calls_keep_closed ():
40+ async def test_successful_calls_keep_closed () -> None :
4141 """Test that successful calls keep the circuit breaker in CLOSED state."""
4242 breaker = CircuitBreaker ()
4343 for _ in range (5 ):
@@ -47,8 +47,8 @@ async def test_successful_calls_keep_closed():
4747 assert breaker .current_failures == 0
4848
4949
50- async def test_failures_increment_count ():
51- """Failures should increment the failure count but not open the circuit if below threshold."""
50+ async def test_failures_increment_count () -> None :
51+ """Increment failure count but don't open the circuit if below threshold."""
5252 breaker = CircuitBreaker (failure_threshold = 5 )
5353 for i in range (3 ):
5454 with pytest .raises (MockException ):
@@ -58,7 +58,7 @@ async def test_failures_increment_count():
5858 assert breaker .state == CircuitState .CLOSED
5959
6060
61- async def test_failure_threshold_opens_circuit ():
61+ async def test_failure_threshold_opens_circuit () -> None :
6262 """Test that reaching the failure threshold opens the circuit."""
6363 threshold = 3
6464 breaker = CircuitBreaker (failure_threshold = threshold )
@@ -70,7 +70,7 @@ async def test_failure_threshold_opens_circuit():
7070 assert breaker .current_failures == threshold
7171
7272
73- async def test_open_circuit_blocks_calls ():
73+ async def test_open_circuit_blocks_calls () -> None :
7474 """Test that an open circuit blocks calls immediately."""
7575 breaker = CircuitBreaker (failure_threshold = 1 )
7676 with pytest .raises (MockException ):
@@ -83,7 +83,7 @@ async def test_open_circuit_blocks_calls():
8383 await mock_failing_call ()
8484
8585
86- async def test_recovery_timeout_moves_to_half_open ():
86+ async def test_recovery_timeout_moves_to_half_open () -> None :
8787 """Test that after the recovery timeout, the circuit moves to HALF_OPEN state."""
8888 recovery_time = 0.1
8989 breaker = CircuitBreaker (failure_threshold = 1 , recovery_timeout = recovery_time )
@@ -106,7 +106,7 @@ async def test_recovery_timeout_moves_to_half_open():
106106 assert breaker .state == CircuitState .OPEN
107107
108108
109- async def test_success_in_half_open_closes_circuit ():
109+ async def test_success_in_half_open_closes_circuit () -> None :
110110 """Test that a successful call in HALF_OPEN state closes the circuit."""
111111 recovery_time = 0.1
112112 breaker = CircuitBreaker (failure_threshold = 1 , recovery_timeout = recovery_time )
@@ -125,7 +125,7 @@ async def test_success_in_half_open_closes_circuit():
125125 assert breaker .current_failures == 0
126126
127127
128- async def test_failure_in_half_open_reopens_circuit ():
128+ async def test_failure_in_half_open_reopens_circuit () -> None :
129129 """Test that a failure in HALF_OPEN state reopens the circuit."""
130130 recovery_time = 0.1
131131 breaker = CircuitBreaker (failure_threshold = 1 , recovery_timeout = recovery_time )
@@ -148,7 +148,7 @@ async def test_failure_in_half_open_reopens_circuit():
148148 await mock_failing_call ()
149149
150150
151- async def test_success_resets_failure_count_when_closed ():
151+ async def test_success_resets_failure_count_when_closed () -> None :
152152 """Test that a successful call resets the failure count when in CLOSED state."""
153153 breaker = CircuitBreaker (failure_threshold = 3 )
154154
@@ -172,7 +172,7 @@ async def test_success_resets_failure_count_when_closed():
172172 assert breaker .current_failures == 1
173173
174174
175- async def test_expected_exception_filtering ():
175+ async def test_expected_exception_filtering () -> None :
176176 """Test that only expected exceptions trip the circuit breaker."""
177177 breaker = CircuitBreaker (failure_threshold = 2 , expected_exception = MockException )
178178
@@ -194,7 +194,7 @@ async def test_expected_exception_filtering():
194194 assert breaker .state == CircuitState .OPEN
195195
196196
197- async def test_expected_exception_tuple ():
197+ async def test_expected_exception_tuple () -> None :
198198 """Test that a tuple of exceptions can be used to trip the circuit breaker."""
199199 breaker = CircuitBreaker (
200200 failure_threshold = 2 , expected_exception = (MockException , AnotherMockException )
@@ -212,18 +212,16 @@ async def test_expected_exception_tuple():
212212 assert breaker .state == CircuitState .OPEN
213213
214214
215- async def test_concurrent_failures_open_circuit ():
215+ async def test_concurrent_failures_open_circuit () -> None :
216216 """Test that concurrent failures can open the circuit."""
217217 threshold = 5
218218 breaker = CircuitBreaker (failure_threshold = threshold )
219219 num_concurrent = 10
220220
221- async def concurrent_task ():
222- try :
221+ async def concurrent_task () -> None :
222+ with contextlib . suppress ( MockException , OpenCircuitError ) :
223223 async with breaker :
224224 await mock_failing_call ()
225- except (MockException , OpenCircuitError ):
226- pass
227225
228226 tasks = [asyncio .create_task (concurrent_task ()) for _ in range (num_concurrent )]
229227 await asyncio .gather (* tasks )
@@ -232,7 +230,7 @@ async def concurrent_task():
232230 assert breaker .current_failures >= threshold
233231
234232
235- async def test_invalid_init_args ():
233+ async def test_invalid_init_args () -> None :
236234 with pytest .raises (ValueError ):
237235 CircuitBreaker (failure_threshold = 0 )
238236 with pytest .raises (ValueError ):
0 commit comments