Skip to content

Commit 642427e

Browse files
liustveyiyuan-he
andauthored
feat: Add Contract Tests for new Gen AI attributes for foundational models (#292)
contract tests for new gen_ai inference parameters added in #290 <img width="1563" alt="image" src="https://github.com/user-attachments/assets/3ea5979d-43b2-43d6-8730-708855969d8a"> By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Michael He <[email protected]>
1 parent d305721 commit 642427e

File tree

3 files changed

+327
-31
lines changed

3 files changed

+327
-31
lines changed

contract-tests/images/applications/botocore/botocore_server.py

Lines changed: 148 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import tempfile
77
from collections import namedtuple
88
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
9+
from io import BytesIO
910
from threading import Thread
1011

1112
import boto3
1213
import requests
1314
from botocore.client import BaseClient
1415
from botocore.config import Config
1516
from botocore.exceptions import ClientError
17+
from botocore.response import StreamingBody
1618
from typing_extensions import Tuple, override
1719

1820
_PORT: int = 8080
@@ -285,28 +287,22 @@ def _handle_bedrock_request(self) -> None:
285287
},
286288
)
287289
elif self.in_path("invokemodel/invoke-model"):
290+
model_id, request_body, response_body = get_model_request_response(self.path)
291+
288292
set_main_status(200)
289293
bedrock_runtime_client.meta.events.register(
290294
"before-call.bedrock-runtime.InvokeModel",
291-
inject_200_success,
292-
)
293-
model_id = "amazon.titan-text-premier-v1:0"
294-
user_message = "Describe the purpose of a 'hello world' program in one line."
295-
prompt = f"<s>[INST] {user_message} [/INST]"
296-
body = json.dumps(
297-
{
298-
"inputText": prompt,
299-
"textGenerationConfig": {
300-
"maxTokenCount": 3072,
301-
"stopSequences": [],
302-
"temperature": 0.7,
303-
"topP": 0.9,
304-
},
305-
}
295+
lambda **kwargs: inject_200_success(
296+
modelId=model_id,
297+
body=response_body,
298+
**kwargs,
299+
),
306300
)
307301
accept = "application/json"
308302
content_type = "application/json"
309-
bedrock_runtime_client.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
303+
bedrock_runtime_client.invoke_model(
304+
body=request_body, modelId=model_id, accept=accept, contentType=content_type
305+
)
310306
else:
311307
set_main_status(404)
312308

@@ -378,6 +374,137 @@ def _end_request(self, status_code: int):
378374
self.end_headers()
379375

380376

