Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({
awsSecretKey: z.string().optional(),
awsSessionToken: z.string().optional(),
awsRegion: z.string().optional(),
awsCustomRegion: z.string().optional(),
awsUseCrossRegionInference: z.boolean().optional(),
awsUsePromptCache: z.boolean().optional(),
awsProfile: z.string().optional(),
Expand Down
8 changes: 7 additions & 1 deletion packages/types/src/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,10 @@ export const BEDROCK_REGIONS = [
{ value: "sa-east-1", label: "sa-east-1" },
{ value: "us-gov-east-1", label: "us-gov-east-1" },
{ value: "us-gov-west-1", label: "us-gov-west-1" },
].sort((a, b) => a.value.localeCompare(b.value))
{ value: "custom", label: "Custom region..." },
].sort((a, b) => {
// Keep "Custom region..." at the end
if (a.value === "custom") return 1
if (b.value === "custom") return -1
return a.value.localeCompare(b.value)
})
107 changes: 107 additions & 0 deletions src/api/providers/__tests__/bedrock-custom-region.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { AwsBedrockHandler } from "../bedrock"
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"

// Mock the AWS SDK
vi.mock("@aws-sdk/client-bedrock-runtime", () => ({
BedrockRuntimeClient: vi.fn().mockImplementation((config) => ({
config,
send: vi.fn(),
})),
ConverseCommand: vi.fn(),
ConverseStreamCommand: vi.fn(),
}))

describe("AwsBedrockHandler - Custom Region Support", () => {
beforeEach(() => {
vi.clearAllMocks()
})

it("should use custom region when awsRegion is 'custom' and awsCustomRegion is provided", () => {
const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "custom",
awsCustomRegion: "us-west-3",
})

// Get the mock instance to check the config
const mockClientInstance = vi.mocked(BedrockRuntimeClient).mock.results[0]?.value
expect(mockClientInstance.config.region).toBe("us-west-3")
})

it("should use standard region when awsRegion is not 'custom'", () => {
const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsCustomRegion: "us-west-3", // This should be ignored
})

// Get the mock instance to check the config
const mockClientInstance = vi.mocked(BedrockRuntimeClient).mock.results[0]?.value
expect(mockClientInstance.config.region).toBe("us-east-1")
})

it("should use awsRegion when awsCustomRegion is not provided", () => {
const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "custom",
// awsCustomRegion is not provided
})

// Get the mock instance to check the config
const mockClientInstance = vi.mocked(BedrockRuntimeClient).mock.results[0]?.value
expect(mockClientInstance.config.region).toBe("custom")
})

it("should use custom region for cross-region inference prefix calculation", () => {
const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "custom",
awsCustomRegion: "us-west-3",
awsUseCrossRegionInference: true,
})

const model = handler.getModel()
// Should have the us. prefix for us-west-3
expect(model.id).toContain("us.")
})

