Skip to content

Commit b53d557

Browse files
committed
- added support for additional gen ai attributes
- added contract tests for amazon titan and anthropic claude models - added support to compare float attributes
1 parent 7e70a46 commit b53d557

File tree

3 files changed

+77
-16
lines changed

3 files changed

+77
-16
lines changed

contract-tests/images/applications/aws-sdk/server.js

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -554,19 +554,39 @@ async function handleBedrockRequest(req, res, path) {
554554
res.statusCode = 200;
555555
} else if (path.includes('invokemodel/invoke-model')) {
556556
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], {}, async () => {
557-
const modelId = 'amazon.titan-text-premier-v1:0';
557+
let modelId = ''
558+
let body = {}
558559
const userMessage = "Describe the purpose of a 'hello world' program in one line.";
559560
const prompt = `<s>[INST] ${userMessage} [/INST]`;
560561

561-
const body = JSON.stringify({
562-
inputText: prompt,
563-
textGenerationConfig: {
564-
maxTokenCount: 3072,
565-
stopSequences: [],
566-
temperature: 0.7,
567-
topP: 0.9,
568-
},
569-
});
562+
if (path.includes('amazon.titan')) {
563+
modelId = 'amazon.titan-text-premier-v1:0';
564+
body = JSON.stringify({
565+
inputText: prompt,
566+
textGenerationConfig: {
567+
maxTokenCount: 3072,
568+
stopSequences: [],
569+
temperature: 0.7,
570+
topP: 0.9,
571+
},
572+
});
573+
}
574+
575+
if (path.includes('anthropic.claude')) {
576+
modelId = 'anthropic.claude-v2:1';
577+
body = JSON.stringify({
578+
anthropic_version: 'bedrock-2023-05-31',
579+
max_tokens: 1000,
580+
temperature: 1.1,
581+
top_p: 1,
582+
messages: [
583+
{
584+
role: 'user',
585+
content: [{ type: 'text', text: prompt }],
586+
},
587+
],
588+
});
589+
}
570590

571591
await bedrockRuntimeClient.send(
572592
new InvokeModelCommand({
@@ -577,6 +597,7 @@ async function handleBedrockRequest(req, res, path) {
577597
})
578598
);
579599
});
600+
580601
res.statusCode = 200;
581602
} else {
582603
res.statusCode = 404;

contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: str = "aws.bedrock.knowledge_base.id"
3535
_AWS_BEDROCK_DATA_SOURCE_ID: str = "aws.bedrock.data_source.id"
3636
_GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model"
37+
_GEN_AI_REQUEST_TEMPERATURE: str = "gen_ai.request.temperature"
38+
_GEN_AI_REQUEST_TOP_P: str = "gen_ai.request.top_p"
39+
_GEN_AI_REQUEST_MAX_TOKENS: str = "gen_ai.request.max_tokens"
3740

3841

3942
# pylint: disable=too-many-public-methods
@@ -406,9 +409,9 @@ def test_kinesis_fault(self):
406409
span_name="Kinesis.PutRecord",
407410
)
408411

409-
def test_bedrock_runtime_invoke_model(self):
412+
def test_bedrock_runtime_invoke_model_amazon_titan(self):
410413
self.do_test_requests(
411-
"bedrock/invokemodel/invoke-model",
414+
"bedrock/invokemodel/invoke-model/amazon.titan-text-premier-v1:0",
412415
"GET",
413416
200,
414417
0,
@@ -418,11 +421,36 @@ def test_bedrock_runtime_invoke_model(self):
418421
remote_service="AWS::BedrockRuntime",
419422
remote_operation="InvokeModel",
420423
remote_resource_type="AWS::Bedrock::Model",
421-
remote_resource_identifier="amazon.titan-text-premier-v1:0",
424+
remote_resource_identifier='amazon.titan-text-premier-v1:0',
422425
request_specific_attributes={
423-
_GEN_AI_REQUEST_MODEL: "amazon.titan-text-premier-v1:0",
424-
},
425-
span_name="BedrockRuntime.InvokeModel",
426+
_GEN_AI_REQUEST_MODEL: 'amazon.titan-text-premier-v1:0',
427+
_GEN_AI_REQUEST_MAX_TOKENS: 3072,
428+
_GEN_AI_REQUEST_TEMPERATURE: 0.7,
429+
_GEN_AI_REQUEST_TOP_P: 0.9
430+
},
431+
span_name="BedrockRuntime.InvokeModel"
432+
)
433+
434+
def test_bedrock_runtime_invoke_model_anthropic_claude(self):
435+
self.do_test_requests(
436+
"bedrock/invokemodel/invoke-model/anthropic.claude-v2:1",
437+
"GET",
438+
200,
439+
0,
440+
0,
441+
local_operation="GET /bedrock",
442+
rpc_service="BedrockRuntime",
443+
remote_service="AWS::BedrockRuntime",
444+
remote_operation="InvokeModel",
445+
remote_resource_type="AWS::Bedrock::Model",
446+
remote_resource_identifier='anthropic.claude-v2:1',
447+
request_specific_attributes={
448+
_GEN_AI_REQUEST_MODEL: 'anthropic.claude-v2:1',
449+
_GEN_AI_REQUEST_MAX_TOKENS: 1000,
450+
_GEN_AI_REQUEST_TEMPERATURE: 1.1,
451+
_GEN_AI_REQUEST_TOP_P: 1
452+
},
453+
span_name="BedrockRuntime.InvokeModel"
426454
)
427455

428456
def test_bedrock_get_guardrail(self):
@@ -538,6 +566,9 @@ def test_bedrock_agent_get_data_source(self):
538566
},
539567
span_name="BedrockAgent.GetDataSource",
540568
)
569+
570+
# def test_bedrock_agent_runtime_invoke_agent(self):
571+
# return None
541572

542573
@override
543574
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None:
@@ -591,6 +622,7 @@ def _assert_aws_attributes(
591622
def _assert_semantic_conventions_span_attributes(
592623
self, resource_scope_spans: List[ResourceScopeSpan], method: str, path: str, status_code: int, **kwargs
593624
) -> None:
625+
594626
target_spans: List[Span] = []
595627
for resource_scope_span in resource_scope_spans:
596628
# pylint: disable=no-member
@@ -629,6 +661,8 @@ def _assert_semantic_conventions_attributes(
629661
self._assert_str_attribute(attributes_dict, key, value)
630662
elif isinstance(value, int):
631663
self._assert_int_attribute(attributes_dict, key, value)
664+
elif isinstance(value, float):
665+
self._assert_float_attribute(attributes_dict, key, value)
632666
else:
633667
self._assert_array_value_ddb_table_name(attributes_dict, key, value)
634668

contract-tests/tests/test/amazon/base/contract_test_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ def _assert_int_attribute(self, attributes_dict: Dict[str, AnyValue], key: str,
230230
self.assertIsNotNone(actual_value)
231231
self.assertEqual(expected_value, actual_value.int_value)
232232

233+
def _assert_float_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: float) -> None:
234+
self.assertIn(key, attributes_dict)
235+
actual_value: AnyValue = attributes_dict[key]
236+
self.assertIsNotNone(actual_value)
237+
self.assertEqual(expected_value, actual_value.double_value)
238+
233239
def check_sum(self, metric_name: str, actual_sum: float, expected_sum: float) -> None:
234240
if metric_name is LATENCY_METRIC:
235241
self.assertTrue(0 < actual_sum < expected_sum)

0 commit comments

Comments
 (0)