Skip to content

Commit 4797a7d

Browse files
committed
added assertions for response and usage inference parameter attributes
1 parent 60eec28 commit 4797a7d

File tree

2 files changed

+174
-66
lines changed

2 files changed

+174
-66
lines changed

contract-tests/images/applications/aws-sdk/server.js

Lines changed: 130 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ const { S3Client, CreateBucketCommand, PutObjectCommand, GetObjectCommand } = re
1010
const { DynamoDBClient, CreateTableCommand, PutItemCommand } = require('@aws-sdk/client-dynamodb');
1111
const { SQSClient, CreateQueueCommand, SendMessageCommand, ReceiveMessageCommand } = require('@aws-sdk/client-sqs');
1212
const { KinesisClient, CreateStreamCommand, PutRecordCommand } = require('@aws-sdk/client-kinesis');
13-
const fetch = require('node-fetch');
1413
const { BedrockClient, GetGuardrailCommand } = require('@aws-sdk/client-bedrock');
1514
const { BedrockAgentClient, GetKnowledgeBaseCommand, GetDataSourceCommand, GetAgentCommand } = require('@aws-sdk/client-bedrock-agent');
1615
const { BedrockRuntimeClient, InvokeModelCommand } = require('@aws-sdk/client-bedrock-runtime');
@@ -553,28 +552,44 @@ async function handleBedrockRequest(req, res, path) {
553552
});
554553
res.statusCode = 200;
555554
} else if (path.includes('invokemodel/invoke-model')) {
556-
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], {}, async () => {
557-
let modelId = ''
558-
let body = {}
559-
const userMessage = "Describe the purpose of a 'hello world' program in one line.";
560-
const prompt = `<s>[INST] ${userMessage} [/INST]`;
561-
562-
if (path.includes('amazon.titan')) {
555+
const get_model_request_response = function () {
556+
const prompt = "Describe the purpose of a 'hello world' program in one line.";
557+
let modelId = ''
558+
let request_body = {}
559+
let response_body = {}
560+
561+
if (path.includes('amazon.titan')) {
562+
563563
modelId = 'amazon.titan-text-premier-v1:0';
564-
body = JSON.stringify({
564+
565+
request_body = {
565566
inputText: prompt,
566567
textGenerationConfig: {
567568
maxTokenCount: 3072,
568569
stopSequences: [],
569570
temperature: 0.7,
570571
topP: 0.9,
571572
},
572-
});
573-
}
573+
};
574+
575+
response_body = {
576+
inputTextTokenCount: 15,
577+
results: [
578+
{
579+
tokenCount: 13,
580+
outputText: 'text-test-response',
581+
completionReason: 'CONTENT_FILTERED',
582+
},
583+
],
584+
}
585+
586+
}
574587

575-
if (path.includes('anthropic.claude')) {
588+
if (path.includes('anthropic.claude')) {
589+
576590
modelId = 'anthropic.claude-v2:1';
577-
body = JSON.stringify({
591+
592+
request_body = {
578593
anthropic_version: 'bedrock-2023-05-31',
579594
max_tokens: 1000,
580595
temperature: 0.99,
@@ -585,64 +600,120 @@ async function handleBedrockRequest(req, res, path) {
585600
content: [{ type: 'text', text: prompt }],
586601
},
587602
],
588-
});
589-
}
590-
591-
if (path.includes('meta.llama')) {
592-
modelId = 'meta.llama2-13b-chat-v1';
593-
body = JSON.stringify({
594-
prompt,
595-
max_gen_len: 512,
596-
temperature: 0.5,
597-
top_p: 0.9
598-
});
599-
}
603+
};
600604

601-
if (path.includes('cohere.command')) {
602-
modelId = 'cohere.command-light-text-v14';
603-
body = JSON.stringify({
604-
prompt,
605-
max_tokens: 512,
606-
temperature: 0.5,
607-
p: 0.65,
608-
});
609-
}
610-
611-
if (path.includes('ai21.jamba')) {
612-
modelId = 'ai21.jamba-1-5-large-v1:0';
613-
body = JSON.stringify({
614-
messages: [
615-
{
616-
role: 'user',
617-
content: prompt,
605+
response_body = {
606+
stop_reason: 'end_turn',
607+
usage: {
608+
input_tokens: 15,
609+
output_tokens: 13,
618610
},
619-
],
620-
top_p: 0.8,
621-
temperature: 0.6,
622-
max_tokens: 512,
623-
});
624-
}
625-
626-
if (path.includes('mistral.mistral')) {
627-
modelId = 'mistral.mistral-7b-instruct-v0:2';
628-
body = JSON.stringify({
629-
prompt,
630-
max_tokens: 4096,
631-
temperature: 0.75,
632-
top_p: 0.99,
633-
});
611+
}
612+
}
613+
614+
if (path.includes('meta.llama')) {
615+
modelId = 'meta.llama2-13b-chat-v1';
616+
617+
request_body = {
618+
prompt,
619+
max_gen_len: 512,
620+
temperature: 0.5,
621+
top_p: 0.9
622+
};
623+
624+
response_body = {
625+
prompt_token_count: 31,
626+
generation_token_count: 49,
627+
stop_reason: 'stop'
628+
}
629+
}
630+
631+
if (path.includes('cohere.command')) {
632+
modelId = 'cohere.command-light-text-v14';
633+
634+
request_body = {
635+
prompt,
636+
max_tokens: 512,
637+
temperature: 0.5,
638+
p: 0.65,
639+
};
640+
641+
response_body = {
642+
generations: [
643+
{
644+
finish_reason: 'COMPLETE',
645+
text: 'test-generation-text',
646+
},
647+
],
648+
prompt: prompt,
649+
};
650+
}
651+
652+
if (path.includes('ai21.jamba')) {
653+
modelId = 'ai21.jamba-1-5-large-v1:0';
654+
655+
request_body = {
656+
messages: [
657+
{
658+
role: 'user',
659+
content: prompt,
660+
},
661+
],
662+
top_p: 0.8,
663+
temperature: 0.6,
664+
max_tokens: 512,
665+
};
666+
667+
response_body = {
668+
stop_reason: 'end_turn',
669+
usage: {
670+
prompt_tokens: 21,
671+
completion_tokens: 24,
672+
},
673+
choices: [
674+
{
675+
finish_reason: 'stop',
676+
},
677+
],
678+
}
679+
}
680+
681+
if (path.includes('mistral.mistral')) {
682+
modelId = 'mistral.mistral-7b-instruct-v0:2';
683+
684+
request_body = {
685+
prompt,
686+
max_tokens: 4096,
687+
temperature: 0.75,
688+
top_p: 0.99,
689+
};
690+
691+
response_body = {
692+
outputs: [
693+
{
694+
text: 'test-output-text',
695+
stop_reason: 'stop',
696+
},
697+
]
698+
}
699+
}
700+
701+
return [modelId, JSON.stringify(request_body), new TextEncoder().encode(JSON.stringify(response_body))]
634702
}
703+
704+
const [modelId, request_body, response_body] = get_model_request_response();
635705

706+
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], { body: response_body }, async () => {
636707
await bedrockRuntimeClient.send(
637708
new InvokeModelCommand({
638-
body: body,
709+
body: request_body,
639710
modelId: modelId,
640711
accept: 'application/json',
641712
contentType: 'application/json',
642713
})
643714
);
644715
});
645-
716+
646717
res.statusCode = 200;
647718
} else {
648719
res.statusCode = 404;

contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
from logging import INFO, Logger, getLogger
4+
import math
45
from typing import Dict, List
56

67
from docker.types import EndpointConfig
@@ -37,7 +38,9 @@
3738
_GEN_AI_REQUEST_TEMPERATURE: str = "gen_ai.request.temperature"
3839
_GEN_AI_REQUEST_TOP_P: str = "gen_ai.request.top_p"
3940
_GEN_AI_REQUEST_MAX_TOKENS: str = "gen_ai.request.max_tokens"
40-
41+
_GEN_AI_RESPONSE_FINISH_REASONS: str = "gen_ai.response.finish_reasons"
42+
_GEN_AI_USAGE_INPUT_TOKENS: str = 'gen_ai.usage.input_tokens'
43+
_GEN_AI_USAGE_OUTPUT_TOKENS: str = 'gen_ai.usage.output_tokens'
4144

4245
# pylint: disable=too-many-public-methods
4346
class AWSSDKTest(ContractTestBase):
@@ -410,7 +413,7 @@ def test_kinesis_fault(self):
410413
)
411414