it("should handle custom regions with different prefixes for cross-region inference", () => {
const testCases = [
{ customRegion: "eu-central-3", expectedPrefix: "eu." },
{ customRegion: "ap-southeast-4", expectedPrefix: "apac." },
{ customRegion: "ca-west-1", expectedPrefix: "ca." },
{ customRegion: "sa-east-2", expectedPrefix: "sa." },
{ customRegion: "us-gov-west-2", expectedPrefix: "ug." },
]

for (const { customRegion, expectedPrefix } of testCases) {
vi.clearAllMocks()

const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "custom",
awsCustomRegion: customRegion,
awsUseCrossRegionInference: true,
})

const model = handler.getModel()
expect(model.id).toContain(expectedPrefix)
}
})
})
24 changes: 18 additions & 6 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
constructor(options: ProviderSettings) {
super()
this.options = options
let region = this.options.awsRegion
// Use custom region if awsRegion is "custom"
let region =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The region resolution logic (using awsRegion === 'custom' && awsCustomRegion) is duplicated (also at line 952). Consider extracting it into a helper function for clarity and consistency.

this.options.awsRegion === "custom" && this.options.awsCustomRegion
? this.options.awsCustomRegion
: this.options.awsRegion

// process the various user input options, be opinionated about the intent of the options
// and determine the model to use during inference and for cost calculations
Expand Down Expand Up @@ -216,7 +220,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
this.costModelConfig = this.getModel()

const clientConfig: BedrockRuntimeClientConfig = {
region: this.options.awsRegion,
region: region, // Use the resolved region (either standard or custom)
// Add the endpoint configuration when specified and enabled
...(this.options.awsBedrockEndpoint &&
this.options.awsBedrockEndpointEnabled && { endpoint: this.options.awsBedrockEndpoint }),
Expand Down Expand Up @@ -943,10 +947,18 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
modelConfig = this.getModelById(this.options.apiModelId as string)

// Add cross-region inference prefix if enabled
if (this.options.awsUseCrossRegionInference && this.options.awsRegion) {
const prefix = AwsBedrockHandler.getPrefixForRegion(this.options.awsRegion)
if (prefix) {
modelConfig.id = `${prefix}${modelConfig.id}`
if (this.options.awsUseCrossRegionInference) {
// Use custom region if awsRegion is "custom"
const regionToUse =
this.options.awsRegion === "custom" && this.options.awsCustomRegion
? this.options.awsCustomRegion
: this.options.awsRegion

if (regionToUse) {
const prefix = AwsBedrockHandler.getPrefixForRegion(regionToUse)
if (prefix) {
modelConfig.id = `${prefix}${modelConfig.id}`
}
}
}
}
Expand Down
72 changes: 71 additions & 1 deletion webview-ui/src/components/settings/providers/Bedrock.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue, Standard

import { inputEventTransform, noTransform } from "../transforms"

// AWS region format validation regex
const AWS_REGION_REGEX = /^[a-z]{2,}-[a-z]+-\d+$/

type BedrockProps = {
apiConfiguration: ProviderSettings
setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void
Expand All @@ -18,12 +21,19 @@ type BedrockProps = {
export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedModelInfo }: BedrockProps) => {
const { t } = useAppTranslation()
const [awsEndpointSelected, setAwsEndpointSelected] = useState(!!apiConfiguration?.awsBedrockEndpointEnabled)
const [customRegionSelected, setCustomRegionSelected] = useState(apiConfiguration?.awsRegion === "custom")
const [customRegionError, setCustomRegionError] = useState<string | null>(null)

// Update the endpoint enabled state when the configuration changes
useEffect(() => {
setAwsEndpointSelected(!!apiConfiguration?.awsBedrockEndpointEnabled)
}, [apiConfiguration?.awsBedrockEndpointEnabled])

// Update the custom region state when the configuration changes
useEffect(() => {
setCustomRegionSelected(apiConfiguration?.awsRegion === "custom")
}, [apiConfiguration?.awsRegion])

const handleInputChange = useCallback(
<K extends keyof ProviderSettings, E>(
field: K,
Expand All @@ -35,6 +45,34 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo
[setApiConfigurationField],
)

// Validate custom region format
const validateCustomRegion = useCallback(
(value: string) => {
if (!value && customRegionSelected) {
setCustomRegionError(t("settings:providers.awsCustomRegion.validation.required"))
return false
}
if (value && !AWS_REGION_REGEX.test(value)) {
setCustomRegionError(t("settings:providers.awsCustomRegion.validation.format"))
return false
}
setCustomRegionError(null)
return true
},
[customRegionSelected, t],
)

// Handle custom region input change with validation
const handleCustomRegionChange = useCallback(
(event: Event | React.FormEvent<HTMLElement>) => {
const target = event.target as HTMLInputElement
const value = target.value
validateCustomRegion(value)
setApiConfigurationField("awsCustomRegion", value)
},
[setApiConfigurationField, validateCustomRegion],
)

return (
<>
<VSCodeRadioGroup
Expand Down Expand Up @@ -89,7 +127,18 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo
<label className="block font-medium mb-1">{t("settings:providers.awsRegion")}</label>
<Select
value={apiConfiguration?.awsRegion || ""}
onValueChange={(value) => setApiConfigurationField("awsRegion", value)}>
onValueChange={(value) => {
setApiConfigurationField("awsRegion", value)
setCustomRegionSelected(value === "custom")
// Don't clear custom region when switching away - preserve the value
if (value === "custom" && apiConfiguration?.awsCustomRegion) {
// Validate the existing custom region value
validateCustomRegion(apiConfiguration.awsCustomRegion)
} else {
// Clear validation error when not using custom region
setCustomRegionError(null)
}
}}>
<SelectTrigger className="w-full">
<SelectValue placeholder={t("settings:common.select")} />
</SelectTrigger>
Expand All @@ -102,6 +151,27 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo
</SelectContent>
</Select>
</div>
{customRegionSelected && (
<>
<VSCodeTextField
value={apiConfiguration?.awsCustomRegion || ""}
style={{ width: "100%", marginTop: 3, marginBottom: 5 }}
onInput={handleCustomRegionChange}
placeholder={t("settings:providers.awsCustomRegion.placeholder")}
data-testid="custom-region-input"
className={customRegionError ? "error" : ""}
/>
{customRegionError && (
<div className="text-sm text-vscode-errorForeground ml-6 mt-1 mb-2">{customRegionError}</div>
)}
<div className="text-sm text-vscode-descriptionForeground ml-6 mt-1 mb-3">
{t("settings:providers.awsCustomRegion.examples")}
<div className="ml-2">• us-west-3</div>
<div className="ml-2">• eu-central-3</div>
<div className="ml-2">• ap-southeast-3</div>
</div>
</>
)}
<Checkbox
checked={apiConfiguration?.awsUseCrossRegionInference || false}
onChange={handleInputChange("awsUseCrossRegionInference", noTransform)}>
Expand Down
Loading
Loading