Skip to content

Commit 7a8d735

Browse files
committed
add short circuit on service throttling and apply much needed typing
1 parent d400e86 commit 7a8d735

File tree

3 files changed

+549
-0
lines changed

3 files changed

+549
-0
lines changed

src/layer_utils/circuit_state.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
# SPDX-License-Identifier: MIT-0
4+
5+
Circuit breaker pattern implementation for AWS API calls
6+
"""
7+
import time
8+
import logging
9+
from threading import Lock
10+
from functools import wraps
11+
12+
# Set up logging
13+
logger = logging.getLogger()
14+
logger.setLevel("INFO")
15+
16+
# Circuit breaker state storage
17+
_circuit_states : dict = {}
18+
_circuit_lock : Lock = Lock()
19+
20+
class CircuitState:
21+
"""
22+
Represents the state of a circuit breaker for a specific operation.
23+
24+
Attributes:
25+
is_open (bool): Whether the circuit is currently open (failing fast)
26+
failure_count (int): Number of consecutive failures
27+
last_failure_time (float): Timestamp of the last failure
28+
threshold (int): Number of failures before opening the circuit
29+
timeout (int): Seconds to wait before trying again (half-open state)
30+
"""
31+
def __init__(self):
32+
self.is_open : bool = False
33+
self.failure_count : int = 0
34+
self.last_failure_time : float = 0.0
35+
self.threshold : int = 5 # Number of failures before opening
36+
self.timeout : int = 60 # Seconds to wait before trying again
37+
38+
def circuit_is_open(operation_name: str):
39+
"""
40+
Check if the circuit is open for a specific operation
41+
42+
Args:
43+
operation_name (str): The name of the operation to check
44+
45+
Returns:
46+
bool: True if the circuit is open and requests should fail fast
47+
"""
48+
with _circuit_lock:
49+
if operation_name not in _circuit_states:
50+
_circuit_states[operation_name] = CircuitState()
51+
52+
circuit = _circuit_states[operation_name]
53+
54+
# If circuit is open, check if timeout has elapsed to try again
55+
if circuit.is_open is True:
56+
if time.time() - circuit.last_failure_time > circuit.timeout:
57+
logger.info("Circuit for %s moving to half-open state", operation_name)
58+
circuit.is_open = False
59+
return False
60+
return True
61+
return False
62+
63+
# Move to half-open state by allowing one request through
64+
65+
def record_failure(operation_name: str):
66+
"""
67+
Record a failure for the circuit breaker
68+
69+
Args:
70+
operation_name (str): The name of the operation that failed
71+
"""
72+
with _circuit_lock:
73+
if operation_name not in _circuit_states:
74+
_circuit_states[operation_name] = CircuitState()
75+
76+
circuit = _circuit_states[operation_name]
77+
circuit.failure_count += 1
78+
circuit.last_failure_time = time.time()
79+
80+
# Open the circuit if we exceed the threshold
81+
if circuit.failure_count >= circuit.threshold:
82+
circuit.is_open = True
83+
logger.warning("Circuit breaker opened for %s after %i failures",
84+
operation_name, circuit.threshold)
85+
86+
def reset_circuit(operation_name):
87+
"""
88+
Reset the circuit after a successful operation
89+
90+
Args:
91+
operation_name (str): The name of the operation that succeeded
92+
"""
93+
with _circuit_lock:
94+
if operation_name not in _circuit_states:
95+
return
96+
97+
circuit = _circuit_states[operation_name]
98+
if circuit.is_open or circuit.failure_count > 0:
99+
logger.info("Circuit for %s reset after successful operation", operation_name)
100+
circuit.is_open = False
101+
circuit.failure_count = 0
102+
103+
class CircuitOpenError(Exception):
104+
"""Exception raised when a circuit is open"""
105+
pass
106+
107+
def with_circuit_breaker(operation_name, fallback_function=None):
108+
"""
109+
Decorator to apply circuit breaker pattern to a function
110+
111+
Args:
112+
operation_name (str): Name of the operation for the circuit breaker
113+
fallback_function (callable, optional): Function to call when circuit is open
114+
115+
Example:
116+
@with_circuit_breaker('get_thing_group')
117+
def get_thing_group_arn(thing_group_name):
118+
# Implementation
119+
"""
120+
def decorator(func):
121+
@wraps(func)
122+
def wrapper(*args, **kwargs):
123+
if circuit_is_open(operation_name):
124+
logger.warning("Circuit breaker open for %s, failing fast", operation_name)
125+
if fallback_function:
126+
return fallback_function(*args, **kwargs)
127+
raise CircuitOpenError(f"Circuit breaker open for {operation_name}")
128+
129+
try:
130+
result = func(*args, **kwargs)
131+
reset_circuit(operation_name)
132+
return result
133+
except Exception as e:
134+
record_failure(operation_name)
135+
raise
136+
137+
return wrapper
138+
return decorator
139+
140+
def add_circuit(name: str) -> CircuitState:
141+
""" Create an arbitrary circuit """
142+
_circuit_states[name] = CircuitState()
143+
return _circuit_states[name]
144+
145+
def remove_circuit(name: str):
146+
""" Remove an arbitrary circuit """
147+
del _circuit_states[name]
148+
149+
def clear_circuits():
150+
""" Clears all circuits """
151+
_circuit_states.clear()
152+
153+
def has_circuit(name: str) -> bool:
154+
if name in _circuit_states:
155+
return True
156+
return False
157+
158+
def get_circuit(name: str) -> CircuitState:
159+
if has_circuit(name):
160+
return _circuit_states[name]
161+
return add_circuit(name)

test/conftest.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
# SPDX-License-Identifier: MIT-0
4+
5+
Configure pytest environment
6+
"""
7+
import os
8+
import sys
9+
from pathlib import Path
10+
import pytest
11+
12+
# Add the project root directory to the Python path
13+
project_root = Path(__file__).parent.parent
14+
sys.path.insert(0, str(project_root))
15+
16+
# Add specific module directories to the Python path
17+
# This makes modules directly importable in tests without relative imports
18+
module_paths = [
19+
# Add paths to specific modules
20+
os.path.join(project_root, "src", "layer_utils"),
21+
os.path.join(project_root, "src", "bulk_importer"),
22+
os.path.join(project_root, "src", "product_provider"),
23+
os.path.join(project_root, "src", "provider_espressif"),
24+
os.path.join(project_root, "src", "provider_infineon", "provider_infineon"),
25+
os.path.join(project_root, "src", "provider_microchip", "provider_microchip"),
26+
]
27+
28+
# Add each module path to sys.path
29+
for path in module_paths:
30+
if path not in sys.path and os.path.exists(path):
31+
sys.path.insert(0, path)
32+
33+
# Create module aliases for compatibility
34+
try:
35+
# Create aliases for common modules
36+
# For example, make src.layer_utils.aws_utils available as just aws_utils
37+
import src.layer_utils.aws_utils
38+
sys.modules['aws_utils'] = sys.modules['src.layer_utils.aws_utils']
39+
40+
import src.layer_utils.cert_utils
41+
sys.modules['cert_utils'] = sys.modules['src.layer_utils.cert_utils']
42+
except ImportError:
43+
# Handle case where modules aren't found
44+
pass
45+
46+
# Reset circuit state before each test
47+
@pytest.fixture(autouse=True)
48+
def reset_circuit_state(request):
49+
"""Reset circuit state before each test"""
50+
# Skip for aws_utils tests
51+
if 'test_aws_utils' in request.node.name:
52+
yield
53+
return
54+
55+
# Reset circuit state for other tests
56+
try:
57+
from src.layer_utils.circuit_state import _circuit_states
58+
_circuit_states.clear()
59+
except ImportError:
60+
pass
61+
62+
yield

0 commit comments

Comments
 (0)