412415
def test_bedrock_runtime_invoke_model_amazon_titan(self):
413-
self.do_test_requests(
416+
result = self.do_test_requests(
414417
"bedrock/invokemodel/invoke-model/amazon.titan-text-premier-v1:0",
415418
"GET",
416419
200,
@@ -428,9 +431,15 @@ def test_bedrock_runtime_invoke_model_amazon_titan(self):
428431
_GEN_AI_REQUEST_TEMPERATURE: 0.7,
429432
_GEN_AI_REQUEST_TOP_P: 0.9
430433
},
434+
response_specific_attributes={
435+
_GEN_AI_RESPONSE_FINISH_REASONS: ['CONTENT_FILTERED'],
436+
_GEN_AI_USAGE_INPUT_TOKENS: 15,
437+
_GEN_AI_USAGE_OUTPUT_TOKENS: 13
438+
},
439+
431440
span_name="BedrockRuntime.InvokeModel"
432441
)
433-
442+
434443
def test_bedrock_runtime_invoke_model_anthropic_claude(self):
435444
self.do_test_requests(
436445
"bedrock/invokemodel/invoke-model/anthropic.claude-v2:1",
@@ -450,6 +459,11 @@ def test_bedrock_runtime_invoke_model_anthropic_claude(self):
450459
_GEN_AI_REQUEST_TEMPERATURE: 0.99,
451460
_GEN_AI_REQUEST_TOP_P: 1
452461
},
462+
response_specific_attributes={
463+
_GEN_AI_RESPONSE_FINISH_REASONS: ['end_turn'],
464+
_GEN_AI_USAGE_INPUT_TOKENS: 15,
465+
_GEN_AI_USAGE_OUTPUT_TOKENS: 13
466+
},
453467
span_name="BedrockRuntime.InvokeModel"
454468
)
455469

