Skip to content

Commit 4cc8bde

Browse files
committed
update invoke_model test case to support titan model attributes
1 parent 575014c commit 4cc8bde

File tree

4 files changed

+60
-7
lines changed

4 files changed

+60
-7
lines changed

test/contract-tests/images/applications/TestSimpleApp.AWSSDK.Core/BedrockTests.cs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.IO;
12
using System.Net;
23
using System.Text;
34
using System.Text.Json;
@@ -42,14 +43,37 @@ public void InvokeModel()
4243
{
4344
bedrockRuntime.InvokeModelAsync(new InvokeModelRequest
4445
{
45-
ModelId = "test-model",
46+
ModelId = "amazon.titan-text-express-v1",
47+
Body = new MemoryStream(Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new
48+
{
49+
inputText = "sample input text",
50+
textGenerationConfig = new
51+
{
52+
temperature = 0.123,
53+
topP = 0.456,
54+
maxTokenCount = 123,
55+
},
56+
}))),
57+
ContentType = "application/json",
4658
});
4759
return;
4860
}
4961

50-
public void InvokeModelResponse()
62+
public object InvokeModelResponse()
5163
{
52-
return;
64+
return new
65+
{
66+
inputTextTokenCount = 456,
67+
results = new object[]
68+
{
69+
new
70+
{
71+
outputText = "\nsample output text\n",
72+
tokenCount = 789,
73+
completionReason = "finish_reason"
74+
},
75+
},
76+
};
5377
}
5478

5579
public Task<GetAgentResponse> GetAgent()

test/contract-tests/images/applications/TestSimpleApp.AWSSDK.Core/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@
152152
// Reroute the Bedrock API calls to our mock responses in BedrockTests. While other services use localstack to handle the requests,
153153
// we write our own responses with the necessary data to mimic the expected behavior of the Bedrock services.
154154
app.MapGet("guardrails/test-guardrail", (BedrockTests bedrock) => bedrock.GetGuardrailResponse());
155-
app.MapPost("model/test-model/invoke", (BedrockTests bedrock) => bedrock.InvokeModelResponse());
155+
app.MapPost("model/amazon.titan-text-express-v1/invoke", (BedrockTests bedrock) => bedrock.InvokeModelResponse());
156156
app.MapGet("agents/test-agent", (BedrockTests bedrock) => bedrock.GetAgentResponse());
157157
app.MapGet("knowledgebases/test-knowledge-base", (BedrockTests bedrock) => bedrock.GetKnowledgeBaseResponse());
158158
app.MapGet("knowledgebases/test-knowledge-base/datasources/test-data-source", (BedrockTests bedrock) => bedrock.GetDataSourceResponse());

test/contract-tests/tests/test/amazon/awssdk/awssdk_test.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: str = "aws.bedrock.knowledge_base.id"
3333
_AWS_BEDROCK_DATA_SOURCE_ID: str = "aws.bedrock.data_source.id"
3434
_GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model"
35+
_GEN_AI_REQUEST_TOP_P: str = "gen_ai.request.top_p"
36+
_GEN_AI_REQUEST_TEMPERATURE: str = "gen_ai.request.temperature"
37+
_GEN_AI_REQUEST_MAX_TOKENS: str = "gen_ai.request.max_tokens"
38+
_GEN_AI_USAGE_INPUT_TOKENS: str = "gen_ai.usage.input_tokens"
39+
_GEN_AI_USAGE_OUTPUT_TOKENS: str = "gen_ai.usage.output_tokens"
40+
_GEN_AI_RESPONSE_FINISH_REASONS: str = "gen_ai.response.finish_reasons"
3541

3642

