Skip to content

Commit e194edb

Browse files
committed
Add Bedrock Anthropic provider support for test generation
Extended CLI and internal logic to support 'bedrock-anthropic' as a provider for test generation. Updated help messages, provider validation, and model handling to accommodate Bedrock Anthropic, including AWS credential requirements and model ID usage. Integrated ChatBedrockAnthropic client and adjusted model validation and selection accordingly.
1 parent 23a66cb commit e194edb

File tree

3 files changed

+91
-42
lines changed

3 files changed

+91
-42
lines changed

shiny/_main.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,14 +555,23 @@ def add() -> None:
555555
)
556556
@click.option(
557557
"--provider",
558-
type=click.Choice(["anthropic", "openai"]),
558+
type=click.Choice(["anthropic", "openai", "bedrock-anthropic"]),
559559
default="anthropic",
560-
help="AI provider to use for test generation.",
560+
help=(
561+
"AI provider to use for test generation. For 'bedrock-anthropic', "
562+
"make sure your AWS credentials are configured (env vars, profile, or role) "
563+
"and provide a Bedrock Anthropic model ID (e.g., "
564+
"us.anthropic.claude-3-7-sonnet-20250219-v1:0)."
565+
),
561566
)
562567
@click.option(
563568
"--model",
564569
type=str,
565-
help="Specific model to use (optional). Examples: haiku3.5, sonnet, gpt-5, gpt-5-mini",
570+
help=(
571+
"Specific model to use (optional). Examples: haiku3.5, sonnet, gpt-5, gpt-5-mini; "
572+
"or a Bedrock Anthropic model ID when using provider=bedrock-anthropic, e.g. "
573+
"us.anthropic.claude-3-7-sonnet-20250219-v1:0"
574+
),
566575
)
567576
# Param for app.py, param for test_name
568577
def test(

shiny/_main_generate_test.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,27 @@ def validate_api_key(provider: str) -> None:
5454
"env_var": "OPENAI_API_KEY",
5555
"url": "https://platform.openai.com/api-keys",
5656
},
57+
"bedrock-anthropic": {
58+
"env_var": None,
59+
"url": "https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html",
60+
},
5761
}
5862

5963
if provider not in api_configs:
6064
raise ValidationError(f"Unsupported provider: {provider}")
6165

6266
config = api_configs[provider]
63-
if not os.getenv(config["env_var"]):
64-
raise ValidationError(
65-
f"{config['env_var']} environment variable is not set.\n"
66-
f"Please set your {provider.title()} API key:\n"
67-
f" export {config['env_var']}='your-api-key-here'\n\n"
68-
f"Get your API key from: {config['url']}"
69-
)
67+
if provider in ("anthropic", "openai"):
68+
env_var = config["env_var"] # type: ignore[assignment]
69+
if not isinstance(env_var, str) or not os.getenv(env_var):
70+
raise ValidationError(
71+
f"{env_var} environment variable is not set.\n"
72+
f"Please set your {provider.title()} API key:\n"
73+
f" export {env_var}='your-api-key-here'\n\n"
74+
f"Get your API key from: {config['url']}"
75+
)
76+
else:
77+
pass
7078

7179

7280
def get_app_file_path(app_file: str | None) -> Path:

shiny/pytest/_generate/_main.py

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pathlib import Path
99
from typing import Literal, Optional, Tuple, Union
1010

11-
from chatlas import ChatAnthropic, ChatOpenAI, token_usage
11+
from chatlas import ChatAnthropic, ChatBedrockAnthropic, ChatOpenAI, token_usage
1212
from dotenv import load_dotenv
1313

