|
6 | 6 | import tempfile
|
7 | 7 | from collections import namedtuple
|
8 | 8 | from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
| 9 | +from io import BytesIO |
9 | 10 | from threading import Thread
|
10 | 11 |
|
11 | 12 | import boto3
|
12 | 13 | import requests
|
13 | 14 | from botocore.client import BaseClient
|
14 | 15 | from botocore.config import Config
|
15 | 16 | from botocore.exceptions import ClientError
|
| 17 | +from botocore.response import StreamingBody |
16 | 18 | from typing_extensions import Tuple, override
|
17 | 19 |
|
18 | 20 | _PORT: int = 8080
|
@@ -285,28 +287,22 @@ def _handle_bedrock_request(self) -> None:
|
285 | 287 | },
|
286 | 288 | )
|
287 | 289 | elif self.in_path("invokemodel/invoke-model"):
|
| 290 | + model_id, request_body, response_body = get_model_request_response(self.path) |
| 291 | + |
288 | 292 | set_main_status(200)
|
289 | 293 | bedrock_runtime_client.meta.events.register(
|
290 | 294 | "before-call.bedrock-runtime.InvokeModel",
|
291 |
| - inject_200_success, |
292 |
| - ) |
293 |
| - model_id = "amazon.titan-text-premier-v1:0" |
294 |
| - user_message = "Describe the purpose of a 'hello world' program in one line." |
295 |
| - prompt = f"<s>[INST] {user_message} [/INST]" |
296 |
| - body = json.dumps( |
297 |
| - { |
298 |
| - "inputText": prompt, |
299 |
| - "textGenerationConfig": { |
300 |
| - "maxTokenCount": 3072, |
301 |
| - "stopSequences": [], |
302 |
| - "temperature": 0.7, |
303 |
| - "topP": 0.9, |
304 |
| - }, |
305 |
| - } |
| 295 | + lambda **kwargs: inject_200_success( |
| 296 | + modelId=model_id, |
| 297 | + body=response_body, |
| 298 | + **kwargs, |
| 299 | + ), |
306 | 300 | )
|
307 | 301 | accept = "application/json"
|
308 | 302 | content_type = "application/json"
|
309 |
| - bedrock_runtime_client.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type) |
| 303 | + bedrock_runtime_client.invoke_model( |
| 304 | + body=request_body, modelId=model_id, accept=accept, contentType=content_type |
| 305 | + ) |
310 | 306 | else:
|
311 | 307 | set_main_status(404)
|
312 | 308 |
|
@@ -378,6 +374,137 @@ def _end_request(self, status_code: int):
|
378 | 374 | self.end_headers()
|
379 | 375 |
|
380 | 376 |
|
| 377 | +def get_model_request_response(path): |
| 378 | + prompt = "Describe the purpose of a 'hello world' program in one line." |
| 379 | + model_id = "" |
| 380 | + request_body = {} |
| 381 | + response_body = {} |
| 382 | + |
| 383 | + if "amazon.titan" in path: |
| 384 | + model_id = "amazon.titan-text-premier-v1:0" |
| 385 | + |
| 386 | + request_body = { |
| 387 | + "inputText": prompt, |
| 388 | + "textGenerationConfig": { |
| 389 | + "maxTokenCount": 3072, |
| 390 | + "stopSequences": [], |
| 391 | + "temperature": 0.7, |
| 392 | + "topP": 0.9, |
| 393 | + }, |
| 394 | + } |
| 395 | + |
| 396 | + response_body = { |
| 397 | + "inputTextTokenCount": 15, |
| 398 | + "results": [ |
| 399 | + { |
| 400 | + "tokenCount": 13, |
| 401 | + "outputText": "text-test-response", |
| 402 | + "completionReason": "CONTENT_FILTERED", |
| 403 | + }, |
| 404 | + ], |
| 405 | + } |
| 406 | + |
| 407 | + if "anthropic.claude" in path: |
| 408 | + model_id = "anthropic.claude-v2:1" |
| 409 | + |
| 410 | + request_body = { |
| 411 | + "anthropic_version": "bedrock-2023-05-31", |
| 412 | + "max_tokens": 1000, |
| 413 | + "temperature": 0.99, |
| 414 | + "top_p": 1, |
| 415 | + "messages": [ |
| 416 | + { |
| 417 | + "role": "user", |
| 418 | + "content": [{"type": "text", "text": prompt}], |
| 419 | + }, |
| 420 | + ], |
| 421 | + } |
| 422 | + |
| 423 | + response_body = { |
| 424 | + "stop_reason": "end_turn", |
| 425 | + "usage": { |
| 426 | + "input_tokens": 15, |
| 427 | + "output_tokens": 13, |
| 428 | + }, |
| 429 | + } |
| 430 | + |
| 431 | + if "meta.llama" in path: |
| 432 | + model_id = "meta.llama2-13b-chat-v1" |
| 433 | + |
| 434 | + request_body = {"prompt": prompt, "max_gen_len": 512, "temperature": 0.5, "top_p": 0.9} |
| 435 | + |
| 436 | + response_body = {"prompt_token_count": 31, "generation_token_count": 49, "stop_reason": "stop"} |
| 437 | + |
| 438 | + if "cohere.command" in path: |
| 439 | + model_id = "cohere.command-r-v1:0" |
| 440 | + |
| 441 | + request_body = { |
| 442 | + "chat_history": [], |
| 443 | + "message": prompt, |
| 444 | + "max_tokens": 512, |
| 445 | + "temperature": 0.5, |
| 446 | + "p": 0.65, |
| 447 | + } |
| 448 | + |
| 449 | + response_body = { |
| 450 | + "chat_history": [ |
| 451 | + {"role": "USER", "message": prompt}, |
| 452 | + {"role": "CHATBOT", "message": "test-text-output"}, |
| 453 | + ], |
| 454 | + "finish_reason": "COMPLETE", |
| 455 | + "text": "test-generation-text", |
| 456 | + } |
| 457 | + |
| 458 | + if "ai21.jamba" in path: |
| 459 | + model_id = "ai21.jamba-1-5-large-v1:0" |
| 460 | + |
| 461 | + request_body = { |
| 462 | + "messages": [ |
| 463 | + { |
| 464 | + "role": "user", |
| 465 | + "content": prompt, |
| 466 | + }, |
| 467 | + ], |
| 468 | + "top_p": 0.8, |
| 469 | + "temperature": 0.6, |
| 470 | + "max_tokens": 512, |
| 471 | + } |
| 472 | + |
| 473 | + response_body = { |
| 474 | + "stop_reason": "end_turn", |
| 475 | + "usage": { |
| 476 | + "prompt_tokens": 21, |
| 477 | + "completion_tokens": 24, |
| 478 | + }, |
| 479 | + "choices": [ |
| 480 | + {"finish_reason": "stop"}, |
| 481 | + ], |
| 482 | + } |
| 483 | + |
| 484 | + if "mistral" in path: |
| 485 | + model_id = "mistral.mistral-7b-instruct-v0:2" |
| 486 | + |
| 487 | + request_body = { |
| 488 | + "prompt": prompt, |
| 489 | + "max_tokens": 4096, |
| 490 | + "temperature": 0.75, |
| 491 | + "top_p": 0.99, |
| 492 | + } |
| 493 | + |
| 494 | + response_body = { |
| 495 | + "outputs": [ |
| 496 | + { |
| 497 | + "text": "test-output-text", |
| 498 | + "stop_reason": "stop", |
| 499 | + }, |
| 500 | + ] |
| 501 | + } |
| 502 | + |
| 503 | + json_bytes = json.dumps(response_body).encode("utf-8") |
| 504 | + |
| 505 | + return model_id, json.dumps(request_body), StreamingBody(BytesIO(json_bytes), len(json_bytes)) |
| 506 | + |
| 507 | + |
381 | 508 | def set_main_status(status: int) -> None:
|
382 | 509 | RequestHandler.main_status = status
|
383 | 510 |
|
@@ -490,11 +617,16 @@ def inject_200_success(**kwargs):
|
490 | 617 | guardrail_arn = kwargs.get("guardrailArn")
|
491 | 618 | if guardrail_arn is not None:
|
492 | 619 | response_body["guardrailArn"] = guardrail_arn
|
| 620 | + model_id = kwargs.get("modelId") |
| 621 | + if model_id is not None: |
| 622 | + response_body["modelId"] = model_id |
493 | 623 |
|
494 | 624 | HTTPResponse = namedtuple("HTTPResponse", ["status_code", "headers", "body"])
|
495 | 625 | headers = kwargs.get("headers", {})
|
496 | 626 | body = kwargs.get("body", "")
|
| 627 | + response_body["body"] = body |
497 | 628 | http_response = HTTPResponse(200, headers=headers, body=body)
|
| 629 | + |
498 | 630 | return http_response, response_body
|
499 | 631 |
|
500 | 632 |
|
|
0 commit comments