Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({
awsSecretKey: z.string().optional(),
awsSessionToken: z.string().optional(),
awsRegion: z.string().optional(),
awsCustomRegion: z.string().optional(),
awsUseCrossRegionInference: z.boolean().optional(),
awsUsePromptCache: z.boolean().optional(),
awsProfile: z.string().optional(),
Expand Down
8 changes: 7 additions & 1 deletion packages/types/src/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,10 @@ export const BEDROCK_REGIONS = [
{ value: "sa-east-1", label: "sa-east-1" },
{ value: "us-gov-east-1", label: "us-gov-east-1" },
{ value: "us-gov-west-1", label: "us-gov-west-1" },
].sort((a, b) => a.value.localeCompare(b.value))
{ value: "custom", label: "Custom region..." },
].sort((a, b) => {
// Keep "Custom region..." at the end
if (a.value === "custom") return 1
if (b.value === "custom") return -1
return a.value.localeCompare(b.value)
})
107 changes: 107 additions & 0 deletions src/api/providers/__tests__/bedrock-custom-region.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { AwsBedrockHandler } from "../bedrock"
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"

// Mock the AWS SDK
vi.mock("@aws-sdk/client-bedrock-runtime", () => ({
BedrockRuntimeClient: vi.fn().mockImplementation((config) => ({
config,
send: vi.fn(),
})),
ConverseCommand: vi.fn(),
ConverseStreamCommand: vi.fn(),
}))

describe("AwsBedrockHandler - Custom Region Support", () => {
beforeEach(() => {
vi.clearAllMocks()
})

it("should use custom region when awsRegion is 'custom' and awsCustomRegion is provided", () => {
const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "custom",
awsCustomRegion: "us-west-3",
})

// Get the mock instance to check the config
const mockClientInstance = vi.mocked(BedrockRuntimeClient).mock.results[0]?.value
expect(mockClientInstance.config.region).toBe("us-west-3")
})

it("should use standard region when awsRegion is not 'custom'", () => {
const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsCustomRegion: "us-west-3", // This should be ignored
})

// Get the mock instance to check the config
const mockClientInstance = vi.mocked(BedrockRuntimeClient).mock.results[0]?.value
expect(mockClientInstance.config.region).toBe("us-east-1")
})

it("should use awsRegion when awsCustomRegion is not provided", () => {
const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "custom",
// awsCustomRegion is not provided
})

// Get the mock instance to check the config
const mockClientInstance = vi.mocked(BedrockRuntimeClient).mock.results[0]?.value
expect(mockClientInstance.config.region).toBe("custom")
})

it("should use custom region for cross-region inference prefix calculation", () => {
const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "custom",
awsCustomRegion: "us-west-3",
awsUseCrossRegionInference: true,
})

const model = handler.getModel()
// Should have the us. prefix for us-west-3
expect(model.id).toContain("us.")
})

