Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,16 @@ def patch_extract_attributes(self, attributes: _AttributeMapT):
if queue_url:
attributes[AWS_SQS_QUEUE_URL] = queue_url

old_on_success = _SqsExtension.on_success

def patch_on_success(self, span: Span, result: _BotoResultT):
old_on_success(self, span, result)
queue_url = result.get("QueueUrl")
if queue_url:
span.set_attribute(AWS_SQS_QUEUE_URL, queue_url)

_SqsExtension.extract_attributes = patch_extract_attributes
_SqsExtension.on_success = patch_on_success


def _apply_botocore_bedrock_patch() -> None:
Expand Down
139 changes: 138 additions & 1 deletion contract-tests/images/applications/botocore/botocore_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import requests
from botocore.client import BaseClient
from botocore.config import Config
from botocore.exceptions import ClientError
from typing_extensions import Tuple, override

_PORT: int = 8080
Expand Down Expand Up @@ -45,6 +46,10 @@ def do_GET(self):
self._handle_kinesis_request()
if self.in_path("bedrock"):
self._handle_bedrock_request()
if self.in_path("secretsmanager"):
self._handle_secretsmanager_request()
if self.in_path("stepfunctions"):
self._handle_stepfunctions_request()

self._end_request(self.main_status)

Expand Down Expand Up @@ -246,7 +251,11 @@ def _handle_bedrock_request(self) -> None:
set_main_status(200)
bedrock_client.meta.events.register(
"before-call.bedrock.GetGuardrail",
lambda **kwargs: inject_200_success(guardrailId="bt4o77i015cu", **kwargs),
lambda **kwargs: inject_200_success(
guardrailId="bt4o77i015cu",
guardrailArn="arn:aws:bedrock:us-east-1:000000000000:guardrail/bt4o77i015cu",
**kwargs,
),
)
bedrock_client.get_guardrail(
guardrailIdentifier="arn:aws:bedrock:us-east-1:000000000000:guardrail/bt4o77i015cu"
Expand Down Expand Up @@ -301,6 +310,69 @@ def _handle_bedrock_request(self) -> None:
else:
set_main_status(404)

def _handle_secretsmanager_request(self) -> None:
secretsmanager_client = boto3.client("secretsmanager", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
if self.in_path(_ERROR):
set_main_status(400)
try:
error_client = boto3.client("secretsmanager", endpoint_url=_ERROR_ENDPOINT, region_name=_AWS_REGION)
error_client.describe_secret(
SecretId="arn:aws:secretsmanager:us-west-2:000000000000:secret:unExistSecret"
)
except Exception as exception:
print("Expected exception occurred", exception)
elif self.in_path(_FAULT):
set_main_status(500)
try:
fault_client = boto3.client(
"secretsmanager", endpoint_url=_FAULT_ENDPOINT, region_name=_AWS_REGION, config=_NO_RETRY_CONFIG
)
fault_client.get_secret_value(
SecretId="arn:aws:secretsmanager:us-west-2:000000000000:secret:nonexistent-secret"
)
except Exception as exception:
print("Expected exception occurred", exception)
elif self.in_path("describesecret/my-secret"):
set_main_status(200)
secretsmanager_client.describe_secret(SecretId="testSecret")
else:
set_main_status(404)

def _handle_stepfunctions_request(self) -> None:
sfn_client = boto3.client("stepfunctions", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
if self.in_path(_ERROR):
set_main_status(400)
try:
error_client = boto3.client("stepfunctions", endpoint_url=_ERROR_ENDPOINT, region_name=_AWS_REGION)
error_client.describe_state_machine(
stateMachineArn="arn:aws:states:us-west-2:000000000000:stateMachine:unExistStateMachine"
)
except Exception as exception:
print("Expected exception occurred", exception)
elif self.in_path(_FAULT):
set_main_status(500)
try:
fault_client = boto3.client("stepfunctions", endpoint_url=_FAULT_ENDPOINT, region_name=_AWS_REGION)
fault_client.meta.events.register(
"before-call.stepfunctions.ListStateMachineVersions",
lambda **kwargs: inject_500_error("ListStateMachineVersions", **kwargs),
)
fault_client.list_state_machine_versions(
stateMachineArn="arn:aws:states:us-west-2:000000000000:stateMachine:invalid-state-machine",
)
except Exception as exception:
print("Expected exception occurred", exception)
elif self.in_path("describestatemachine/my-state-machine"):
set_main_status(200)
sfn_client.describe_state_machine(
stateMachineArn="arn:aws:states:us-west-2:000000000000:stateMachine:testStateMachine"
)
elif self.in_path("describeactivity/my-activity"):
set_main_status(200)
sfn_client.describe_activity(activityArn="arn:aws:states:us-west-2:000000000000:activity:testActivity")
else:
set_main_status(404)

def _end_request(self, status_code: int):
self.send_response_only(status_code)
self.end_headers()
Expand All @@ -310,6 +382,7 @@ def set_main_status(status: int) -> None:
RequestHandler.main_status = status


# pylint: disable=too-many-locals
def prepare_aws_server() -> None:
requests.Request(method="POST", url="http://localhost:4566/_localstack/state/reset")
try:
Expand Down Expand Up @@ -345,6 +418,57 @@ def prepare_aws_server() -> None:
# Set up Kinesis so tests can access a stream.
kinesis_client: BaseClient = boto3.client("kinesis", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
kinesis_client.create_stream(StreamName="test_stream", ShardCount=1)

# Set up Secrets Manager so tests can access a secret.
secretsmanager_client: BaseClient = boto3.client(
"secretsmanager", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION
)
secretsmanager_response = secretsmanager_client.list_secrets()
secret = next((s for s in secretsmanager_response["SecretList"] if s["Name"] == "testSecret"), None)
if not secret:
secretsmanager_client.create_secret(
Name="testSecret", SecretString="secretValue", Description="This is a test secret"
)

# Set up Step Functions so tests can access a state machine and activity.
sfn_client: BaseClient = boto3.client("stepfunctions", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
sfn_response = sfn_client.list_state_machines()
state_machine_name = "testStateMachine"
activity_name = "testActivity"
state_machine = next((st for st in sfn_response["stateMachines"] if st["name"] == state_machine_name), None)
if not state_machine:
# create state machine needs an iam role so we create it here
iam_client: BaseClient = boto3.client("iam", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
iam_role_name = "testRole"
iam_role_arn = None
trust_policy = {
"Version": "2012-10-17",
"Statement": [
{"Effect": "Allow", "Principal": {"Service": "states.amazonaws.com"}, "Action": "sts:AssumeRole"}
],
}
try:
iam_response = iam_client.create_role(
RoleName=iam_role_name, AssumeRolePolicyDocument=json.dumps(trust_policy)
)
iam_client.attach_role_policy(
RoleName=iam_role_name, PolicyArn="arn:aws:iam::aws:policy/AWSStepFunctionsFullAccess"
)
print(f"IAM Role '{iam_role_name}' create successfully.")
iam_role_arn = iam_response["Role"]["Arn"]
sfn_defintion = {
"Comment": "A simple sequential workflow",
"StartAt": "FirstState",
"States": {"FirstState": {"Type": "Pass", "Result": "Hello, World!", "End": True}},
}
definition_string = json.dumps(sfn_defintion)
sfn_client.create_state_machine(
name=state_machine_name, definition=definition_string, roleArn=iam_role_arn
)
sfn_client.create_activity(name=activity_name)
except Exception as exception:
print("Something went wrong with Step Functions setup", exception)

except Exception as exception:
print("Unexpected exception occurred", exception)

Expand All @@ -363,6 +487,9 @@ def inject_200_success(**kwargs):
guardrail_id = kwargs.get("guardrailId")
if guardrail_id is not None:
response_body["guardrailId"] = guardrail_id
guardrail_arn = kwargs.get("guardrailArn")
if guardrail_arn is not None:
response_body["guardrailArn"] = guardrail_arn

HTTPResponse = namedtuple("HTTPResponse", ["status_code", "headers", "body"])
headers = kwargs.get("headers", {})
Expand All @@ -371,6 +498,16 @@ def inject_200_success(**kwargs):
return http_response, response_body


def inject_500_error(api_name: str, **kwargs):
raise ClientError(
{
"Error": {"Code": "InternalServerError", "Message": "Internal Server Error"},
"ResponseMetadata": {"HTTPStatusCode": 500, "RequestId": "mock-request-id"},
},
api_name,
)


def main() -> None:
prepare_aws_server()
server_address: Tuple[str, int] = ("0.0.0.0", _PORT)
Expand Down
14 changes: 14 additions & 0 deletions contract-tests/tests/test/amazon/base/contract_test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import re
import time
from logging import INFO, Logger, getLogger
from typing import Dict, List
Expand Down Expand Up @@ -171,6 +172,12 @@ def _assert_int_attribute(self, attributes_dict: Dict[str, AnyValue], key: str,
self.assertIsNotNone(actual_value)
self.assertEqual(expected_value, actual_value.int_value)

def _assert_match_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, pattern: str) -> None:
self.assertIn(key, attributes_dict)
actual_value: AnyValue = attributes_dict[key]
self.assertIsNotNone(actual_value)
self.assertRegex(actual_value.string_value, pattern)

def check_sum(self, metric_name: str, actual_sum: float, expected_sum: float) -> None:
if metric_name is LATENCY_METRIC:
self.assertTrue(0 < actual_sum < expected_sum)
Expand Down Expand Up @@ -221,3 +228,10 @@ def _assert_metric_attributes(
self, resource_scope_metrics: List[ResourceScopeMetric], metric_name: str, expected_sum: int, **kwargs
):
self.fail("Tests must implement this function")

def _is_valid_regex(self, pattern: str) -> bool:
try:
re.compile(pattern)
return True
except re.error:
return False
Loading
Loading