Skip to content

Commit 9668adb

Browse files
committed
Add JSON parsing of output.
1 parent 333e1ed commit 9668adb

File tree

1 file changed

+62
-26
lines changed

1 file changed

+62
-26
lines changed

scripts/snippet_summarize.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#!/usr/bin/env python
22

33
import argparse
4+
import json
45
import logging
56
import os
67
from pathlib import Path
7-
from typing import Dict, Optional
8+
from typing import Dict, List, NewType, Optional, TypedDict
89

910
import boto3
1011
from botocore.exceptions import ClientError
@@ -16,43 +17,78 @@
1617
logger = logging.getLogger(__name__)
1718

1819

20+
class ConversationEntry(TypedDict):
21+
role: str
22+
content: str
23+
24+
25+
Conversation = NewType("Conversation", List[ConversationEntry])
26+
27+
28+
class BedrockRuntime:
29+
def __init__(self, model_id: str = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"):
30+
self.client = boto3.client("bedrock-runtime")
31+
self.model_id = model_id
32+
self.base_prompt = Path(
33+
os.path.dirname(__file__), "base_prompt.txt"
34+
).read_text()
35+
self.conversation = [
36+
ConversationEntry({"role": "user", "content": [{"text": self.base_prompt}]})
37+
]
38+
39+
def converse(self, conversation: Conversation):
40+
self.conversation.extend(conversation)
41+
response = self.client.converse(
42+
modelId=self.model_id,
43+
messages=self.conversation,
44+
inferenceConfig={"maxTokens": 512, "temperature": 0.5, "topP": 0.9},
45+
)
46+
response_text = response["output"]["message"]["content"][0]["text"]
47+
return response_text
48+
49+
1950
def make_doc_gen(root: Path):
2051
doc_gen = DocGen.from_root(root)
2152
doc_gen.collect_snippets()
2253
return doc_gen
2354

2455

56+
def generate_snippet_description(
57+
bedrock_runtime: BedrockRuntime, snippet: Snippet, prompt: Optional[str]
58+
) -> Dict:
59+
content = (
60+
[{"text": prompt}, {"text": snippet.code}]
61+
if prompt
62+
else [{"text": snippet.code}]
63+
)
64+
conversation = [
65+
{
66+
"role": "user",
67+
"content": content,
68+
}
69+
]
70+
71+
response_text = bedrock_runtime.converse(conversation)
72+
73+
try:
74+
# This assumes the response is JSON, which couples snippet
75+
# description generation to a specific prompt.
76+
return json.loads(response_text)
77+
except Exception as e:
78+
logger.warning("Failed to parse response.", response=response_text)
79+
return {}
80+
81+
2582
def generate_descriptions(snippets: Dict[str, Snippet], prompt: Optional[str]):
26-
client = boto3.client("bedrock-runtime", region_name="us-west-2")
27-
base_prompt = Path(os.path.dirname(__file__), "base_prompt.txt").read_text()
28-
model_id = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
83+
runtime = BedrockRuntime()
2984
results = []
3085
for snippet_id, snippet in snippets.items():
31-
content = [{"text": base_prompt}]
32-
if prompt:
33-
content.append({"text": prompt})
34-
content.append({"text": snippet.code})
35-
conversation = [
36-
{
37-
"role": "user",
38-
"content": content,
39-
}
40-
]
41-
4286
try:
43-
response = client.converse(
44-
modelId=model_id,
45-
messages=conversation,
46-
inferenceConfig={"maxTokens": 512, "temperature": 0.5, "topP": 0.9},
47-
)
48-
49-
# Extract and print the response text.
50-
response_text = response["output"]["message"]["content"][0]["text"]
51-
results.append(response_text)
52-
87+
response = generate_snippet_description(runtime, snippet, prompt)
88+
results.append(response)
5389
except (ClientError, Exception) as e:
5490
logger.warning(
55-
f"ERROR: Can't invoke '{model_id}'. Name: {type(e).__name__}, Reason: {e}"
91+
f"ERROR: Can't invoke '{runtime.model_id}'. Name: {type(e).__name__}, Reason: {e}"
5692
)
5793
print(results)
5894

0 commit comments

Comments
 (0)