it("should handle custom regions with different prefixes for cross-region inference", () => {
const testCases = [
{ customRegion: "eu-central-3", expectedPrefix: "eu." },
{ customRegion: "ap-southeast-4", expectedPrefix: "apac." },
{ customRegion: "ca-west-1", expectedPrefix: "ca." },
{ customRegion: "sa-east-2", expectedPrefix: "sa." },
{ customRegion: "us-gov-west-2", expectedPrefix: "ug." },
]

for (const { customRegion, expectedPrefix } of testCases) {
vi.clearAllMocks()

const handler = new AwsBedrockHandler({
apiProvider: "bedrock",
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "custom",
awsCustomRegion: customRegion,
awsUseCrossRegionInference: true,
})

const model = handler.getModel()
expect(model.id).toContain(expectedPrefix)
}
})
})
24 changes: 18 additions & 6 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
constructor(options: ProviderSettings) {
super()
this.options = options
let region = this.options.awsRegion
// Use custom region if awsRegion is "custom"
let region =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The region resolution logic (using awsRegion === 'custom' && awsCustomRegion) is duplicated (also at line 952). Consider extracting it into a helper function for clarity and consistency.

this.options.awsRegion === "custom" && this.options.awsCustomRegion
? this.options.awsCustomRegion
: this.options.awsRegion

// process the various user input options, be opinionated about the intent of the options
// and determine the model to use during inference and for cost calculations
Expand Down Expand Up @@ -216,7 +220,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
this.costModelConfig = this.getModel()

const clientConfig: BedrockRuntimeClientConfig = {
region: this.options.awsRegion,
region: region, // Use the resolved region (either standard or custom)
// Add the endpoint configuration when specified and enabled
...(this.options.awsBedrockEndpoint &&
this.options.awsBedrockEndpointEnabled && { endpoint: this.options.awsBedrockEndpoint }),
Expand Down Expand Up @@ -943,10 +947,18 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
modelConfig = this.getModelById(this.options.apiModelId as string)

// Add cross-region inference prefix if enabled
if (this.options.awsUseCrossRegionInference && this.options.awsRegion) {
const prefix = AwsBedrockHandler.getPrefixForRegion(this.options.awsRegion)
if (prefix) {
modelConfig.id = `${prefix}${modelConfig.id}`
if (this.options.awsUseCrossRegionInference) {
// Use custom region if awsRegion is "custom"
const regionToUse =
this.options.awsRegion === "custom" && this.options.awsCustomRegion
? this.options.awsCustomRegion
: this.options.awsRegion

if (regionToUse) {
const prefix = AwsBedrockHandler.getPrefixForRegion(regionToUse)
if (prefix) {
modelConfig.id = `${prefix}${modelConfig.id}`
}
}
}
}
Expand Down
72 changes: 71 additions & 1 deletion webview-ui/src/components/settings/providers/Bedrock.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue, Standard

import { inputEventTransform, noTransform } from "../transforms"

// AWS region format validation regex
const AWS_REGION_REGEX = /^[a-z]{2,}-[a-z]+-\d+$/

type BedrockProps = {
apiConfiguration: ProviderSettings
setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void
Expand All @@ -18,12 +21,19 @@ type BedrockProps = {
export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedModelInfo }: BedrockProps) => {
const { t } = useAppTranslation()
const [awsEndpointSelected, setAwsEndpointSelected] = useState(!!apiConfiguration?.awsBedrockEndpointEnabled)
const [customRegionSelected, setCustomRegionSelected] = useState(apiConfiguration?.awsRegion === "custom")
const [customRegionError, setCustomRegionError] = useState<string | null>(null)

// Update the endpoint enabled state when the configuration changes
useEffect(() => {
setAwsEndpointSelected(!!apiConfiguration?.awsBedrockEndpointEnabled)
}, [apiConfiguration?.awsBedrockEndpointEnabled])

// Update the custom region state when the configuration changes
useEffect(() => {
setCustomRegionSelected(apiConfiguration?.awsRegion === "custom")
}, [apiConfiguration?.awsRegion])

const handleInputChange = useCallback(
<K extends keyof ProviderSettings, E>(
field: K,
Expand All @@ -35,6 +45,34 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo
[setApiConfigurationField],
)

// Validate custom region format
const validateCustomRegion = useCallback(
(value: string) => {
if (!value && customRegionSelected) {
setCustomRegionError(t("settings:providers.awsCustomRegion.validation.required"))
return false
}
if (value && !AWS_REGION_REGEX.test(value)) {
setCustomRegionError(t("settings:providers.awsCustomRegion.validation.format"))
return false
}
setCustomRegionError(null)
return true
},
[customRegionSelected, t],
)

// Handle custom region input change with validation
const handleCustomRegionChange = useCallback(
(event: Event | React.FormEvent<HTMLElement>) => {
const target = event.target as HTMLInputElement
const value = target.value
validateCustomRegion(value)
setApiConfigurationField("awsCustomRegion", value)
},
[setApiConfigurationField, validateCustomRegion],
)

return (
<>
<VSCodeRadioGroup
Expand Down Expand Up @@ -89,7 +127,18 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo
<label className="block font-medium mb-1">{t("settings:providers.awsRegion")}</label>
<Select
value={apiConfiguration?.awsRegion || ""}
onValueChange={(value) => setApiConfigurationField("awsRegion", value)}>
onValueChange={(value) => {
setApiConfigurationField("awsRegion", value)
setCustomRegionSelected(value === "custom")
// Don't clear custom region when switching away - preserve the value
if (value === "custom" && apiConfiguration?.awsCustomRegion) {
// Validate the existing custom region value
validateCustomRegion(apiConfiguration.awsCustomRegion)
} else {
// Clear validation error when not using custom region
setCustomRegionError(null)
}
}}>
<SelectTrigger className="w-full">
<SelectValue placeholder={t("settings:common.select")} />
</SelectTrigger>
Expand All @@ -102,6 +151,27 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo
</SelectContent>
</Select>
</div>
{customRegionSelected && (
<>
<VSCodeTextField
value={apiConfiguration?.awsCustomRegion || ""}
style={{ width: "100%", marginTop: 3, marginBottom: 5 }}
onInput={handleCustomRegionChange}
placeholder={t("settings:providers.awsCustomRegion.placeholder")}
data-testid="custom-region-input"
className={customRegionError ? "error" : ""}
/>
{customRegionError && (
<div className="text-sm text-vscode-errorForeground ml-6 mt-1 mb-2">{customRegionError}</div>
)}
<div className="text-sm text-vscode-descriptionForeground ml-6 mt-1 mb-3">
{t("settings:providers.awsCustomRegion.examples")}
<div className="ml-2">• us-west-3</div>
<div className="ml-2">• eu-central-3</div>
<div className="ml-2">• ap-southeast-3</div>
</div>
</>
)}
<Checkbox
checked={apiConfiguration?.awsUseCrossRegionInference || false}
onChange={handleInputChange("awsUseCrossRegionInference", noTransform)}>
Expand Down
Loading
Loading