Skip to content

Commit a043f25

Browse files
authored
add contract tests for bedrock (#106)
*Issue #, if available:* *Description of changes:* Add contract tests for Bedrock instrumentation which include the following APIs: * BedrockRuntime.InvokeModel * Bedrock.GetGuardrail * BedrockAgentRuntime.InvokeAgent * BedrockAgentRuntime.Retrieve * BedrockAgent.GetAgent * BedrockAgent.GetKnowledgeBase * BedrockAgent.GetDataSource By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 3773450 commit a043f25

File tree

3 files changed

+276
-0
lines changed

3 files changed

+276
-0
lines changed

contract-tests/images/applications/aws-sdk/package.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
"license": "ISC",
1111
"description": "",
1212
"dependencies": {
13+
"@aws-sdk/client-bedrock": "^3.675.0",
14+
"@aws-sdk/client-bedrock-agent": "^3.675.0",
15+
"@aws-sdk/client-bedrock-agent-runtime": "^3.676.0",
16+
"@aws-sdk/client-bedrock-runtime": "^3.675.0",
1317
"@aws-sdk/client-dynamodb": "^3.658.1",
1418
"@aws-sdk/client-kinesis": "^3.658.1",
1519
"@aws-sdk/client-s3": "^3.658.1",

contract-tests/images/applications/aws-sdk/server.js

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ const { DynamoDBClient, CreateTableCommand, PutItemCommand } = require('@aws-sdk
1111
const { SQSClient, CreateQueueCommand, SendMessageCommand, ReceiveMessageCommand } = require('@aws-sdk/client-sqs');
1212
const { KinesisClient, CreateStreamCommand, PutRecordCommand } = require('@aws-sdk/client-kinesis');
1313
const fetch = require('node-fetch');
14+
const { BedrockClient, GetGuardrailCommand } = require('@aws-sdk/client-bedrock');
15+
const { BedrockAgentClient, GetKnowledgeBaseCommand, GetDataSourceCommand, GetAgentCommand } = require('@aws-sdk/client-bedrock-agent');
16+
const { BedrockRuntimeClient, InvokeModelCommand } = require('@aws-sdk/client-bedrock-runtime');
17+
const { BedrockAgentRuntimeClient, InvokeAgentCommand, RetrieveCommand } = require('@aws-sdk/client-bedrock-agent-runtime');
18+
1419

1520
const _PORT = 8080;
1621
const _ERROR = 'error';
@@ -141,6 +146,8 @@ async function handleGetRequest(req, res, path) {
141146
await handleSqsRequest(req, res, path);
142147
} else if (path.includes('kinesis')) {
143148
await handleKinesisRequest(req, res, path);
149+
} else if (path.includes('bedrock')) {
150+
await handleBedrockRequest(req, res, path);
144151
} else {
145152
res.writeHead(404);
146153
res.end();
@@ -485,6 +492,132 @@ async function handleKinesisRequest(req, res, path) {
485492
}
486493
}
487494

495+
async function handleBedrockRequest(req, res, path) {
496+
const bedrockClient = new BedrockClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
497+
const bedrockAgentClient = new BedrockAgentClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
498+
const bedrockRuntimeClient = new BedrockRuntimeClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
499+
const bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({ endpoint: _AWS_SDK_ENDPOINT, region: _AWS_REGION });
500+
501+
try {
502+
if (path.includes('getknowledgebase/get_knowledge_base')) {
503+
await withInjected200Success(bedrockAgentClient, ['GetKnowledgeBaseCommand'], {}, async () => {
504+
await bedrockAgentClient.send(new GetKnowledgeBaseCommand({ knowledgeBaseId: 'invalid-knowledge-base-id' }));
505+
});
506+
res.statusCode = 200;
507+
} else if (path.includes('getdatasource/get_data_source')) {
508+
await withInjected200Success(bedrockAgentClient, ['GetDataSourceCommand'], {}, async () => {
509+
await bedrockAgentClient.send(new GetDataSourceCommand({ knowledgeBaseId: 'TESTKBSEID', dataSourceId: 'DATASURCID' }));
510+
});
511+
res.statusCode = 200;
512+
} else if (path.includes('getagent/get-agent')) {
513+
await withInjected200Success(bedrockAgentClient, ['GetAgentCommand'], {}, async () => {
514+
await bedrockAgentClient.send(new GetAgentCommand({ agentId: 'TESTAGENTID' }));
515+
});
516+
res.statusCode = 200;
517+
} else if (path.includes('getguardrail/get-guardrail')) {
518+
await withInjected200Success(
519+
bedrockClient,
520+
['GetGuardrailCommand'],
521+
{ guardrailId: 'bt4o77i015cu' },
522+
async () => {
523+
await bedrockClient.send(
524+
new GetGuardrailCommand({
525+
guardrailIdentifier: 'arn:aws:bedrock:us-east-1:000000000000:guardrail/bt4o77i015cu',
526+
})
527+
);
528+
}
529+
);
530+
res.statusCode = 200;
531+
} else if (path.includes('invokeagent/invoke_agent')) {
532+
await withInjected200Success(bedrockAgentRuntimeClient, ['InvokeAgentCommand'], {}, async () => {
533+
await bedrockAgentRuntimeClient.send(
534+
new InvokeAgentCommand({
535+
agentId: 'Q08WFRPHVL',
536+
agentAliasId: 'testAlias',
537+
sessionId: 'testSessionId',
538+
inputText: 'Invoke agent sample input text',
539+
})
540+
);
541+
});
542+
res.statusCode = 200;
543+
} else if (path.includes('retrieve/retrieve')) {
544+
await withInjected200Success(bedrockAgentRuntimeClient, ['RetrieveCommand'], {}, async () => {
545+
await bedrockAgentRuntimeClient.send(
546+
new RetrieveCommand({
547+
knowledgeBaseId: 'test-knowledge-base-id',
548+
retrievalQuery: {
549+
text: 'an example of retrieve query',
550+
},
551+
})
552+
);
553+
});
554+
res.statusCode = 200;
555+
} else if (path.includes('invokemodel/invoke-model')) {
556+
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], {}, async () => {
557+
const modelId = 'amazon.titan-text-premier-v1:0';
558+
const userMessage = "Describe the purpose of a 'hello world' program in one line.";
559+
const prompt = `<s>[INST] ${userMessage} [/INST]`;
560+
561+
const body = JSON.stringify({
562+
inputText: prompt,
563+
textGenerationConfig: {
564+
maxTokenCount: 3072,
565+
stopSequences: [],
566+
temperature: 0.7,
567+
topP: 0.9,
568+
},
569+
});
570+
571+
await bedrockRuntimeClient.send(
572+
new InvokeModelCommand({
573+
body: body,
574+
modelId: modelId,
575+
accept: 'application/json',
576+
contentType: 'application/json',
577+
})
578+
);
579+
});
580+
res.statusCode = 200;
581+
} else {
582+
res.statusCode = 404;
583+
}
584+
} catch (error) {
585+
console.error('An error occurred:', error);
586+
res.statusCode = 500;
587+
}
588+
589+
res.end();
590+
}
591+
592+
function inject200Success(client, commandNames, additionalResponse = {}, middlewareName = 'inject200SuccessMiddleware') {
593+
const middleware = (next, context) => async (args) => {
594+
const { commandName } = context;
595+
if (commandNames.includes(commandName)) {
596+
const response = {
597+
$metadata: {
598+
httpStatusCode: 200,
599+
requestId: 'mock-request-id',
600+
},
601+
Message: 'Request succeeded',
602+
...additionalResponse,
603+
};
604+
return { output: response };
605+
}
606+
return next(args);
607+
};
608+
// this middleware intercept the request and inject the response
609+
client.middlewareStack.add(middleware, { step: 'build', name: middlewareName, priority: 'high' });
610+
}
611+
612+
async function withInjected200Success(client, commandNames, additionalResponse, apiCall) {
613+
const middlewareName = 'inject200SuccessMiddleware';
614+
inject200Success(client, commandNames, additionalResponse, middlewareName);
615+
await apiCall();
616+
client.middlewareStack.remove(middlewareName);
617+
}
618+
619+
620+
488621
prepareAwsServer().then(() => {
489622
server.listen(_PORT, '0.0.0.0', () => {
490623
console.log('Server is listening on port', _PORT);

contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
_AWS_SQS_QUEUE_URL: str = "aws.sqs.queue.url"
3030
_AWS_SQS_QUEUE_NAME: str = "aws.sqs.queue.name"
3131
_AWS_KINESIS_STREAM_NAME: str = "aws.kinesis.stream.name"
32+
_AWS_BEDROCK_AGENT_ID: str = "aws.bedrock.agent.id"
33+
_AWS_BEDROCK_GUARDRAIL_ID: str = "aws.bedrock.guardrail.id"
34+
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: str = "aws.bedrock.knowledge_base.id"
35+
_AWS_BEDROCK_DATA_SOURCE_ID: str = "aws.bedrock.data_source.id"
36+
_GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model"
37+
3238

3339
# pylint: disable=too-many-public-methods
3440
class AWSSDKTest(ContractTestBase):
@@ -400,6 +406,139 @@ def test_kinesis_fault(self):
400406
span_name="Kinesis.PutRecord",
401407
)
402408

409+
def test_bedrock_runtime_invoke_model(self):
410+
self.do_test_requests(
411+
"bedrock/invokemodel/invoke-model",
412+
"GET",
413+
200,
414+
0,
415+
0,
416+
local_operation="GET /bedrock",
417+
rpc_service="BedrockRuntime",
418+
remote_service="AWS::BedrockRuntime",
419+
remote_operation="InvokeModel",
420+
remote_resource_type="AWS::Bedrock::Model",
421+
remote_resource_identifier="amazon.titan-text-premier-v1:0",
422+
request_specific_attributes={
423+
_GEN_AI_REQUEST_MODEL: "amazon.titan-text-premier-v1:0",
424+
},
425+
span_name="BedrockRuntime.InvokeModel",
426+
)
427+
428+
def test_bedrock_get_guardrail(self):
429+
self.do_test_requests(
430+
"bedrock/getguardrail/get-guardrail",
431+
"GET",
432+
200,
433+
0,
434+
0,
435+
local_operation="GET /bedrock",
436+
rpc_service="Bedrock",
437+
remote_service="AWS::Bedrock",
438+
remote_operation="GetGuardrail",
439+
remote_resource_type="AWS::Bedrock::Guardrail",
440+
remote_resource_identifier="bt4o77i015cu",
441+
request_specific_attributes={
442+
_AWS_BEDROCK_GUARDRAIL_ID: "bt4o77i015cu",
443+
},
444+
span_name="Bedrock.GetGuardrail",
445+
)
446+
447+
def test_bedrock_agent_runtime_invoke_agent(self):
448+
self.do_test_requests(
449+
"bedrock/invokeagent/invoke_agent",
450+
"GET",
451+
200,
452+
0,
453+
0,
454+
local_operation="GET /bedrock",
455+
rpc_service="BedrockAgentRuntime",
456+
remote_service="AWS::Bedrock",
457+
remote_operation="InvokeAgent",
458+
remote_resource_type="AWS::Bedrock::Agent",
459+
remote_resource_identifier="Q08WFRPHVL",
460+
request_specific_attributes={
461+
_AWS_BEDROCK_AGENT_ID: "Q08WFRPHVL",
462+
},
463+
span_name="BedrockAgentRuntime.InvokeAgent",
464+
)
465+
466+
def test_bedrock_agent_runtime_retrieve(self):
467+
self.do_test_requests(
468+
"bedrock/retrieve/retrieve",
469+
"GET",
470+
200,
471+
0,
472+
0,
473+
local_operation="GET /bedrock",
474+
rpc_service="BedrockAgentRuntime",
475+
remote_service="AWS::Bedrock",
476+
remote_operation="Retrieve",
477+
remote_resource_type="AWS::Bedrock::KnowledgeBase",
478+
remote_resource_identifier="test-knowledge-base-id",
479+
request_specific_attributes={
480+
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "test-knowledge-base-id",
481+
},
482+
span_name="BedrockAgentRuntime.Retrieve",
483+
)
484+
485+
def test_bedrock_agent_get_agent(self):
486+
self.do_test_requests(
487+
"bedrock/getagent/get-agent",
488+
"GET",
489+
200,
490+
0,
491+
0,
492+
local_operation="GET /bedrock",
493+
rpc_service="BedrockAgent",
494+
remote_service="AWS::Bedrock",
495+
remote_operation="GetAgent",
496+
remote_resource_type="AWS::Bedrock::Agent",
497+
remote_resource_identifier="TESTAGENTID",
498+
request_specific_attributes={
499+
_AWS_BEDROCK_AGENT_ID: "TESTAGENTID",
500+
},
501+
span_name="BedrockAgent.GetAgent",
502+
)
503+
504+
def test_bedrock_agent_get_knowledge_base(self):
505+
self.do_test_requests(
506+
"bedrock/getknowledgebase/get_knowledge_base",
507+
"GET",
508+
200,
509+
0,
510+
0,
511+
local_operation="GET /bedrock",
512+
rpc_service="BedrockAgent",
513+
remote_service="AWS::Bedrock",
514+
remote_operation="GetKnowledgeBase",
515+
remote_resource_type="AWS::Bedrock::KnowledgeBase",
516+
remote_resource_identifier="invalid-knowledge-base-id",
517+
request_specific_attributes={
518+
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "invalid-knowledge-base-id",
519+
},
520+
span_name="BedrockAgent.GetKnowledgeBase",
521+
)
522+
523+
def test_bedrock_agent_get_data_source(self):
524+
self.do_test_requests(
525+
"bedrock/getdatasource/get_data_source",
526+
"GET",
527+
200,
528+
0,
529+
0,
530+
local_operation="GET /bedrock",
531+
rpc_service="BedrockAgent",
532+
remote_service="AWS::Bedrock",
533+
remote_operation="GetDataSource",
534+
remote_resource_type="AWS::Bedrock::DataSource",
535+
remote_resource_identifier="DATASURCID",
536+
request_specific_attributes={
537+
_AWS_BEDROCK_DATA_SOURCE_ID: "DATASURCID",
538+
},
539+
span_name="BedrockAgent.GetDataSource",
540+
)
541+
403542
@override
404543
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None:
405544
target_spans: List[Span] = []

0 commit comments

Comments
 (0)