1414
__all__ = [
@@ -32,6 +32,7 @@ class Config:
3232

3333
DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514"
3434
DEFAULT_OPENAI_MODEL = "gpt-5-mini-2025-08-07"
35+
DEFAULT_BEDROCK_ANTHROPIC_MODEL = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
3536
DEFAULT_PROVIDER = "anthropic"
3637

3738
MAX_TOKENS = 8092
@@ -50,7 +51,9 @@ class ShinyTestGenerator:
5051

5152
def __init__(
5253
self,
53-
provider: Literal["anthropic", "openai"] = Config.DEFAULT_PROVIDER,
54+
provider: Literal[
55+
"anthropic", "openai", "bedrock-anthropic"
56+
] = Config.DEFAULT_PROVIDER,
5457
api_key: Optional[str] = None,
5558
log_file: str = Config.LOG_FILE,
5659
setup_logging: bool = True,
@@ -74,25 +77,28 @@ def __init__(
7477
self.setup_logging()
7578

7679
@property
77-
def client(self) -> Union[ChatAnthropic, ChatOpenAI]:
80+
def client(self) -> Union[ChatAnthropic, ChatOpenAI, ChatBedrockAnthropic]:
7881
"""Lazy-loaded chat client based on provider"""
7982
if self._client is None:
80-
if not self.api_key:
81-
env_var = (
82-
"ANTHROPIC_API_KEY"
83-
if self.provider == "anthropic"
84-
else "OPENAI_API_KEY"
85-
)
86-
self.api_key = os.getenv(env_var)
87-
if not self.api_key:
88-
raise ValueError(
89-
f"Missing API key for provider '{self.provider}'. Set the environment variable "
90-
f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key explicitly."
91-
)
83+
if self.provider in ("anthropic", "openai"):
84+
if not self.api_key:
85+
env_var = (
86+
"ANTHROPIC_API_KEY"
87+
if self.provider == "anthropic"
88+
else "OPENAI_API_KEY"
89+
)
90+
self.api_key = os.getenv(env_var)
91+
if not self.api_key:
92+
raise ValueError(
93+
f"Missing API key for provider '{self.provider}'. Set the environment variable "
94+
f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key explicitly."
95+
)
9296
if self.provider == "anthropic":
9397
self._client = ChatAnthropic(api_key=self.api_key)
9498
elif self.provider == "openai":
9599
self._client = ChatOpenAI(api_key=self.api_key)
100+
elif self.provider == "bedrock-anthropic":
101+
self._client = ChatBedrockAnthropic()
96102
else:
97103
raise ValueError(f"Unsupported provider: {self.provider}")
98104
return self._client
@@ -118,6 +124,8 @@ def default_model(self) -> str:
118124
return Config.DEFAULT_ANTHROPIC_MODEL
119125
elif self.provider == "openai":
120126
return Config.DEFAULT_OPENAI_MODEL
127+
elif self.provider == "bedrock-anthropic":
128+
return Config.DEFAULT_BEDROCK_ANTHROPIC_MODEL
121129
else:
122130
raise ValueError(f"Unsupported provider: {self.provider}")
123131

@@ -168,6 +176,15 @@ def _resolve_model(self, model: str) -> str:
168176

169177
def _validate_model_for_provider(self, model: str) -> str:
170178
"""Validate that the model is compatible with the current provider"""
179+
if self.provider == "bedrock-anthropic":
180+
resolved_model = model
181+
if resolved_model.startswith("gpt-") or resolved_model.startswith("o1-"):
182+
raise ValueError(
183+
f"Model '{model}' is an OpenAI model but provider is set to 'bedrock-anthropic'. "
184+
f"Use an Anthropic Bedrock model ID (e.g., 'anthropic.claude-3-5-sonnet-20240620-v1:0')."
185+
)
186+
return resolved_model
187+
171188
resolved_model = self._resolve_model(model)
172189

173190
if self.provider == "anthropic":
@@ -193,18 +210,19 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str:
193210
model = self._validate_model_for_provider(model)
194211

195212
try:
196-
if not self.api_key:
197-
env_var = (
198-
"ANTHROPIC_API_KEY"
199-
if self.provider == "anthropic"
200-
else "OPENAI_API_KEY"
201-
)
202-
self.api_key = os.getenv(env_var)
203-
if not self.api_key:
204-
raise ValueError(
205-
f"Missing API key for provider '{self.provider}'. Set the environment variable "
206-
f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key."
207-
)
213+
if self.provider in ("anthropic", "openai"):
214+
if not self.api_key:
215+
env_var = (
216+
"ANTHROPIC_API_KEY"
217+
if self.provider == "anthropic"
218+
else "OPENAI_API_KEY"
219+
)
220+
self.api_key = os.getenv(env_var)
221+
if not self.api_key:
222+
raise ValueError(
223+
f"Missing API key for provider '{self.provider}'. Set the environment variable "
224+
f"{'ANTHROPIC_API_KEY' if self.provider == 'anthropic' else 'OPENAI_API_KEY'} or pass api_key."
225+
)
208226
# Create chat client with the specified model
209227
if self.provider == "anthropic":
210228
chat = ChatAnthropic(
@@ -219,22 +237,25 @@ def get_llm_response(self, prompt: str, model: Optional[str] = None) -> str:
219237
system_prompt=self.system_prompt,
220238
api_key=self.api_key,
221239
)
240+
elif self.provider == "bedrock-anthropic":
241+
chat = ChatBedrockAnthropic(
242+
model=model,
243+
system_prompt=self.system_prompt,
244+
max_tokens=Config.MAX_TOKENS,
245+
)
222246
else:
223247
raise ValueError(f"Unsupported provider: {self.provider}")
224248

225249
start_time = time.perf_counter()
226250
response = chat.chat(prompt)
227251
elapsed = time.perf_counter() - start_time
228252
usage = token_usage()
229-
# For Anthropic, token_usage() includes costs. For OpenAI, use chat.get_cost with model pricing.
230253
token_price = None
231254
if self.provider == "openai":
232255
token_price = Config.OPENAI_PRICING.get(model)
233256
try:
234-
# Call to compute and cache costs internally; per-entry cost is computed below
235257
_ = chat.get_cost(options="all", token_price=token_price)
236258
except Exception:
237-
# If cost computation fails, continue without it
238259
pass
239260

240261
try:
@@ -530,7 +551,9 @@ def generate_test_from_code(
530551
)
531552

532553
def switch_provider(
533-
self, provider: Literal["anthropic", "openai"], api_key: Optional[str] = None
554+
self,
555+
provider: Literal["anthropic", "openai", "bedrock-anthropic"],
556+
api_key: Optional[str] = None,
534557
):
535558
self.provider = provider
536559
if api_key:
@@ -549,6 +572,11 @@ def create_openai_generator(
549572
) -> "ShinyTestGenerator":
550573
return cls(provider="openai", api_key=api_key, **kwargs)
551574

575+
@classmethod
576+
def create_bedrock_anthropic_generator(cls, **kwargs) -> "ShinyTestGenerator":
577+
# AWS credentials and region are resolved from environment or AWS config
578+
return cls(provider="bedrock-anthropic", api_key=None, **kwargs)
579+
552580
def get_available_models(self) -> list[str]:
553581
if self.provider == "anthropic":
554582
return [
@@ -562,6 +590,10 @@ def get_available_models(self) -> list[str]:
562590
for model in Config.MODEL_ALIASES.keys()
563591
if (model.startswith("gpt-") or model.startswith("o1-"))
564592
]
593+
elif self.provider == "bedrock-anthropic":
594+
# Bedrock requires full model IDs (e.g., 'us.anthropic.claude-sonnet-4-20250514-v1:0').
595+
# We don't provide aliases here because IDs are region/account specific.
596+
return []
565597
else:
566598
return []
567599

@@ -573,7 +605,7 @@ def cli():
573605
parser.add_argument("app_file", help="Path to the Shiny app file")
574606
parser.add_argument(
575607
"--provider",
576-
choices=["anthropic", "openai"],
608+
choices=["anthropic", "openai", "bedrock-anthropic"],
577609
default=Config.DEFAULT_PROVIDER,
578610
help="LLM provider to use",
579611
)

0 commit comments

Comments
 (0)