|
1 | 1 | #!/usr/bin/env python |
2 | 2 |
|
3 | 3 | import argparse |
| 4 | +import json |
4 | 5 | import logging |
5 | 6 | import os |
6 | 7 | from pathlib import Path |
7 | | -from typing import Dict, Optional |
| 8 | +from typing import Dict, List, NewType, Optional, TypedDict |
8 | 9 |
|
9 | 10 | import boto3 |
10 | 11 | from botocore.exceptions import ClientError |
|
16 | 17 | logger = logging.getLogger(__name__) |
17 | 18 |
|
18 | 19 |
|
| 20 | +class ConversationEntry(TypedDict): |
| 21 | + role: str |
| 22 | + content: str |
| 23 | + |
| 24 | + |
| 25 | +Conversation = NewType("Conversation", List[ConversationEntry]) |
| 26 | + |
| 27 | + |
| 28 | +class BedrockRuntime: |
| 29 | + def __init__(self, model_id: str = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"): |
| 30 | + self.client = boto3.client("bedrock-runtime") |
| 31 | + self.model_id = model_id |
| 32 | + self.base_prompt = Path( |
| 33 | + os.path.dirname(__file__), "base_prompt.txt" |
| 34 | + ).read_text() |
| 35 | + self.conversation = [ |
| 36 | + ConversationEntry({"role": "user", "content": [{"text": self.base_prompt}]}) |
| 37 | + ] |
| 38 | + |
| 39 | + def converse(self, conversation: Conversation): |
| 40 | + self.conversation.extend(conversation) |
| 41 | + response = self.client.converse( |
| 42 | + modelId=self.model_id, |
| 43 | + messages=self.conversation, |
| 44 | + inferenceConfig={"maxTokens": 512, "temperature": 0.5, "topP": 0.9}, |
| 45 | + ) |
| 46 | + response_text = response["output"]["message"]["content"][0]["text"] |
| 47 | + return response_text |
| 48 | + |
| 49 | + |
19 | 50 | def make_doc_gen(root: Path): |
20 | 51 | doc_gen = DocGen.from_root(root) |
21 | 52 | doc_gen.collect_snippets() |
22 | 53 | return doc_gen |
23 | 54 |
|
24 | 55 |
|
| 56 | +def generate_snippet_description( |
| 57 | + bedrock_runtime: BedrockRuntime, snippet: Snippet, prompt: Optional[str] |
| 58 | +) -> Dict: |
| 59 | + content = ( |
| 60 | + [{"text": prompt}, {"text": snippet.code}] |
| 61 | + if prompt |
| 62 | + else [{"text": snippet.code}] |
| 63 | + ) |
| 64 | + conversation = [ |
| 65 | + { |
| 66 | + "role": "user", |
| 67 | + "content": content, |
| 68 | + } |
| 69 | + ] |
| 70 | + |
| 71 | + response_text = bedrock_runtime.converse(conversation) |
| 72 | + |
| 73 | + try: |
| 74 | + # This assumes the response is JSON, which couples snippet |
| 75 | + # description generation to a specific prompt. |
| 76 | + return json.loads(response_text) |
| 77 | + except Exception as e: |
| 78 | + logger.warning("Failed to parse response.", response=response_text) |
| 79 | + return {} |
| 80 | + |
| 81 | + |
25 | 82 | def generate_descriptions(snippets: Dict[str, Snippet], prompt: Optional[str]): |
26 | | - client = boto3.client("bedrock-runtime", region_name="us-west-2") |
27 | | - base_prompt = Path(os.path.dirname(__file__), "base_prompt.txt").read_text() |
28 | | - model_id = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" |
| 83 | + runtime = BedrockRuntime() |
29 | 84 | results = [] |
30 | 85 | for snippet_id, snippet in snippets.items(): |
31 | | - content = [{"text": base_prompt}] |
32 | | - if prompt: |
33 | | - content.append({"text": prompt}) |
34 | | - content.append({"text": snippet.code}) |
35 | | - conversation = [ |
36 | | - { |
37 | | - "role": "user", |
38 | | - "content": content, |
39 | | - } |
40 | | - ] |
41 | | - |
42 | 86 | try: |
43 | | - response = client.converse( |
44 | | - modelId=model_id, |
45 | | - messages=conversation, |
46 | | - inferenceConfig={"maxTokens": 512, "temperature": 0.5, "topP": 0.9}, |
47 | | - ) |
48 | | - |
49 | | - # Extract and print the response text. |
50 | | - response_text = response["output"]["message"]["content"][0]["text"] |
51 | | - results.append(response_text) |
52 | | - |
| 87 | + response = generate_snippet_description(runtime, snippet, prompt) |
| 88 | + results.append(response) |
53 | 89 | except (ClientError, Exception) as e: |
54 | 90 | logger.warning( |
55 | | - f"ERROR: Can't invoke '{model_id}'. Name: {type(e).__name__}, Reason: {e}" |
| 91 | + f"ERROR: Can't invoke '{runtime.model_id}'. Name: {type(e).__name__}, Reason: {e}" |
56 | 92 | ) |
57 | 93 | print(results) |
58 | 94 |
|
|
0 commit comments