377+
def get_model_request_response(path):
378+
prompt = "Describe the purpose of a 'hello world' program in one line."
379+
model_id = ""
380+
request_body = {}
381+
response_body = {}
382+
383+
if "amazon.titan" in path:
384+
model_id = "amazon.titan-text-premier-v1:0"
385+
386+
request_body = {
387+
"inputText": prompt,
388+
"textGenerationConfig": {
389+
"maxTokenCount": 3072,
390+
"stopSequences": [],
391+
"temperature": 0.7,
392+
"topP": 0.9,
393+
},
394+
}
395+
396+
response_body = {
397+
"inputTextTokenCount": 15,
398+
"results": [
399+
{
400+
"tokenCount": 13,
401+
"outputText": "text-test-response",
402+
"completionReason": "CONTENT_FILTERED",
403+
},
404+
],
405+
}
406+
407+
if "anthropic.claude" in path:
408+
model_id = "anthropic.claude-v2:1"
409+
410+
request_body = {
411+
"anthropic_version": "bedrock-2023-05-31",
412+
"max_tokens": 1000,
413+
"temperature": 0.99,
414+
"top_p": 1,
415+
"messages": [
416+
{
417+
"role": "user",
418+
"content": [{"type": "text", "text": prompt}],
419+
},
420+
],
421+
}
422+
423+
response_body = {
424+
"stop_reason": "end_turn",
425+
"usage": {
426+
"input_tokens": 15,
427+
"output_tokens": 13,
428+
},
429+
}
430+
431+
if "meta.llama" in path:
432+
model_id = "meta.llama2-13b-chat-v1"
433+
434+
request_body = {"prompt": prompt, "max_gen_len": 512, "temperature": 0.5, "top_p": 0.9}
435+
436+
response_body = {"prompt_token_count": 31, "generation_token_count": 49, "stop_reason": "stop"}
437+
438+
if "cohere.command" in path:
439+
model_id = "cohere.command-r-v1:0"
440+
441+
request_body = {
442+
"chat_history": [],
443+
"message": prompt,
444+
"max_tokens": 512,
445+
"temperature": 0.5,
446+
"p": 0.65,
447+
}
448+
449+
response_body = {
450+
"chat_history": [
451+
{"role": "USER", "message": prompt},
452+
{"role": "CHATBOT", "message": "test-text-output"},
453+
],
454+
"finish_reason": "COMPLETE",
455+
"text": "test-generation-text",
456+
}
457+
458+
if "ai21.jamba" in path:
459+
model_id = "ai21.jamba-1-5-large-v1:0"
460+
461+
request_body = {
462+
"messages": [
463+
{
464+
"role": "user",
465+
"content": prompt,
466+
},
467+
],
468+
"top_p": 0.8,
469+
"temperature": 0.6,
470+
"max_tokens": 512,
471+
}
472+
473+
response_body = {
474+
"stop_reason": "end_turn",
475+
"usage": {
476+
"prompt_tokens": 21,
477+
"completion_tokens": 24,
478+
},
479+
"choices": [
480+
{"finish_reason": "stop"},
481+
],
482+
}
483+
484+
if "mistral" in path:
485+
model_id = "mistral.mistral-7b-instruct-v0:2"
486+
487+
request_body = {
488+
"prompt": prompt,
489+
"max_tokens": 4096,
490+
"temperature": 0.75,
491+
"top_p": 0.99,
492+
}
493+
494+
response_body = {
495+
"outputs": [
496+
{
497+
"text": "test-output-text",
498+
"stop_reason": "stop",
499+
},
500+
]
501+
}
502+
503+
json_bytes = json.dumps(response_body).encode("utf-8")
504+
505+
return model_id, json.dumps(request_body), StreamingBody(BytesIO(json_bytes), len(json_bytes))
506+
507+
381508
def set_main_status(status: int) -> None:
382509
RequestHandler.main_status = status
383510

@@ -490,11 +617,16 @@ def inject_200_success(**kwargs):
490617
guardrail_arn = kwargs.get("guardrailArn")
491618
if guardrail_arn is not None:
492619
response_body["guardrailArn"] = guardrail_arn
620+
model_id = kwargs.get("modelId")
621+
if model_id is not None:
622+
response_body["modelId"] = model_id
493623

494624
HTTPResponse = namedtuple("HTTPResponse", ["status_code", "headers", "body"])
495625
headers = kwargs.get("headers", {})
496626
body = kwargs.get("body", "")
627+
response_body["body"] = body
497628
http_response = HTTPResponse(200, headers=headers, body=body)
629+
498630
return http_response, response_body
499631

500632

contract-tests/tests/test/amazon/base/contract_test_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ def _assert_int_attribute(self, attributes_dict: Dict[str, AnyValue], key: str,
173173
self.assertIsNotNone(actual_value)
174174
self.assertEqual(expected_value, actual_value.int_value)
175175

176+
def _assert_float_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: float) -> None:
177+
self.assertIn(key, attributes_dict)
178+
actual_value: AnyValue = attributes_dict[key]
179+
self.assertIsNotNone(actual_value)
180+
self.assertEqual(expected_value, actual_value.double_value)
181+
176182
def _assert_match_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, pattern: str) -> None:
177183
self.assertIn(key, attributes_dict)
178184
actual_value: AnyValue = attributes_dict[key]
@@ -237,5 +243,5 @@ def _is_valid_regex(self, pattern: str) -> bool:
237243
try:
238244
re.compile(pattern)
239245
return True
240-
except re.error:
246+
except (re.error, StopIteration, RuntimeError, KeyError):
241247
return False

0 commit comments

Comments
 (0)