Skip to content
164 changes: 148 additions & 16 deletions contract-tests/images/applications/botocore/botocore_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import tempfile
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from io import BytesIO
from threading import Thread

import boto3
import requests
from botocore.client import BaseClient
from botocore.config import Config
from botocore.exceptions import ClientError
from botocore.response import StreamingBody
from typing_extensions import Tuple, override

_PORT: int = 8080
Expand Down Expand Up @@ -285,28 +287,22 @@ def _handle_bedrock_request(self) -> None:
},
)
elif self.in_path("invokemodel/invoke-model"):
model_id, request_body, response_body = get_model_request_response(self.path)

set_main_status(200)
bedrock_runtime_client.meta.events.register(
"before-call.bedrock-runtime.InvokeModel",
inject_200_success,
)
model_id = "amazon.titan-text-premier-v1:0"
user_message = "Describe the purpose of a 'hello world' program in one line."
prompt = f"<s>[INST] {user_message} [/INST]"
body = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": 3072,
"stopSequences": [],
"temperature": 0.7,
"topP": 0.9,
},
}
lambda **kwargs: inject_200_success(
modelId=model_id,
body=response_body,
**kwargs,
),
)
accept = "application/json"
content_type = "application/json"
bedrock_runtime_client.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
bedrock_runtime_client.invoke_model(
body=request_body, modelId=model_id, accept=accept, contentType=content_type
)
else:
set_main_status(404)

Expand Down Expand Up @@ -378,6 +374,137 @@ def _end_request(self, status_code: int):
self.end_headers()


def get_model_request_response(path):
prompt = "Describe the purpose of a 'hello world' program in one line."
model_id = ""
request_body = {}
response_body = {}

if "amazon.titan" in path:
model_id = "amazon.titan-text-premier-v1:0"

request_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": 3072,
"stopSequences": [],
"temperature": 0.7,
"topP": 0.9,
},
}

response_body = {
"inputTextTokenCount": 15,
"results": [
{
"tokenCount": 13,
"outputText": "text-test-response",
"completionReason": "CONTENT_FILTERED",
},
],
}

if "anthropic.claude" in path:
model_id = "anthropic.claude-v2:1"

request_body = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 1000,
"temperature": 0.99,
"top_p": 1,
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
},
],
}

response_body = {
"stop_reason": "end_turn",
"usage": {
"input_tokens": 15,
"output_tokens": 13,
},
}

if "meta.llama" in path:
model_id = "meta.llama2-13b-chat-v1"

request_body = {"prompt": prompt, "max_gen_len": 512, "temperature": 0.5, "top_p": 0.9}

response_body = {"prompt_token_count": 31, "generation_token_count": 49, "stop_reason": "stop"}

if "cohere.command" in path:
model_id = "cohere.command-r-v1:0"

request_body = {
"chat_history": [],
"message": prompt,
"max_tokens": 512,
"temperature": 0.5,
"p": 0.65,
}

response_body = {
"chat_history": [
{"role": "USER", "message": prompt},
{"role": "CHATBOT", "message": "test-text-output"},
],
"finish_reason": "COMPLETE",
"text": "test-generation-text",
}

if "ai21.jamba" in path:
model_id = "ai21.jamba-1-5-large-v1:0"

request_body = {
"messages": [
{
"role": "user",
"content": prompt,
},
],
"top_p": 0.8,
"temperature": 0.6,
"max_tokens": 512,
}

response_body = {
"stop_reason": "end_turn",
"usage": {
"prompt_tokens": 21,
"completion_tokens": 24,
},
"choices": [
{"finish_reason": "stop"},
],
}

if "mistral" in path:
model_id = "mistral.mistral-7b-instruct-v0:2"

request_body = {
"prompt": prompt,
"max_tokens": 4096,
"temperature": 0.75,
"top_p": 0.99,
}

response_body = {
"outputs": [
{
"text": "test-output-text",
"stop_reason": "stop",
},
]
}

json_bytes = json.dumps(response_body).encode("utf-8")

return model_id, json.dumps(request_body), StreamingBody(BytesIO(json_bytes), len(json_bytes))


def set_main_status(status: int) -> None:
RequestHandler.main_status = status

Expand Down Expand Up @@ -490,11 +617,16 @@ def inject_200_success(**kwargs):
guardrail_arn = kwargs.get("guardrailArn")
if guardrail_arn is not None:
response_body["guardrailArn"] = guardrail_arn
model_id = kwargs.get("modelId")
if model_id is not None:
response_body["modelId"] = model_id

HTTPResponse = namedtuple("HTTPResponse", ["status_code", "headers", "body"])
headers = kwargs.get("headers", {})
body = kwargs.get("body", "")
response_body["body"] = body
http_response = HTTPResponse(200, headers=headers, body=body)

return http_response, response_body


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def _assert_int_attribute(self, attributes_dict: Dict[str, AnyValue], key: str,
self.assertIsNotNone(actual_value)
self.assertEqual(expected_value, actual_value.int_value)

def _assert_float_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: float) -> None:
self.assertIn(key, attributes_dict)
actual_value: AnyValue = attributes_dict[key]
self.assertIsNotNone(actual_value)
self.assertEqual(expected_value, actual_value.double_value)

def _assert_match_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, pattern: str) -> None:
self.assertIn(key, attributes_dict)
actual_value: AnyValue = attributes_dict[key]
Expand Down Expand Up @@ -237,5 +243,5 @@ def _is_valid_regex(self, pattern: str) -> bool:
try:
re.compile(pattern)
return True
except re.error:
except (re.error, StopIteration, RuntimeError, KeyError):
return False
Loading
Loading