Skip to content

Commit 27c0f80

Browse files
liustveyiyuan-he
andauthored
feat: Add Contract Tests for new Gen AI attributes for foundational models (#119)
*Description of changes:* contract tests for new gen_ai inference parameter added in e8c96ae#diff-20c2ca1cb28cda6e03ec0cb986933b2abd103bee39995ad232cc2e8c2d23e4aaR368 <img width="1344" alt="image" src="https://github.com/user-attachments/assets/1d63b019-fe49-4222-9663-34e4f10d3d5b"> By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Michael He <[email protected]>
1 parent 25fa5e9 commit 27c0f80

File tree

3 files changed

+377
-24
lines changed

3 files changed

+377
-24
lines changed

contract-tests/images/applications/aws-sdk/server.js

Lines changed: 176 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ const { S3Client, CreateBucketCommand, PutObjectCommand, GetObjectCommand } = re
1010
const { DynamoDBClient, CreateTableCommand, PutItemCommand } = require('@aws-sdk/client-dynamodb');
1111
const { SQSClient, CreateQueueCommand, SendMessageCommand, ReceiveMessageCommand } = require('@aws-sdk/client-sqs');
1212
const { KinesisClient, CreateStreamCommand, PutRecordCommand } = require('@aws-sdk/client-kinesis');
13-
const fetch = require('node-fetch');
1413
const { BedrockClient, GetGuardrailCommand } = require('@aws-sdk/client-bedrock');
1514
const { BedrockAgentClient, GetKnowledgeBaseCommand, GetDataSourceCommand, GetAgentCommand } = require('@aws-sdk/client-bedrock-agent');
1615
const { BedrockRuntimeClient, InvokeModelCommand } = require('@aws-sdk/client-bedrock-runtime');
@@ -553,30 +552,190 @@ async function handleBedrockRequest(req, res, path) {
553552
});
554553
res.statusCode = 200;
555554
} else if (path.includes('invokemodel/invoke-model')) {
556-
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], {}, async () => {
557-
const modelId = 'amazon.titan-text-premier-v1:0';
558-
const userMessage = "Describe the purpose of a 'hello world' program in one line.";
559-
const prompt = `<s>[INST] ${userMessage} [/INST]`;
560-
561-
const body = JSON.stringify({
562-
inputText: prompt,
563-
textGenerationConfig: {
564-
maxTokenCount: 3072,
565-
stopSequences: [],
566-
temperature: 0.7,
567-
topP: 0.9,
568-
},
569-
});
555+
const get_model_request_response = function () {
556+
const prompt = "Describe the purpose of a 'hello world' program in one line.";
557+
let modelId = ''
558+
let request_body = {}
559+
let response_body = {}
560+
561+
if (path.includes('amazon.titan')) {
562+
563+
modelId = 'amazon.titan-text-premier-v1:0';
564+
565+
request_body = {
566+
inputText: prompt,
567+
textGenerationConfig: {
568+
maxTokenCount: 3072,
569+
stopSequences: [],
570+
temperature: 0.7,
571+
topP: 0.9,
572+
},
573+
};
574+
575+
response_body = {
576+
inputTextTokenCount: 15,
577+
results: [
578+
{
579+
tokenCount: 13,
580+
outputText: 'text-test-response',
581+
completionReason: 'CONTENT_FILTERED',
582+
},
583+
],
584+
}
585+
586+
}
587+
588+
if (path.includes('anthropic.claude')) {
589+
590+
modelId = 'anthropic.claude-v2:1';
591+
592+
request_body = {
593+
anthropic_version: 'bedrock-2023-05-31',
594+
max_tokens: 1000,
595+
temperature: 0.99,
596+
top_p: 1,
597+
messages: [
598+
{
599+
role: 'user',
600+
content: [{ type: 'text', text: prompt }],
601+
},
602+
],
603+
};
604+
605+
response_body = {
606+
stop_reason: 'end_turn',
607+
usage: {
608+
input_tokens: 15,
609+
output_tokens: 13,
610+
},
611+
}
612+
}
613+
614+
if (path.includes('meta.llama')) {
615+
modelId = 'meta.llama2-13b-chat-v1';
616+
617+
request_body = {
618+
prompt,
619+
max_gen_len: 512,
620+
temperature: 0.5,
621+
top_p: 0.9
622+
};
623+
624+
response_body = {
625+
prompt_token_count: 31,
626+
generation_token_count: 49,
627+
stop_reason: 'stop'
628+
}
629+
}
630+
631+
if (path.includes('cohere.command')) {
632+
modelId = 'cohere.command-light-text-v14';
633+
634+
request_body = {
635+
prompt,
636+
max_tokens: 512,
637+
temperature: 0.5,
638+
p: 0.65,
639+
};
640+
641+
response_body = {
642+
generations: [
643+
{
644+
finish_reason: 'COMPLETE',
645+
text: 'test-generation-text',
646+
},
647+
],
648+
prompt: prompt,
649+
};
650+
}
651+
652+
if (path.includes('cohere.command-r')) {
653+
modelId = 'cohere.command-r-v1:0';
654+
655+
request_body = {
656+
message: prompt,
657+
max_tokens: 512,
658+
temperature: 0.5,
659+
p: 0.65,
660+
};
661+
662+
response_body = {
663+
finish_reason: 'COMPLETE',
664+
text: 'test-generation-text',
665+
prompt: prompt,
666+
request: {
667+
commandInput: {
668+
modelId: modelId,
669+
},
670+
},
671+
}
672+
}
673+
674+
if (path.includes('ai21.jamba')) {
675+
modelId = 'ai21.jamba-1-5-large-v1:0';
676+
677+
request_body = {
678+
messages: [
679+
{
680+
role: 'user',
681+
content: prompt,
682+
},
683+
],
684+
top_p: 0.8,
685+
temperature: 0.6,
686+
max_tokens: 512,
687+
};
688+
689+
response_body = {
690+
stop_reason: 'end_turn',
691+
usage: {
692+
prompt_tokens: 21,
693+
completion_tokens: 24,
694+
},
695+
choices: [
696+
{
697+
finish_reason: 'stop',
698+
},
699+
],
700+
}
701+
}
702+
703+
if (path.includes('mistral')) {
704+
modelId = 'mistral.mistral-7b-instruct-v0:2';
705+
706+
request_body = {
707+
prompt,
708+
max_tokens: 4096,
709+
temperature: 0.75,
710+
top_p: 0.99,
711+
};
712+
713+
response_body = {
714+
outputs: [
715+
{
716+
text: 'test-output-text',
717+
stop_reason: 'stop',
718+
},
719+
]
720+
}
721+
}
722+
723+
return [modelId, JSON.stringify(request_body), new TextEncoder().encode(JSON.stringify(response_body))]
724+
}
725+
726+
const [modelId, request_body, response_body] = get_model_request_response();
570727

728+
await withInjected200Success(bedrockRuntimeClient, ['InvokeModelCommand'], { body: response_body }, async () => {
571729
await bedrockRuntimeClient.send(
572730
new InvokeModelCommand({
573-
body: body,
731+
body: request_body,
574732
modelId: modelId,
575733
accept: 'application/json',
576734
contentType: 'application/json',
577735
})
578736
);
579737
});
738+
580739
res.statusCode = 200;
581740
} else {
582741
res.statusCode = 404;
@@ -624,3 +783,4 @@ prepareAwsServer().then(() => {
624783
console.log('Ready');
625784
});
626785
});
786+

0 commit comments

Comments
 (0)