@@ -472,6 +486,11 @@ def test_bedrock_runtime_invoke_model_meta_llama(self):
472486
_GEN_AI_REQUEST_TEMPERATURE: 0.5,
473487
_GEN_AI_REQUEST_TOP_P: 0.9
474488
},
489+
response_specific_attributes={
490+
_GEN_AI_RESPONSE_FINISH_REASONS: ['stop'],
491+
_GEN_AI_USAGE_INPUT_TOKENS: 31,
492+
_GEN_AI_USAGE_OUTPUT_TOKENS: 49
493+
},
475494
span_name="BedrockRuntime.InvokeModel"
476495
)
477496

@@ -494,6 +513,11 @@ def test_bedrock_runtime_invoke_model_cohere_command(self):
494513
_GEN_AI_REQUEST_TEMPERATURE: 0.5,
495514
_GEN_AI_REQUEST_TOP_P: 0.65
496515
},
516+
response_specific_attributes={
517+
_GEN_AI_RESPONSE_FINISH_REASONS: ['COMPLETE'],
518+
_GEN_AI_USAGE_INPUT_TOKENS: math.ceil(len("Describe the purpose of a 'hello world' program in one line.") / 6),
519+
_GEN_AI_USAGE_OUTPUT_TOKENS: math.ceil(len("test-generation-text") / 6)
520+
},
497521
span_name="BedrockRuntime.InvokeModel"
498522
)
499523

@@ -516,6 +540,11 @@ def test_bedrock_runtime_invoke_model_ai21_jamba(self):
516540
_GEN_AI_REQUEST_TEMPERATURE: 0.6,
517541
_GEN_AI_REQUEST_TOP_P: 0.8
518542
},
543+
response_specific_attributes={
544+
_GEN_AI_RESPONSE_FINISH_REASONS: ['stop'],
545+
_GEN_AI_USAGE_INPUT_TOKENS: 21,
546+
_GEN_AI_USAGE_OUTPUT_TOKENS: 24
547+
},
519548
span_name="BedrockRuntime.InvokeModel"
520549
)
521550

@@ -538,6 +567,11 @@ def test_bedrock_runtime_invoke_model_mistral_mistral(self):
538567
_GEN_AI_REQUEST_TEMPERATURE: 0.75,
539568
_GEN_AI_REQUEST_TOP_P: 0.99
540569
},
570+
response_specific_attributes={
571+
_GEN_AI_RESPONSE_FINISH_REASONS: ['stop'],
572+
_GEN_AI_USAGE_INPUT_TOKENS: math.ceil(len("Describe the purpose of a 'hello world' program in one line.") / 6),
573+
_GEN_AI_USAGE_OUTPUT_TOKENS: math.ceil(len("test-output-text") / 6)
574+
},
541575
span_name="BedrockRuntime.InvokeModel"
542576
)
543577

@@ -654,9 +688,6 @@ def test_bedrock_agent_get_data_source(self):
654688
},
655689
span_name="BedrockAgent.GetDataSource",
656690
)
657-
658-
# def test_bedrock_agent_runtime_invoke_agent(self):
659-
# return None
660691

661692
@override
662693
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None:
@@ -726,6 +757,7 @@ def _assert_semantic_conventions_span_attributes(
726757
kwargs.get("remote_operation"),
727758
status_code,
728759
kwargs.get("request_specific_attributes", {}),
760+
kwargs.get("response_specific_attributes", {}),
729761
)
730762

731763
# pylint: disable=unidiomatic-typecheck
@@ -736,6 +768,7 @@ def _assert_semantic_conventions_attributes(
736768
operation: str,
737769
status_code: int,
738770
request_specific_attributes: dict,
771+
response_specific_attributes: dict,
739772
) -> None:
740773
attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list)
741774
self._assert_str_attribute(attributes_dict, SpanAttributes.RPC_METHOD, operation)
@@ -744,7 +777,11 @@ def _assert_semantic_conventions_attributes(
744777
self._assert_int_attribute(attributes_dict, SpanAttributes.HTTP_STATUS_CODE, status_code)
745778
# TODO: aws sdk instrumentation is not respecting PEER_SERVICE
746779
# self._assert_str_attribute(attributes_dict, SpanAttributes.PEER_SERVICE, "backend:8080")
747-
for key, value in request_specific_attributes.items():
780+
self._assert_specific_attributes(attributes_dict, request_specific_attributes)
781+
self._assert_specific_attributes(attributes_dict, response_specific_attributes)
782+
783+
def _assert_specific_attributes(self, attributes_dict: Dict[str, AnyValue], specific_attributes: Dict[str, AnyValue]) -> None:
784+
for key, value in specific_attributes.items():
748785
if isinstance(value, str):
749786
self._assert_str_attribute(attributes_dict, key, value)
750787
elif isinstance(value, int):

0 commit comments

Comments
 (0)