3743
# pylint: disable=too-many-public-methods
@@ -328,9 +334,15 @@ def test_bedrock_runtime_invoke_model(self):
328334
remote_service="AWS::BedrockRuntime",
329335
remote_operation="InvokeModel",
330336
remote_resource_type="AWS::Bedrock::Model",
331-
remote_resource_identifier="test-model",
337+
remote_resource_identifier="amazon.titan-text-express-v1",
332338
request_specific_attributes={
333-
_GEN_AI_REQUEST_MODEL: "test-model",
339+
_GEN_AI_REQUEST_MODEL: "amazon.titan-text-express-v1",
340+
_GEN_AI_REQUEST_TEMPERATURE: 0.123,
341+
_GEN_AI_REQUEST_TOP_P: 0.456,
342+
_GEN_AI_REQUEST_MAX_TOKENS: 123,
343+
_GEN_AI_USAGE_INPUT_TOKENS: 456,
344+
_GEN_AI_USAGE_OUTPUT_TOKENS: 789,
345+
_GEN_AI_RESPONSE_FINISH_REASONS: ["finish_reason"],
334346
},
335347
span_name="Bedrock Runtime.InvokeModel",
336348
)
@@ -505,6 +517,11 @@ def _assert_semantic_conventions_attributes(
505517
self._assert_str_attribute(attributes_dict, key, value)
506518
elif isinstance(value, int):
507519
self._assert_int_attribute(attributes_dict, key, value)
520+
elif isinstance(value, float):
521+
self._assert_float_attribute(attributes_dict, key, value)
522+
# value is a list: gen_ai.response.finish_reasons or aws.table_name
523+
elif key == _GEN_AI_RESPONSE_FINISH_REASONS:
524+
self._assert_invoke_model_finish_reasons(attributes_dict, key, value)
508525
else:
509526
self._assert_array_value_ddb_table_name(attributes_dict, key, value)
510527

@@ -580,6 +597,12 @@ def _assert_service_dp_attributes(self, service_dp: ExponentialHistogramDataPoin
580597
def _assert_array_value_ddb_table_name(self, attributes_dict: Dict[str, AnyValue], key: str, expect_values: list):
581598
self.assertIn(key, attributes_dict)
582599
self.assertEqual(attributes_dict[key].string_value, expect_values[0])
600+
601+
def _assert_invoke_model_finish_reasons(self, attributes_dict: Dict[str, AnyValue], key: str, expect_values: list):
602+
self.assertIn(key, attributes_dict)
603+
self.assertEqual(len(attributes_dict[key].array_value.values), len(expect_values))
604+
for i, value in enumerate(expect_values):
605+
self.assertEqual(attributes_dict[key].array_value.values[i].string_value, value)
583606

584607
def _filter_bedrock_metrics(self, target_metrics: List[Metric]):
585608
bedrock_calls = {
@@ -588,7 +611,7 @@ def _filter_bedrock_metrics(self, target_metrics: List[Metric]):
588611
"GET knowledgebases/test-knowledge-base",
589612
"GET knowledgebases/test-knowledge-base/datasources/test-data-source",
590613
"POST agents/test-agent/agentAliases/test-agent-alias/sessions/test-session/text",
591-
"POST model/test-model/invoke",
614+
"POST model/amazon.titan-text-express-v1/invoke",
592615
"POST knowledgebases/test-knowledge-base/retrieve"
593616
}
594617
for metric in target_metrics:

test/contract-tests/tests/test/amazon/base/contract_test_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,12 @@ def _assert_int_attribute(self, attributes_dict: Dict[str, AnyValue], key: str,
174174
self.assertIsNotNone(actual_value)
175175
self.assertEqual(expected_value, actual_value.int_value)
176176

177+
def _assert_float_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: float) -> None:
178+
self.assertIn(key, attributes_dict)
179+
actual_value: AnyValue = attributes_dict[key]
180+
self.assertIsNotNone(actual_value)
181+
self.assertEqual(expected_value, actual_value.double_value)
182+
177183
def check_sum(self, metric_name: str, actual_sum: float, expected_sum: float) -> None:
178184
if metric_name is LATENCY_METRIC:
179185
self.assertTrue(0 < actual_sum < expected_sum)

0 commit comments

Comments
 (0)