Skip to content

Commit 2875b4f

Browse files
authored
Feat/aws client next (#961)
* feat: add AWS Service Next module for centralized error handling and pagination * refactor: update pagination check to use client's can_paginate method * refactor: use force_paginate for clarity in API call execution * feat: add AWS Identity Store Next module for user and group management * feat: add AWS schemas for user and group management * feat: implement performance comparison * chore: fmt * fix: remove return none on error for consistency * fix: simplify client logic * feat: add standardized integration response model and error handling functions * refactor: move error handling and response building to models.integrations for reuse * refactor: standardize return types to IntegrationResponse for AWS Identity Store functions * refactor: update module documentation for clarity * refactor: streamline AWS dev command * test: add comprehensive tests for AWS API call execution and pagination * feat: add FakePaginator and FakeClient classes for testing AWS API interactions * feat: add integration tests for AWS Identity Store operations * refactor: reorganize user and group tests for clarity
1 parent 5f653bb commit 2875b4f

File tree

9 files changed

+2767
-6
lines changed

9 files changed

+2767
-6
lines changed
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
"""
2+
AWS Service Next Module
3+
4+
Provides a unified, simplified interface for making AWS API calls with consistent error handling, retry logic, and standardized responses.
5+
6+
Features:
7+
- Centralized error handling and configurable retry logic for all AWS API calls
8+
- Standardized IntegrationResponse model for all results and errors
9+
- Automatic pagination for supported list operations
10+
- Optional role assumption for cross-account access
11+
- Handles non-critical errors and throttling transparently
12+
13+
Usage:
14+
response = execute_aws_api_call(
15+
service_name="ec2",
16+
method="describe_instances",
17+
Filters=[{"Name": "instance-state-name", "Values": ["running"]}],
18+
)
19+
if response.success:
20+
data = response.data
21+
else:
22+
error = response.error
23+
24+
Notes:
25+
- All AWS API calls are executed with consistent error handling and response formatting.
26+
- Batch operations are supported for services with native batch APIs; others are handled via repeated calls.
27+
- All responses use the IntegrationResponse model for easy downstream processing.
28+
"""
29+
30+
import time
31+
from typing import Any, List, Optional, Callable, cast
32+
33+
import boto3 # type: ignore
34+
from botocore.client import BaseClient # type: ignore
35+
from botocore.exceptions import BotoCoreError, ClientError # type: ignore
36+
from core.config import settings
37+
from core.logging import get_module_logger
38+
from models.integrations import (
39+
IntegrationResponse,
40+
build_success_response,
41+
build_error_response,
42+
)
43+
44+
logger = get_module_logger()
45+
46+
AWS_REGION = settings.aws.AWS_REGION
47+
THROTTLING_ERRS = settings.aws.THROTTLING_ERRS
48+
RESOURCE_NOT_FOUND_ERRS = settings.aws.RESOURCE_NOT_FOUND_ERRS
49+
50+
ERROR_CONFIG = {
51+
"non_critical_errors": {
52+
"get_user": ["not found", "timed out"],
53+
"describe_user": ["not found", "user not found"],
54+
"get_group": ["not found", "group not found"],
55+
"describe_group": ["not found", "group not found"],
56+
"get_role": ["not found", "role not found"],
57+
"describe_role": ["not found", "role not found"],
58+
},
59+
"retry_errors": [
60+
"Throttling",
61+
"RequestLimitExceeded",
62+
"ProvisionedThroughputExceededException",
63+
],
64+
"rate_limit_delay": 5,
65+
"default_max_retries": 3,
66+
"default_backoff_factor": 1.0,
67+
}
68+
69+
70+
class AWSAPIError(Exception):
71+
"""Custom exception for AWS API errors."""
72+
73+
def __init__(
74+
self,
75+
message: str,
76+
error_code: Optional[str] = None,
77+
function_name: Optional[str] = None,
78+
):
79+
self.message = message
80+
self.error_code = error_code
81+
self.function_name = function_name
82+
super().__init__(message)
83+
84+
85+
def _should_retry(error: Exception, attempt: int, max_attempts: int) -> bool:
86+
error_code = (
87+
getattr(error, "response", {}).get("Error", {}).get("Code")
88+
if hasattr(error, "response")
89+
else None
90+
)
91+
retry_errors = ERROR_CONFIG.get("retry_errors", [])
92+
if not isinstance(retry_errors, (list, set, tuple)):
93+
retry_errors = []
94+
return error_code in retry_errors and attempt < max_attempts
95+
96+
97+
def _calculate_retry_delay(attempt: int) -> float:
98+
backoff_obj = ERROR_CONFIG.get("default_backoff_factor", 0.5)
99+
try:
100+
# Cast to Any so type checkers accept passing it to float()
101+
backoff = float(cast(Any, backoff_obj))
102+
except (TypeError, ValueError):
103+
backoff = 0.5 # fallback to default
104+
return backoff * (2**attempt)
105+
106+
107+
def _handle_final_error(
108+
error: Exception,
109+
function_name: str,
110+
) -> IntegrationResponse:
111+
"""Handle the final error after all retries are exhausted. Supports configured non-critical errors which log warnings instead of errors."""
112+
error_message = str(error).lower()
113+
114+
# Check if this is a known non-critical error
115+
raw_nc = (
116+
ERROR_CONFIG.get("non_critical_errors")
117+
if isinstance(ERROR_CONFIG, dict)
118+
else None
119+
)
120+
is_non_critical_config = False
121+
if isinstance(raw_nc, dict):
122+
function_errs = raw_nc.get(function_name)
123+
if isinstance(function_errs, (list, tuple, set)):
124+
is_non_critical_config = any(
125+
isinstance(err, str) and (err in error_message) for err in function_errs
126+
)
127+
128+
error_code = (
129+
getattr(error, "response", {}).get("Error", {}).get("Code")
130+
if hasattr(error, "response")
131+
else None
132+
)
133+
134+
if is_non_critical_config:
135+
logger.warning(
136+
"aws_api_non_critical_error",
137+
function=function_name,
138+
error=str(error),
139+
error_code=error_code,
140+
)
141+
return build_error_response(error, function_name, "aws")
142+
else:
143+
logger.error(
144+
"aws_api_error_final",
145+
function=function_name,
146+
error=str(error),
147+
error_code=error_code,
148+
)
149+
return build_error_response(error, function_name, "aws")
150+
151+
152+
def _can_paginate_method(client: BaseClient, method: str) -> bool:
153+
"""
154+
Determine if an AWS API method can be paginated using the client's can_paginate method.
155+
156+
Args:
157+
client (BaseClient): The AWS service client
158+
method (str): The AWS API method name
159+
160+
Returns:
161+
bool: True if the method can be paginated
162+
"""
163+
try:
164+
return client.can_paginate(method)
165+
except (AttributeError, TypeError, ValueError):
166+
# Fallback to False if method doesn't exist or can't be checked
167+
return False
168+
169+
170+
def get_aws_client(
171+
service_name: str,
172+
session_config: Optional[dict] = None,
173+
client_config: Optional[dict] = None,
174+
role_arn: Optional[str] = None,
175+
session_name: str = "DefaultSession",
176+
) -> BaseClient:
177+
"""
178+
Create a boto3 AWS service client, optionally assuming a role.
179+
180+
Args:
181+
service_name (str): The name of the AWS service.
182+
session_config (dict, optional): Session configuration.
183+
client_config (dict, optional): Client configuration.
184+
role_arn (str, optional): The ARN of the IAM role to assume.
185+
session_name (str): The name for the assumed role session.
186+
"""
187+
session_config = session_config or {"region_name": AWS_REGION}
188+
client_config = client_config or {"region_name": AWS_REGION}
189+
if role_arn:
190+
sts_client = boto3.client("sts")
191+
assumed_role = sts_client.assume_role(
192+
RoleArn=role_arn, RoleSessionName=session_name
193+
)
194+
credentials = assumed_role["Credentials"]
195+
session = boto3.Session(
196+
aws_access_key_id=credentials["AccessKeyId"],
197+
aws_secret_access_key=credentials["SecretAccessKey"],
198+
aws_session_token=credentials["SessionToken"],
199+
**session_config,
200+
)
201+
else:
202+
session = boto3.Session(**session_config)
203+
return session.client(service_name, **client_config)
204+
205+
206+
def _paginate_all_results(
207+
client: BaseClient, method: str, keys: Optional[List[str]] = None, **kwargs
208+
) -> List[dict]:
209+
paginator = client.get_paginator(method)
210+
results = []
211+
for page in paginator.paginate(**kwargs):
212+
if keys is None:
213+
for key, value in page.items():
214+
if key != "ResponseMetadata":
215+
if isinstance(value, list):
216+
results.extend(value)
217+
else:
218+
results.append(value)
219+
else:
220+
for key in keys:
221+
if key in page:
222+
results.extend(page[key])
223+
return results
224+
225+
226+
def execute_api_call(
227+
func_name: str,
228+
api_call: Callable[[], Any],
229+
max_retries: Optional[int] = None,
230+
) -> IntegrationResponse:
231+
"""
232+
Module-level error handling for AWS API calls.
233+
234+
This provides centralized error handling.
235+
236+
Args:
237+
func_name (str): Name of the calling function for logging
238+
api_call (callable): The API call to execute
239+
max_retries (int): Override default max retries
240+
241+
Returns:
242+
IntegrationResponse: Standardized response model for external API integrations.
243+
"""
244+
default_retries = ERROR_CONFIG.get("default_max_retries", 3)
245+
max_retry_attempts = (
246+
max_retries if max_retries is not None else cast(int, default_retries)
247+
)
248+
last_exception: Optional[Exception] = None
249+
250+
for attempt in range(max_retry_attempts + 1):
251+
try:
252+
logger.debug(
253+
"aws_api_call_start",
254+
function=func_name,
255+
attempt=attempt + 1,
256+
max_attempts=max_retry_attempts + 1,
257+
)
258+
259+
result = api_call()
260+
261+
if attempt > 0:
262+
logger.info(
263+
"aws_api_retry_success",
264+
function=func_name,
265+
attempt=attempt + 1,
266+
)
267+
268+
return build_success_response(result, func_name, "aws")
269+
270+
except (BotoCoreError, ClientError) as e:
271+
last_exception = e
272+
273+
if _should_retry(e, attempt, max_retry_attempts):
274+
delay = _calculate_retry_delay(attempt)
275+
logger.warning(
276+
"aws_api_retrying",
277+
function=func_name,
278+
attempt=attempt + 1,
279+
error=str(e),
280+
delay=delay,
281+
)
282+
time.sleep(delay)
283+
continue
284+
285+
return _handle_final_error(
286+
e,
287+
func_name,
288+
)
289+
290+
except Exception as e: # pylint: disable=broad-except
291+
last_exception = e
292+
293+
return _handle_final_error(
294+
e,
295+
func_name,
296+
)
297+
298+
# If we exit the retry loop without returning, use the last captured exception
299+
if last_exception is None:
300+
last_exception = Exception("Unknown error after retries")
301+
302+
return _handle_final_error(
303+
last_exception,
304+
func_name,
305+
)
306+
307+
308+
def execute_aws_api_call(
309+
service_name: str,
310+
method: str,
311+
keys: Optional[List[str]] = None,
312+
role_arn: Optional[str] = None,
313+
session_config: Optional[dict] = None,
314+
client_config: Optional[dict] = None,
315+
max_retries: Optional[int] = None,
316+
force_paginate: bool = False,
317+
**kwargs,
318+
) -> IntegrationResponse:
319+
"""
320+
Simplified version of execute_aws_api_call using module-level error handling.
321+
322+
Auto-paginates list operations by default.
323+
324+
Args:
325+
service_name (str): The name of the AWS service.
326+
method (str): The method to call on the service.
327+
keys (list, optional): The keys to extract from paginated results.
328+
role_arn (str, optional): The ARN of the IAM role to assume.
329+
session_config (dict, optional): Session configuration.
330+
client_config (dict, optional): Client configuration.
331+
max_retries (int, optional): Override default max retries.
332+
force_paginate (bool, optional): If True, force pagination even for single-page results.
333+
**kwargs: Additional keyword arguments for the API call.
334+
335+
Returns:
336+
IntegrationResponse: Standardized response model for external API integrations.
337+
"""
338+
339+
def api_call():
340+
client = get_aws_client(service_name, session_config, client_config, role_arn)
341+
api_method = getattr(client, method)
342+
343+
# Auto-paginate list operations unless force_paginate is explicitly requested
344+
should_paginate = (
345+
force_paginate
346+
or _can_paginate_method(client, method)
347+
and not force_paginate
348+
)
349+
350+
if should_paginate:
351+
return _paginate_all_results(client, method, keys, **kwargs)
352+
else:
353+
return api_method(**kwargs)
354+
355+
# Use module-level error handling
356+
func_name = f"{service_name}_{method}"
357+
return execute_api_call(
358+
func_name,
359+
api_call,
360+
max_retries=max_retries,
361+
)

0 commit comments

Comments
 (0)